mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
Compare commits
26 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bd83f7520e | ||
|
|
b9a4e7a09b | ||
|
|
c30e57ede8 | ||
|
|
0dba1b336d | ||
|
|
820afe9319 | ||
|
|
5a97f4bc75 | ||
|
|
94da404cc5 | ||
|
|
1da476d858 | ||
|
|
1daaff6bd4 | ||
|
|
e252e44403 | ||
|
|
778ad8abd2 | ||
|
|
68cf381b50 | ||
|
|
337f73e711 | ||
|
|
04ba966a6e | ||
|
|
71c8cf84e0 | ||
|
|
db1aec94e5 | ||
|
|
553e1868e1 | ||
|
|
938ceb49b2 | ||
|
|
c0f03b79a8 | ||
|
|
a492638133 | ||
|
|
e17d6c8ebf | ||
|
|
ffcfe5ea3e | ||
|
|
719e18adb6 | ||
|
|
92d471daf5 | ||
|
|
66babf9ee1 | ||
|
|
60df2df324 |
201
.agents/skills/lora-manager-e2e/SKILL.md
Normal file
201
.agents/skills/lora-manager-e2e/SKILL.md
Normal file
@@ -0,0 +1,201 @@
|
||||
---
|
||||
name: lora-manager-e2e
|
||||
description: End-to-end testing and validation for LoRa Manager features. Use when performing automated E2E validation of LoRa Manager standalone mode, including starting/restarting the server, using Chrome DevTools MCP to interact with the web UI at http://127.0.0.1:8188/loras, and verifying frontend-to-backend functionality. Covers workflow validation, UI interaction testing, and integration testing between the standalone Python backend and the browser frontend.
|
||||
---
|
||||
|
||||
# LoRa Manager E2E Testing
|
||||
|
||||
This skill provides workflows and utilities for end-to-end testing of LoRa Manager using Chrome DevTools MCP.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- LoRa Manager project cloned and dependencies installed (`pip install -r requirements.txt`)
|
||||
- Chrome browser available for debugging
|
||||
- Chrome DevTools MCP connected
|
||||
|
||||
## Quick Start Workflow
|
||||
|
||||
### 1. Start LoRa Manager Standalone
|
||||
|
||||
```python
|
||||
# Use the provided script to start the server
|
||||
python .agents/skills/lora-manager-e2e/scripts/start_server.py --port 8188
|
||||
```
|
||||
|
||||
Or manually:
|
||||
```bash
|
||||
cd /home/miao/workspace/ComfyUI/custom_nodes/ComfyUI-Lora-Manager
|
||||
python standalone.py --port 8188
|
||||
```
|
||||
|
||||
Wait for server ready message before proceeding.
|
||||
|
||||
### 2. Open Chrome Debug Mode
|
||||
|
||||
```bash
|
||||
# Chrome with remote debugging on port 9222
|
||||
google-chrome --remote-debugging-port=9222 --user-data-dir=/tmp/chrome-lora-manager http://127.0.0.1:8188/loras
|
||||
```
|
||||
|
||||
### 3. Connect Chrome DevTools MCP
|
||||
|
||||
Ensure the MCP server is connected to Chrome at `http://localhost:9222`.
|
||||
|
||||
### 4. Navigate and Interact
|
||||
|
||||
Use Chrome DevTools MCP tools to:
|
||||
- Take snapshots: `take_snapshot`
|
||||
- Click elements: `click`
|
||||
- Fill forms: `fill` or `fill_form`
|
||||
- Evaluate scripts: `evaluate_script`
|
||||
- Wait for elements: `wait_for`
|
||||
|
||||
## Common E2E Test Patterns
|
||||
|
||||
### Pattern: Full Page Load Verification
|
||||
|
||||
```python
|
||||
# Navigate to LoRA list page
|
||||
navigate_page(type="url", url="http://127.0.0.1:8188/loras")
|
||||
|
||||
# Wait for page to load
|
||||
wait_for(text="LoRAs", timeout=10000)
|
||||
|
||||
# Take snapshot to verify UI state
|
||||
snapshot = take_snapshot()
|
||||
```
|
||||
|
||||
### Pattern: Restart Server for Configuration Changes
|
||||
|
||||
```python
|
||||
# Stop current server (if running)
|
||||
# Start with new configuration
|
||||
python .agents/skills/lora-manager-e2e/scripts/start_server.py --port 8188 --restart
|
||||
|
||||
# Wait and refresh browser
|
||||
navigate_page(type="reload", ignoreCache=True)
|
||||
wait_for(text="LoRAs", timeout=15000)
|
||||
```
|
||||
|
||||
### Pattern: Verify Backend API via Frontend
|
||||
|
||||
```python
|
||||
# Execute script in browser to call backend API
|
||||
result = evaluate_script(function="""
|
||||
async () => {
|
||||
const response = await fetch('/loras/api/list');
|
||||
const data = await response.json();
|
||||
return { count: data.length, firstItem: data[0]?.name };
|
||||
}
|
||||
""")
|
||||
```
|
||||
|
||||
### Pattern: Form Submission Flow
|
||||
|
||||
```python
|
||||
# Fill a form (e.g., search or filter)
|
||||
fill_form(elements=[
|
||||
{"uid": "search-input", "value": "character"},
|
||||
])
|
||||
|
||||
# Click submit button
|
||||
click(uid="search-button")
|
||||
|
||||
# Wait for results
|
||||
wait_for(text="Results", timeout=5000)
|
||||
|
||||
# Verify results via snapshot
|
||||
snapshot = take_snapshot()
|
||||
```
|
||||
|
||||
### Pattern: Modal Dialog Interaction
|
||||
|
||||
```python
|
||||
# Open modal (e.g., add LoRA)
|
||||
click(uid="add-lora-button")
|
||||
|
||||
# Wait for modal to appear
|
||||
wait_for(text="Add LoRA", timeout=3000)
|
||||
|
||||
# Fill modal form
|
||||
fill_form(elements=[
|
||||
{"uid": "lora-name", "value": "Test LoRA"},
|
||||
{"uid": "lora-path", "value": "/path/to/lora.safetensors"},
|
||||
])
|
||||
|
||||
# Submit
|
||||
click(uid="modal-submit-button")
|
||||
|
||||
# Wait for success message or close
|
||||
wait_for(text="Success", timeout=5000)
|
||||
```
|
||||
|
||||
## Available Scripts
|
||||
|
||||
### scripts/start_server.py
|
||||
|
||||
Starts or restarts the LoRa Manager standalone server.
|
||||
|
||||
```bash
|
||||
python scripts/start_server.py [--port PORT] [--restart] [--wait]
|
||||
```
|
||||
|
||||
Options:
|
||||
- `--port`: Server port (default: 8188)
|
||||
- `--restart`: Kill existing server before starting
|
||||
- `--wait`: Wait for server to be ready before exiting
|
||||
|
||||
### scripts/wait_for_server.py
|
||||
|
||||
Polls server until ready or timeout.
|
||||
|
||||
```bash
|
||||
python scripts/wait_for_server.py [--port PORT] [--timeout SECONDS]
|
||||
```
|
||||
|
||||
## Test Scenarios Reference
|
||||
|
||||
See [references/test-scenarios.md](references/test-scenarios.md) for detailed test scenarios including:
|
||||
- LoRA list display and filtering
|
||||
- Model metadata editing
|
||||
- Recipe creation and management
|
||||
- Settings configuration
|
||||
- Import/export functionality
|
||||
|
||||
## Network Request Verification
|
||||
|
||||
Use `list_network_requests` and `get_network_request` to verify API calls:
|
||||
|
||||
```python
|
||||
# List recent XHR/fetch requests
|
||||
requests = list_network_requests(resourceTypes=["xhr", "fetch"])
|
||||
|
||||
# Get details of specific request
|
||||
details = get_network_request(reqid=123)
|
||||
```
|
||||
|
||||
## Console Message Monitoring
|
||||
|
||||
```python
|
||||
# Check for errors or warnings
|
||||
messages = list_console_messages(types=["error", "warn"])
|
||||
```
|
||||
|
||||
## Performance Testing
|
||||
|
||||
```python
|
||||
# Start performance trace
|
||||
performance_start_trace(reload=True, autoStop=False)
|
||||
|
||||
# Perform actions...
|
||||
|
||||
# Stop and analyze
|
||||
results = performance_stop_trace()
|
||||
```
|
||||
|
||||
## Cleanup
|
||||
|
||||
Always ensure proper cleanup after tests:
|
||||
1. Stop the standalone server
|
||||
2. Close browser pages (keep at least one open)
|
||||
3. Clear temporary data if needed
|
||||
324
.agents/skills/lora-manager-e2e/references/mcp-cheatsheet.md
Normal file
324
.agents/skills/lora-manager-e2e/references/mcp-cheatsheet.md
Normal file
@@ -0,0 +1,324 @@
|
||||
# Chrome DevTools MCP Cheatsheet for LoRa Manager
|
||||
|
||||
Quick reference for common MCP commands used in LoRa Manager E2E testing.
|
||||
|
||||
## Navigation
|
||||
|
||||
```python
|
||||
# Navigate to LoRA list page
|
||||
navigate_page(type="url", url="http://127.0.0.1:8188/loras")
|
||||
|
||||
# Reload page with cache clear
|
||||
navigate_page(type="reload", ignoreCache=True)
|
||||
|
||||
# Go back/forward
|
||||
navigate_page(type="back")
|
||||
navigate_page(type="forward")
|
||||
```
|
||||
|
||||
## Waiting
|
||||
|
||||
```python
|
||||
# Wait for text to appear
|
||||
wait_for(text="LoRAs", timeout=10000)
|
||||
|
||||
# Wait for specific element (via evaluate_script)
|
||||
evaluate_script(function="""
|
||||
() => {
|
||||
return new Promise((resolve) => {
|
||||
const check = () => {
|
||||
if (document.querySelector('.lora-card')) {
|
||||
resolve(true);
|
||||
} else {
|
||||
setTimeout(check, 100);
|
||||
}
|
||||
};
|
||||
check();
|
||||
});
|
||||
}
|
||||
""")
|
||||
```
|
||||
|
||||
## Taking Snapshots
|
||||
|
||||
```python
|
||||
# Full page snapshot
|
||||
snapshot = take_snapshot()
|
||||
|
||||
# Verbose snapshot (more details)
|
||||
snapshot = take_snapshot(verbose=True)
|
||||
|
||||
# Save to file
|
||||
take_snapshot(filePath="test-snapshots/page-load.json")
|
||||
```
|
||||
|
||||
## Element Interaction
|
||||
|
||||
```python
|
||||
# Click element
|
||||
click(uid="element-uid-from-snapshot")
|
||||
|
||||
# Double click
|
||||
click(uid="element-uid", dblClick=True)
|
||||
|
||||
# Fill input
|
||||
fill(uid="search-input", value="test query")
|
||||
|
||||
# Fill multiple inputs
|
||||
fill_form(elements=[
|
||||
{"uid": "input-1", "value": "value 1"},
|
||||
{"uid": "input-2", "value": "value 2"},
|
||||
])
|
||||
|
||||
# Hover
|
||||
hover(uid="lora-card-1")
|
||||
|
||||
# Upload file
|
||||
upload_file(uid="file-input", filePath="/path/to/file.safetensors")
|
||||
```
|
||||
|
||||
## Keyboard Input
|
||||
|
||||
```python
|
||||
# Press key
|
||||
press_key(key="Enter")
|
||||
press_key(key="Escape")
|
||||
press_key(key="Tab")
|
||||
|
||||
# Keyboard shortcuts
|
||||
press_key(key="Control+A") # Select all
|
||||
press_key(key="Control+F") # Find
|
||||
```
|
||||
|
||||
## JavaScript Evaluation
|
||||
|
||||
```python
|
||||
# Simple evaluation
|
||||
result = evaluate_script(function="() => document.title")
|
||||
|
||||
# Async evaluation
|
||||
result = evaluate_script(function="""
|
||||
async () => {
|
||||
const response = await fetch('/loras/api/list');
|
||||
return await response.json();
|
||||
}
|
||||
""")
|
||||
|
||||
# Check element existence
|
||||
exists = evaluate_script(function="""
|
||||
() => document.querySelector('.lora-card') !== null
|
||||
""")
|
||||
|
||||
# Get element count
|
||||
count = evaluate_script(function="""
|
||||
() => document.querySelectorAll('.lora-card').length
|
||||
""")
|
||||
```
|
||||
|
||||
## Network Monitoring
|
||||
|
||||
```python
|
||||
# List all network requests
|
||||
requests = list_network_requests()
|
||||
|
||||
# Filter by resource type
|
||||
xhr_requests = list_network_requests(resourceTypes=["xhr", "fetch"])
|
||||
|
||||
# Get specific request details
|
||||
details = get_network_request(reqid=123)
|
||||
|
||||
# Include preserved requests from previous navigations
|
||||
all_requests = list_network_requests(includePreservedRequests=True)
|
||||
```
|
||||
|
||||
## Console Monitoring
|
||||
|
||||
```python
|
||||
# List all console messages
|
||||
messages = list_console_messages()
|
||||
|
||||
# Filter by type
|
||||
errors = list_console_messages(types=["error", "warn"])
|
||||
|
||||
# Include preserved messages
|
||||
all_messages = list_console_messages(includePreservedMessages=True)
|
||||
|
||||
# Get specific message
|
||||
details = get_console_message(msgid=1)
|
||||
```
|
||||
|
||||
## Performance Testing
|
||||
|
||||
```python
|
||||
# Start trace with page reload
|
||||
performance_start_trace(reload=True, autoStop=False)
|
||||
|
||||
# Start trace without reload
|
||||
performance_start_trace(reload=False, autoStop=True, filePath="trace.json.gz")
|
||||
|
||||
# Stop trace
|
||||
results = performance_stop_trace()
|
||||
|
||||
# Stop and save
|
||||
performance_stop_trace(filePath="trace-results.json.gz")
|
||||
|
||||
# Analyze specific insight
|
||||
insight = performance_analyze_insight(
|
||||
insightSetId="results.insightSets[0].id",
|
||||
insightName="LCPBreakdown"
|
||||
)
|
||||
```
|
||||
|
||||
## Page Management
|
||||
|
||||
```python
|
||||
# List open pages
|
||||
pages = list_pages()
|
||||
|
||||
# Select a page
|
||||
select_page(pageId=0, bringToFront=True)
|
||||
|
||||
# Create new page
|
||||
new_page(url="http://127.0.0.1:8188/loras")
|
||||
|
||||
# Close page (keep at least one open!)
|
||||
close_page(pageId=1)
|
||||
|
||||
# Resize page
|
||||
resize_page(width=1920, height=1080)
|
||||
```
|
||||
|
||||
## Screenshots
|
||||
|
||||
```python
|
||||
# Full page screenshot
|
||||
take_screenshot(fullPage=True)
|
||||
|
||||
# Viewport screenshot
|
||||
take_screenshot()
|
||||
|
||||
# Element screenshot
|
||||
take_screenshot(uid="lora-card-1")
|
||||
|
||||
# Save to file
|
||||
take_screenshot(filePath="screenshots/page.png", format="png")
|
||||
|
||||
# JPEG with quality
|
||||
take_screenshot(filePath="screenshots/page.jpg", format="jpeg", quality=90)
|
||||
```
|
||||
|
||||
## Dialog Handling
|
||||
|
||||
```python
|
||||
# Accept dialog
|
||||
handle_dialog(action="accept")
|
||||
|
||||
# Accept with text input
|
||||
handle_dialog(action="accept", promptText="user input")
|
||||
|
||||
# Dismiss dialog
|
||||
handle_dialog(action="dismiss")
|
||||
```
|
||||
|
||||
## Device Emulation
|
||||
|
||||
```python
|
||||
# Mobile viewport
|
||||
emulate(viewport={"width": 375, "height": 667, "isMobile": True, "hasTouch": True})
|
||||
|
||||
# Tablet viewport
|
||||
emulate(viewport={"width": 768, "height": 1024, "isMobile": True, "hasTouch": True})
|
||||
|
||||
# Desktop viewport
|
||||
emulate(viewport={"width": 1920, "height": 1080})
|
||||
|
||||
# Network throttling
|
||||
emulate(networkConditions="Slow 3G")
|
||||
emulate(networkConditions="Fast 4G")
|
||||
|
||||
# CPU throttling
|
||||
emulate(cpuThrottlingRate=4) # 4x slowdown
|
||||
|
||||
# Geolocation
|
||||
emulate(geolocation={"latitude": 37.7749, "longitude": -122.4194})
|
||||
|
||||
# User agent
|
||||
emulate(userAgent="Mozilla/5.0 (Custom)")
|
||||
|
||||
# Reset emulation
|
||||
emulate(viewport=None, networkConditions="No emulation", userAgent=None)
|
||||
```
|
||||
|
||||
## Drag and Drop
|
||||
|
||||
```python
|
||||
# Drag element to another
|
||||
drag(from_uid="draggable-item", to_uid="drop-zone")
|
||||
```
|
||||
|
||||
## Common LoRa Manager Test Patterns
|
||||
|
||||
### Verify LoRA Cards Loaded
|
||||
|
||||
```python
|
||||
navigate_page(type="url", url="http://127.0.0.1:8188/loras")
|
||||
wait_for(text="LoRAs", timeout=10000)
|
||||
|
||||
# Check if cards loaded
|
||||
result = evaluate_script(function="""
|
||||
() => {
|
||||
const cards = document.querySelectorAll('.lora-card');
|
||||
return {
|
||||
count: cards.length,
|
||||
hasData: cards.length > 0
|
||||
};
|
||||
}
|
||||
""")
|
||||
```
|
||||
|
||||
### Search and Verify Results
|
||||
|
||||
```python
|
||||
fill(uid="search-input", value="character")
|
||||
press_key(key="Enter")
|
||||
wait_for(timeout=2000) # Wait for debounce
|
||||
|
||||
# Check results
|
||||
result = evaluate_script(function="""
|
||||
() => {
|
||||
const cards = document.querySelectorAll('.lora-card');
|
||||
const names = Array.from(cards).map(c => c.dataset.name || c.textContent);
|
||||
return { count: cards.length, names };
|
||||
}
|
||||
""")
|
||||
```
|
||||
|
||||
### Check API Response
|
||||
|
||||
```python
|
||||
# Trigger API call
|
||||
evaluate_script(function="""
|
||||
() => window.loraApiCallPromise = fetch('/loras/api/list').then(r => r.json())
|
||||
""")
|
||||
|
||||
# Wait and get result
|
||||
import time
|
||||
time.sleep(1)
|
||||
|
||||
result = evaluate_script(function="""
|
||||
async () => await window.loraApiCallPromise
|
||||
""")
|
||||
```
|
||||
|
||||
### Monitor Console for Errors
|
||||
|
||||
```python
|
||||
# Before test: clear console (navigate reloads)
|
||||
navigate_page(type="reload")
|
||||
|
||||
# ... perform actions ...
|
||||
|
||||
# Check for errors
|
||||
errors = list_console_messages(types=["error"])
|
||||
assert len(errors) == 0, f"Console errors: {errors}"
|
||||
```
|
||||
272
.agents/skills/lora-manager-e2e/references/test-scenarios.md
Normal file
272
.agents/skills/lora-manager-e2e/references/test-scenarios.md
Normal file
@@ -0,0 +1,272 @@
|
||||
# LoRa Manager E2E Test Scenarios
|
||||
|
||||
This document provides detailed test scenarios for end-to-end validation of LoRa Manager features.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
1. [LoRA List Page](#lora-list-page)
|
||||
2. [Model Details](#model-details)
|
||||
3. [Recipes](#recipes)
|
||||
4. [Settings](#settings)
|
||||
5. [Import/Export](#importexport)
|
||||
|
||||
---
|
||||
|
||||
## LoRA List Page
|
||||
|
||||
### Scenario: Page Load and Display
|
||||
|
||||
**Objective**: Verify the LoRA list page loads correctly and displays models.
|
||||
|
||||
**Steps**:
|
||||
1. Navigate to `http://127.0.0.1:8188/loras`
|
||||
2. Wait for page title "LoRAs" to appear
|
||||
3. Take snapshot to verify:
|
||||
- Header with "LoRAs" title is visible
|
||||
- Search/filter controls are present
|
||||
- Grid/list view toggle exists
|
||||
- LoRA cards are displayed (if models exist)
|
||||
- Pagination controls (if applicable)
|
||||
|
||||
**Expected Result**: Page loads without errors, UI elements are present.
|
||||
|
||||
### Scenario: Search Functionality
|
||||
|
||||
**Objective**: Verify search filters LoRA models correctly.
|
||||
|
||||
**Steps**:
|
||||
1. Ensure at least one LoRA exists with known name (e.g., "test-character")
|
||||
2. Navigate to LoRA list page
|
||||
3. Enter search term in search box: "test"
|
||||
4. Press Enter or click search button
|
||||
5. Wait for results to update
|
||||
|
||||
**Expected Result**: Only LoRAs matching search term are displayed.
|
||||
|
||||
**Verification Script**:
|
||||
```python
|
||||
# After search, verify filtered results
|
||||
evaluate_script(function="""
|
||||
() => {
|
||||
const cards = document.querySelectorAll('.lora-card');
|
||||
const names = Array.from(cards).map(c => c.dataset.name);
|
||||
return { count: cards.length, names };
|
||||
}
|
||||
""")
|
||||
```
|
||||
|
||||
### Scenario: Filter by Tags
|
||||
|
||||
**Objective**: Verify tag filtering works correctly.
|
||||
|
||||
**Steps**:
|
||||
1. Navigate to LoRA list page
|
||||
2. Click on a tag (e.g., "character", "style")
|
||||
3. Wait for filtered results
|
||||
|
||||
**Expected Result**: Only LoRAs with selected tag are displayed.
|
||||
|
||||
### Scenario: View Mode Toggle
|
||||
|
||||
**Objective**: Verify grid/list view toggle works.
|
||||
|
||||
**Steps**:
|
||||
1. Navigate to LoRA list page
|
||||
2. Click list view button
|
||||
3. Verify list layout
|
||||
4. Click grid view button
|
||||
5. Verify grid layout
|
||||
|
||||
**Expected Result**: View mode changes correctly, layout updates.
|
||||
|
||||
---
|
||||
|
||||
## Model Details
|
||||
|
||||
### Scenario: Open Model Details
|
||||
|
||||
**Objective**: Verify clicking a LoRA opens its details.
|
||||
|
||||
**Steps**:
|
||||
1. Navigate to LoRA list page
|
||||
2. Click on a LoRA card
|
||||
3. Wait for details panel/modal to open
|
||||
|
||||
**Expected Result**: Details panel shows:
|
||||
- Model name
|
||||
- Preview image
|
||||
- Metadata (trigger words, tags, etc.)
|
||||
- Action buttons (edit, delete, etc.)
|
||||
|
||||
### Scenario: Edit Model Metadata
|
||||
|
||||
**Objective**: Verify metadata editing works end-to-end.
|
||||
|
||||
**Steps**:
|
||||
1. Open a LoRA's details
|
||||
2. Click "Edit" button
|
||||
3. Modify trigger words field
|
||||
4. Add/remove tags
|
||||
5. Save changes
|
||||
6. Refresh page
|
||||
7. Reopen the same LoRA
|
||||
|
||||
**Expected Result**: Changes persist after refresh.
|
||||
|
||||
### Scenario: Delete Model
|
||||
|
||||
**Objective**: Verify model deletion works.
|
||||
|
||||
**Steps**:
|
||||
1. Open a LoRA's details
|
||||
2. Click "Delete" button
|
||||
3. Confirm deletion in dialog
|
||||
4. Wait for removal
|
||||
|
||||
**Expected Result**: Model removed from list, success message shown.
|
||||
|
||||
---
|
||||
|
||||
## Recipes
|
||||
|
||||
### Scenario: Recipe List Display
|
||||
|
||||
**Objective**: Verify recipes page loads and displays recipes.
|
||||
|
||||
**Steps**:
|
||||
1. Navigate to `http://127.0.0.1:8188/recipes`
|
||||
2. Wait for "Recipes" title
|
||||
3. Take snapshot
|
||||
|
||||
**Expected Result**: Recipe list displayed with cards/items.
|
||||
|
||||
### Scenario: Create New Recipe
|
||||
|
||||
**Objective**: Verify recipe creation workflow.
|
||||
|
||||
**Steps**:
|
||||
1. Navigate to recipes page
|
||||
2. Click "New Recipe" button
|
||||
3. Fill recipe form:
|
||||
- Name: "Test Recipe"
|
||||
- Description: "E2E test recipe"
|
||||
- Add LoRA models
|
||||
4. Save recipe
|
||||
5. Verify recipe appears in list
|
||||
|
||||
**Expected Result**: New recipe created and displayed.
|
||||
|
||||
### Scenario: Apply Recipe
|
||||
|
||||
**Objective**: Verify applying a recipe to ComfyUI.
|
||||
|
||||
**Steps**:
|
||||
1. Open a recipe
|
||||
2. Click "Apply" or "Load in ComfyUI"
|
||||
3. Verify action completes
|
||||
|
||||
**Expected Result**: Recipe applied successfully.
|
||||
|
||||
---
|
||||
|
||||
## Settings
|
||||
|
||||
### Scenario: Settings Page Load
|
||||
|
||||
**Objective**: Verify settings page displays correctly.
|
||||
|
||||
**Steps**:
|
||||
1. Navigate to `http://127.0.0.1:8188/settings`
|
||||
2. Wait for "Settings" title
|
||||
3. Take snapshot
|
||||
|
||||
**Expected Result**: Settings form with various options displayed.
|
||||
|
||||
### Scenario: Change Setting and Restart
|
||||
|
||||
**Objective**: Verify settings persist after restart.
|
||||
|
||||
**Steps**:
|
||||
1. Navigate to settings page
|
||||
2. Change a setting (e.g., default view mode)
|
||||
3. Save settings
|
||||
4. Restart server: `python scripts/start_server.py --restart --wait`
|
||||
5. Refresh browser page
|
||||
6. Navigate to settings
|
||||
|
||||
**Expected Result**: Changed setting value persists.
|
||||
|
||||
---
|
||||
|
||||
## Import/Export
|
||||
|
||||
### Scenario: Export Models List
|
||||
|
||||
**Objective**: Verify export functionality.
|
||||
|
||||
**Steps**:
|
||||
1. Navigate to LoRA list
|
||||
2. Click "Export" button
|
||||
3. Select format (JSON/CSV)
|
||||
4. Download file
|
||||
|
||||
**Expected Result**: File downloaded with correct data.
|
||||
|
||||
### Scenario: Import Models
|
||||
|
||||
**Objective**: Verify import functionality.
|
||||
|
||||
**Steps**:
|
||||
1. Prepare import file
|
||||
2. Navigate to import page
|
||||
3. Upload file
|
||||
4. Verify import results
|
||||
|
||||
**Expected Result**: Models imported successfully, confirmation shown.
|
||||
|
||||
---
|
||||
|
||||
## API Integration Tests
|
||||
|
||||
### Scenario: Verify API Endpoints
|
||||
|
||||
**Objective**: Verify backend API responds correctly.
|
||||
|
||||
**Test via browser console**:
|
||||
```javascript
|
||||
// List LoRAs
|
||||
fetch('/loras/api/list').then(r => r.json()).then(console.log)
|
||||
|
||||
// Get LoRA details
|
||||
fetch('/loras/api/detail/<id>').then(r => r.json()).then(console.log)
|
||||
|
||||
// Search LoRAs
|
||||
fetch('/loras/api/search?q=test').then(r => r.json()).then(console.log)
|
||||
```
|
||||
|
||||
**Expected Result**: APIs return valid JSON with expected structure.
|
||||
|
||||
---
|
||||
|
||||
## Console Error Monitoring
|
||||
|
||||
During all tests, monitor browser console for errors:
|
||||
|
||||
```python
|
||||
# Check for JavaScript errors
|
||||
messages = list_console_messages(types=["error"])
|
||||
assert len(messages) == 0, f"Console errors found: {messages}"
|
||||
```
|
||||
|
||||
## Network Request Verification
|
||||
|
||||
Verify key API calls are made:
|
||||
|
||||
```python
|
||||
# List XHR requests
|
||||
requests = list_network_requests(resourceTypes=["xhr", "fetch"])
|
||||
|
||||
# Look for specific endpoints
|
||||
lora_list_requests = [r for r in requests if "/api/list" in r.get("url", "")]
|
||||
assert len(lora_list_requests) > 0, "LoRA list API not called"
|
||||
```
|
||||
193
.agents/skills/lora-manager-e2e/scripts/example_e2e_test.py
Executable file
193
.agents/skills/lora-manager-e2e/scripts/example_e2e_test.py
Executable file
@@ -0,0 +1,193 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Example E2E test demonstrating LoRa Manager testing workflow.
|
||||
|
||||
This script shows how to:
|
||||
1. Start the standalone server
|
||||
2. Use Chrome DevTools MCP to interact with the UI
|
||||
3. Verify functionality end-to-end
|
||||
|
||||
Note: This is a template. Actual execution requires Chrome DevTools MCP.
|
||||
"""
|
||||
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
|
||||
|
||||
def run_test():
|
||||
"""Run example E2E test flow."""
|
||||
|
||||
print("=" * 60)
|
||||
print("LoRa Manager E2E Test Example")
|
||||
print("=" * 60)
|
||||
|
||||
# Step 1: Start server
|
||||
print("\n[1/5] Starting LoRa Manager standalone server...")
|
||||
result = subprocess.run(
|
||||
[sys.executable, "start_server.py", "--port", "8188", "--wait", "--timeout", "30"],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
if result.returncode != 0:
|
||||
print(f"Failed to start server: {result.stderr}")
|
||||
return 1
|
||||
print("Server ready!")
|
||||
|
||||
# Step 2: Open Chrome (manual step - show command)
|
||||
print("\n[2/5] Open Chrome with debug mode:")
|
||||
print("google-chrome --remote-debugging-port=9222 --user-data-dir=/tmp/chrome-lora-manager http://127.0.0.1:8188/loras")
|
||||
print("(In actual test, this would be automated via MCP)")
|
||||
|
||||
# Step 3: Navigate and verify page load
|
||||
print("\n[3/5] Page Load Verification:")
|
||||
print("""
|
||||
MCP Commands to execute:
|
||||
1. navigate_page(type="url", url="http://127.0.0.1:8188/loras")
|
||||
2. wait_for(text="LoRAs", timeout=10000)
|
||||
3. snapshot = take_snapshot()
|
||||
""")
|
||||
|
||||
# Step 4: Test search functionality
|
||||
print("\n[4/5] Search Functionality Test:")
|
||||
print("""
|
||||
MCP Commands to execute:
|
||||
1. fill(uid="search-input", value="test")
|
||||
2. press_key(key="Enter")
|
||||
3. wait_for(text="Results", timeout=5000)
|
||||
4. result = evaluate_script(function="""
|
||||
() => {
|
||||
const cards = document.querySelectorAll('.lora-card');
|
||||
return { count: cards.length };
|
||||
}
|
||||
""")
|
||||
""")
|
||||
|
||||
# Step 5: Verify API
|
||||
print("\n[5/5] API Verification:")
|
||||
print("""
|
||||
MCP Commands to execute:
|
||||
1. api_result = evaluate_script(function="""
|
||||
async () => {
|
||||
const response = await fetch('/loras/api/list');
|
||||
const data = await response.json();
|
||||
return { count: data.length, status: response.status };
|
||||
}
|
||||
""")
|
||||
2. Verify api_result['status'] == 200
|
||||
""")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Test flow completed!")
|
||||
print("=" * 60)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def example_restart_flow():
|
||||
"""Example: Testing configuration change that requires restart."""
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Example: Server Restart Flow")
|
||||
print("=" * 60)
|
||||
|
||||
print("""
|
||||
Scenario: Change setting and verify after restart
|
||||
|
||||
Steps:
|
||||
1. Navigate to settings page
|
||||
- navigate_page(type="url", url="http://127.0.0.1:8188/settings")
|
||||
|
||||
2. Change a setting (e.g., theme)
|
||||
- fill(uid="theme-select", value="dark")
|
||||
- click(uid="save-settings-button")
|
||||
|
||||
3. Restart server
|
||||
- subprocess.run([python, "start_server.py", "--restart", "--wait"])
|
||||
|
||||
4. Refresh browser
|
||||
- navigate_page(type="reload", ignoreCache=True)
|
||||
- wait_for(text="LoRAs", timeout=15000)
|
||||
|
||||
5. Verify setting persisted
|
||||
- navigate_page(type="url", url="http://127.0.0.1:8188/settings")
|
||||
- theme = evaluate_script(function="() => document.querySelector('#theme-select').value")
|
||||
- assert theme == "dark"
|
||||
""")
|
||||
|
||||
|
||||
def example_modal_interaction():
|
||||
"""Example: Testing modal dialog interaction."""
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Example: Modal Dialog Interaction")
|
||||
print("=" * 60)
|
||||
|
||||
print("""
|
||||
Scenario: Add new LoRA via modal
|
||||
|
||||
Steps:
|
||||
1. Open modal
|
||||
- click(uid="add-lora-button")
|
||||
- wait_for(text="Add LoRA", timeout=3000)
|
||||
|
||||
2. Fill form
|
||||
- fill_form(elements=[
|
||||
{"uid": "lora-name", "value": "Test Character"},
|
||||
{"uid": "lora-path", "value": "/models/test.safetensors"},
|
||||
])
|
||||
|
||||
3. Submit
|
||||
- click(uid="modal-submit-button")
|
||||
|
||||
4. Verify success
|
||||
- wait_for(text="Successfully added", timeout=5000)
|
||||
- snapshot = take_snapshot()
|
||||
""")
|
||||
|
||||
|
||||
def example_network_monitoring():
|
||||
"""Example: Network request monitoring."""
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Example: Network Request Monitoring")
|
||||
print("=" * 60)
|
||||
|
||||
print("""
|
||||
Scenario: Verify API calls during user interaction
|
||||
|
||||
Steps:
|
||||
1. Clear network log (implicit on navigation)
|
||||
- navigate_page(type="url", url="http://127.0.0.1:8188/loras")
|
||||
|
||||
2. Perform action that triggers API call
|
||||
- fill(uid="search-input", value="character")
|
||||
- press_key(key="Enter")
|
||||
|
||||
3. List network requests
|
||||
- requests = list_network_requests(resourceTypes=["xhr", "fetch"])
|
||||
|
||||
4. Find search API call
|
||||
- search_requests = [r for r in requests if "/api/search" in r.get("url", "")]
|
||||
- assert len(search_requests) > 0, "Search API was not called"
|
||||
|
||||
5. Get request details
|
||||
- if search_requests:
|
||||
details = get_network_request(reqid=search_requests[0]["reqid"])
|
||||
- Verify request method, response status, etc.
|
||||
""")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("LoRa Manager E2E Test Examples\n")
|
||||
print("This script demonstrates E2E testing patterns.\n")
|
||||
print("Note: Actual execution requires Chrome DevTools MCP connection.\n")
|
||||
|
||||
run_test()
|
||||
example_restart_flow()
|
||||
example_modal_interaction()
|
||||
example_network_monitoring()
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("All examples shown!")
|
||||
print("=" * 60)
|
||||
169
.agents/skills/lora-manager-e2e/scripts/start_server.py
Executable file
169
.agents/skills/lora-manager-e2e/scripts/start_server.py
Executable file
@@ -0,0 +1,169 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Start or restart LoRa Manager standalone server for E2E testing.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import socket
|
||||
import signal
|
||||
import os
|
||||
|
||||
|
||||
def find_server_process(port: int) -> list[int]:
|
||||
"""Find PIDs of processes listening on the given port."""
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["lsof", "-ti", f":{port}"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False
|
||||
)
|
||||
if result.returncode == 0 and result.stdout.strip():
|
||||
return [int(pid) for pid in result.stdout.strip().split("\n") if pid]
|
||||
except FileNotFoundError:
|
||||
# lsof not available, try netstat
|
||||
try:
|
||||
result = subprocess.run(
|
||||
["netstat", "-tlnp"],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=False
|
||||
)
|
||||
pids = []
|
||||
for line in result.stdout.split("\n"):
|
||||
if f":{port}" in line:
|
||||
parts = line.split()
|
||||
for part in parts:
|
||||
if "/" in part:
|
||||
try:
|
||||
pid = int(part.split("/")[0])
|
||||
pids.append(pid)
|
||||
except ValueError:
|
||||
pass
|
||||
return pids
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
return []
|
||||
|
||||
|
||||
def kill_server(port: int) -> None:
|
||||
"""Kill processes using the specified port."""
|
||||
pids = find_server_process(port)
|
||||
for pid in pids:
|
||||
try:
|
||||
os.kill(pid, signal.SIGTERM)
|
||||
print(f"Sent SIGTERM to process {pid}")
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
|
||||
# Wait for processes to terminate
|
||||
time.sleep(1)
|
||||
|
||||
# Force kill if still running
|
||||
pids = find_server_process(port)
|
||||
for pid in pids:
|
||||
try:
|
||||
os.kill(pid, signal.SIGKILL)
|
||||
print(f"Sent SIGKILL to process {pid}")
|
||||
except ProcessLookupError:
|
||||
pass
|
||||
|
||||
|
||||
def is_server_ready(port: int, timeout: float = 0.5) -> bool:
|
||||
"""Check if server is accepting connections."""
|
||||
try:
|
||||
with socket.create_connection(("127.0.0.1", port), timeout=timeout):
|
||||
return True
|
||||
except (socket.timeout, ConnectionRefusedError, OSError):
|
||||
return False
|
||||
|
||||
|
||||
def wait_for_server(port: int, timeout: int = 30) -> bool:
|
||||
"""Wait for server to become ready."""
|
||||
start = time.time()
|
||||
while time.time() - start < timeout:
|
||||
if is_server_ready(port):
|
||||
return True
|
||||
time.sleep(0.5)
|
||||
return False
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Start LoRa Manager standalone server for E2E testing"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=8188,
|
||||
help="Server port (default: 8188)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--restart",
|
||||
action="store_true",
|
||||
help="Kill existing server before starting"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--wait",
|
||||
action="store_true",
|
||||
help="Wait for server to be ready before exiting"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timeout",
|
||||
type=int,
|
||||
default=30,
|
||||
help="Timeout for waiting (default: 30)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Get project root (parent of .agents directory)
|
||||
script_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
skill_dir = os.path.dirname(script_dir)
|
||||
project_root = os.path.dirname(os.path.dirname(os.path.dirname(skill_dir)))
|
||||
|
||||
# Restart if requested
|
||||
if args.restart:
|
||||
print(f"Killing existing server on port {args.port}...")
|
||||
kill_server(args.port)
|
||||
time.sleep(1)
|
||||
|
||||
# Check if already running
|
||||
if is_server_ready(args.port):
|
||||
print(f"Server already running on port {args.port}")
|
||||
return 0
|
||||
|
||||
# Start server
|
||||
print(f"Starting LoRa Manager standalone server on port {args.port}...")
|
||||
cmd = [sys.executable, "standalone.py", "--port", str(args.port)]
|
||||
|
||||
# Start in background
|
||||
process = subprocess.Popen(
|
||||
cmd,
|
||||
cwd=project_root,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
start_new_session=True
|
||||
)
|
||||
|
||||
print(f"Server process started with PID {process.pid}")
|
||||
|
||||
# Wait for ready if requested
|
||||
if args.wait:
|
||||
print(f"Waiting for server to be ready (timeout: {args.timeout}s)...")
|
||||
if wait_for_server(args.port, args.timeout):
|
||||
print(f"Server ready at http://127.0.0.1:{args.port}/loras")
|
||||
return 0
|
||||
else:
|
||||
print(f"Timeout waiting for server")
|
||||
return 1
|
||||
|
||||
print(f"Server starting at http://127.0.0.1:{args.port}/loras")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
61
.agents/skills/lora-manager-e2e/scripts/wait_for_server.py
Executable file
61
.agents/skills/lora-manager-e2e/scripts/wait_for_server.py
Executable file
@@ -0,0 +1,61 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Wait for LoRa Manager server to become ready.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import socket
|
||||
import sys
|
||||
import time
|
||||
|
||||
|
||||
def is_server_ready(port: int, timeout: float = 0.5) -> bool:
|
||||
"""Check if server is accepting connections."""
|
||||
try:
|
||||
with socket.create_connection(("127.0.0.1", port), timeout=timeout):
|
||||
return True
|
||||
except (socket.timeout, ConnectionRefusedError, OSError):
|
||||
return False
|
||||
|
||||
|
||||
def wait_for_server(port: int, timeout: int = 30) -> bool:
|
||||
"""Wait for server to become ready."""
|
||||
start = time.time()
|
||||
while time.time() - start < timeout:
|
||||
if is_server_ready(port):
|
||||
return True
|
||||
time.sleep(0.5)
|
||||
return False
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Wait for LoRa Manager server to become ready"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=8188,
|
||||
help="Server port (default: 8188)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--timeout",
|
||||
type=int,
|
||||
default=30,
|
||||
help="Timeout in seconds (default: 30)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
print(f"Waiting for server on port {args.port} (timeout: {args.timeout}s)...")
|
||||
|
||||
if wait_for_server(args.port, args.timeout):
|
||||
print(f"Server ready at http://127.0.0.1:{args.port}/loras")
|
||||
return 0
|
||||
else:
|
||||
print(f"Timeout: Server not ready after {args.timeout}s")
|
||||
return 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
sys.exit(main())
|
||||
@@ -34,6 +34,11 @@ Enhance your Civitai browsing experience with our companion browser extension! S
|
||||
|
||||
## Release Notes
|
||||
|
||||
### v0.9.14
|
||||
* **LoRA Cycler Node** - Introduced a new LoRA Cycler node that enables iteration through specified LoRAs with support for repeat count and pause iteration functionality. Refer to the new "Lora Cycler" template workflow for concrete example.
|
||||
* **Enhanced Prompt Node with Tag Autocomplete** - Enhanced the Prompt node with comprehensive tag autocomplete based on merged Danbooru + e621 tags. Supports tag search and autocomplete functionality. Implemented a command system with shortcuts like `/char` or `/artist` for category-specific tag searching. Added `/ac` or `/noac` commands to quickly enable or disable autocomplete. Refer to the "Lora Manager Basic" template workflow in ComfyUI -> Templates -> ComfyUI-Lora-Manager for detailed tips.
|
||||
* **Bug Fixes & Stability** - Addressed multiple bugs and improved overall stability.
|
||||
|
||||
### v0.9.12
|
||||
* **LoRA Randomizer System** - Introduced a comprehensive LoRA randomization system featuring LoRA Pool and LoRA Randomizer nodes for flexible and dynamic generation workflows.
|
||||
* **LoRA Randomizer Template** - Refer to the new "LoRA Randomizer" template workflow for detailed examples of flexible randomization modes, lock & reuse options, and other features.
|
||||
|
||||
File diff suppressed because one or more lines are too long
@@ -9,9 +9,9 @@
|
||||
"back": "Zurück",
|
||||
"next": "Weiter",
|
||||
"backToTop": "Nach oben",
|
||||
"add": "Hinzufügen",
|
||||
"settings": "Einstellungen",
|
||||
"help": "Hilfe"
|
||||
"help": "Hilfe",
|
||||
"add": "Hinzufügen"
|
||||
},
|
||||
"status": {
|
||||
"loading": "Wird geladen...",
|
||||
@@ -1572,6 +1572,20 @@
|
||||
"content": "LoRA Manager is a passion project maintained full-time by a solo developer. Your support on Ko-fi helps cover development costs, keeps new updates coming, and unlocks a license key for the LM Civitai Extension as a thank-you gift. Every contribution truly makes a difference.",
|
||||
"supportCta": "Support on Ko-fi",
|
||||
"learnMore": "LM Civitai Extension Tutorial"
|
||||
},
|
||||
"cacheHealth": {
|
||||
"corrupted": {
|
||||
"title": "Cache-Korruption erkannt"
|
||||
},
|
||||
"degraded": {
|
||||
"title": "Cache-Probleme erkannt"
|
||||
},
|
||||
"content": "{invalid} von {total} Cache-Einträgen sind ungültig ({rate}). Dies kann zu fehlenden Modellen oder Fehlern führen. Ein Neuaufbau des Caches wird empfohlen.",
|
||||
"rebuildCache": "Cache neu aufbauen",
|
||||
"dismiss": "Verwerfen",
|
||||
"rebuilding": "Cache wird neu aufgebaut...",
|
||||
"rebuildFailed": "Fehler beim Neuaufbau des Caches: {error}",
|
||||
"retry": "Wiederholen"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1572,6 +1572,20 @@
|
||||
"content": "LoRA Manager is a passion project maintained full-time by a solo developer. Your support on Ko-fi helps cover development costs, keeps new updates coming, and unlocks a license key for the LM Civitai Extension as a thank-you gift. Every contribution truly makes a difference.",
|
||||
"supportCta": "Support on Ko-fi",
|
||||
"learnMore": "LM Civitai Extension Tutorial"
|
||||
},
|
||||
"cacheHealth": {
|
||||
"corrupted": {
|
||||
"title": "Cache Corruption Detected"
|
||||
},
|
||||
"degraded": {
|
||||
"title": "Cache Issues Detected"
|
||||
},
|
||||
"content": "{invalid} of {total} cache entries are invalid ({rate}). This may cause missing models or errors. Rebuilding the cache is recommended.",
|
||||
"rebuildCache": "Rebuild Cache",
|
||||
"dismiss": "Dismiss",
|
||||
"rebuilding": "Rebuilding cache...",
|
||||
"rebuildFailed": "Failed to rebuild cache: {error}",
|
||||
"retry": "Retry"
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1572,6 +1572,20 @@
|
||||
"content": "LoRA Manager is a passion project maintained full-time by a solo developer. Your support on Ko-fi helps cover development costs, keeps new updates coming, and unlocks a license key for the LM Civitai Extension as a thank-you gift. Every contribution truly makes a difference.",
|
||||
"supportCta": "Support on Ko-fi",
|
||||
"learnMore": "LM Civitai Extension Tutorial"
|
||||
},
|
||||
"cacheHealth": {
|
||||
"corrupted": {
|
||||
"title": "Corrupción de caché detectada"
|
||||
},
|
||||
"degraded": {
|
||||
"title": "Problemas de caché detectados"
|
||||
},
|
||||
"content": "{invalid} de {total} entradas de caché son inválidas ({rate}). Esto puede causar modelos faltantes o errores. Se recomienda reconstruir la caché.",
|
||||
"rebuildCache": "Reconstruir caché",
|
||||
"dismiss": "Descartar",
|
||||
"rebuilding": "Reconstruyendo caché...",
|
||||
"rebuildFailed": "Error al reconstruir la caché: {error}",
|
||||
"retry": "Reintentar"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1572,6 +1572,20 @@
|
||||
"content": "LoRA Manager is a passion project maintained full-time by a solo developer. Your support on Ko-fi helps cover development costs, keeps new updates coming, and unlocks a license key for the LM Civitai Extension as a thank-you gift. Every contribution truly makes a difference.",
|
||||
"supportCta": "Support on Ko-fi",
|
||||
"learnMore": "LM Civitai Extension Tutorial"
|
||||
},
|
||||
"cacheHealth": {
|
||||
"corrupted": {
|
||||
"title": "Corruption du cache détectée"
|
||||
},
|
||||
"degraded": {
|
||||
"title": "Problèmes de cache détectés"
|
||||
},
|
||||
"content": "{invalid} des {total} entrées de cache sont invalides ({rate}). Cela peut provoquer des modèles manquants ou des erreurs. Il est recommandé de reconstruire le cache.",
|
||||
"rebuildCache": "Reconstruire le cache",
|
||||
"dismiss": "Ignorer",
|
||||
"rebuilding": "Reconstruction du cache...",
|
||||
"rebuildFailed": "Échec de la reconstruction du cache : {error}",
|
||||
"retry": "Réessayer"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,9 +9,9 @@
|
||||
"back": "חזור",
|
||||
"next": "הבא",
|
||||
"backToTop": "חזור למעלה",
|
||||
"add": "הוסף",
|
||||
"settings": "הגדרות",
|
||||
"help": "עזרה"
|
||||
"help": "עזרה",
|
||||
"add": "הוסף"
|
||||
},
|
||||
"status": {
|
||||
"loading": "טוען...",
|
||||
@@ -1572,6 +1572,20 @@
|
||||
"content": "LoRA Manager is a passion project maintained full-time by a solo developer. Your support on Ko-fi helps cover development costs, keeps new updates coming, and unlocks a license key for the LM Civitai Extension as a thank-you gift. Every contribution truly makes a difference.",
|
||||
"supportCta": "Support on Ko-fi",
|
||||
"learnMore": "LM Civitai Extension Tutorial"
|
||||
},
|
||||
"cacheHealth": {
|
||||
"corrupted": {
|
||||
"title": "זוהתה שחיתות במטמון"
|
||||
},
|
||||
"degraded": {
|
||||
"title": "זוהו בעיות במטמון"
|
||||
},
|
||||
"content": "{invalid} מתוך {total} רשומות מטמון אינן תקינות ({rate}). זה עלול לגרום לדגמים חסרים או לשגיאות. מומלץ לבנות מחדש את המטמון.",
|
||||
"rebuildCache": "בניית מטמון מחדש",
|
||||
"dismiss": "ביטול",
|
||||
"rebuilding": "בונה מחדש את המטמון...",
|
||||
"rebuildFailed": "נכשלה בניית המטמון מחדש: {error}",
|
||||
"retry": "נסה שוב"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1572,6 +1572,20 @@
|
||||
"content": "LoRA Manager is a passion project maintained full-time by a solo developer. Your support on Ko-fi helps cover development costs, keeps new updates coming, and unlocks a license key for the LM Civitai Extension as a thank-you gift. Every contribution truly makes a difference.",
|
||||
"supportCta": "Support on Ko-fi",
|
||||
"learnMore": "LM Civitai Extension Tutorial"
|
||||
},
|
||||
"cacheHealth": {
|
||||
"corrupted": {
|
||||
"title": "キャッシュの破損が検出されました"
|
||||
},
|
||||
"degraded": {
|
||||
"title": "キャッシュの問題が検出されました"
|
||||
},
|
||||
"content": "{total}個のキャッシュエントリのうち{invalid}個が無効です({rate})。モデルが見つからない原因になったり、エラーが発生する可能性があります。キャッシュの再構築を推奨します。",
|
||||
"rebuildCache": "キャッシュを再構築",
|
||||
"dismiss": "閉じる",
|
||||
"rebuilding": "キャッシュを再構築中...",
|
||||
"rebuildFailed": "キャッシュの再構築に失敗しました: {error}",
|
||||
"retry": "再試行"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1572,6 +1572,20 @@
|
||||
"content": "LoRA Manager is a passion project maintained full-time by a solo developer. Your support on Ko-fi helps cover development costs, keeps new updates coming, and unlocks a license key for the LM Civitai Extension as a thank-you gift. Every contribution truly makes a difference.",
|
||||
"supportCta": "Support on Ko-fi",
|
||||
"learnMore": "LM Civitai Extension Tutorial"
|
||||
},
|
||||
"cacheHealth": {
|
||||
"corrupted": {
|
||||
"title": "캐시 손상이 감지되었습니다"
|
||||
},
|
||||
"degraded": {
|
||||
"title": "캐시 문제가 감지되었습니다"
|
||||
},
|
||||
"content": "{total}개의 캐시 항목 중 {invalid}개가 유효하지 않습니다 ({rate}). 모델 누락이나 오류가 발생할 수 있습니다. 캐시를 재구축하는 것이 좋습니다.",
|
||||
"rebuildCache": "캐시 재구축",
|
||||
"dismiss": "무시",
|
||||
"rebuilding": "캐시 재구축 중...",
|
||||
"rebuildFailed": "캐시 재구축 실패: {error}",
|
||||
"retry": "다시 시도"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1572,6 +1572,20 @@
|
||||
"content": "LoRA Manager is a passion project maintained full-time by a solo developer. Your support on Ko-fi helps cover development costs, keeps new updates coming, and unlocks a license key for the LM Civitai Extension as a thank-you gift. Every contribution truly makes a difference.",
|
||||
"supportCta": "Support on Ko-fi",
|
||||
"learnMore": "LM Civitai Extension Tutorial"
|
||||
},
|
||||
"cacheHealth": {
|
||||
"corrupted": {
|
||||
"title": "Обнаружено повреждение кэша"
|
||||
},
|
||||
"degraded": {
|
||||
"title": "Обнаружены проблемы с кэшем"
|
||||
},
|
||||
"content": "{invalid} из {total} записей кэша недействительны ({rate}). Это может привести к отсутствию моделей или ошибкам. Рекомендуется перестроить кэш.",
|
||||
"rebuildCache": "Перестроить кэш",
|
||||
"dismiss": "Отклонить",
|
||||
"rebuilding": "Перестроение кэша...",
|
||||
"rebuildFailed": "Не удалось перестроить кэш: {error}",
|
||||
"retry": "Повторить"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1572,6 +1572,20 @@
|
||||
"content": "来爱发电为Lora Manager项目发电,支持项目持续开发的同时,获取浏览器插件验证码,按季支付更优惠!支付宝/微信方便支付。感谢支持!🚀",
|
||||
"supportCta": "为LM发电",
|
||||
"learnMore": "浏览器插件教程"
|
||||
},
|
||||
"cacheHealth": {
|
||||
"corrupted": {
|
||||
"title": "检测到缓存损坏"
|
||||
},
|
||||
"degraded": {
|
||||
"title": "检测到缓存问题"
|
||||
},
|
||||
"content": "{total} 个缓存条目中有 {invalid} 个无效({rate})。这可能导致模型丢失或错误。建议重建缓存。",
|
||||
"rebuildCache": "重建缓存",
|
||||
"dismiss": "忽略",
|
||||
"rebuilding": "正在重建缓存...",
|
||||
"rebuildFailed": "重建缓存失败:{error}",
|
||||
"retry": "重试"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1572,6 +1572,20 @@
|
||||
"content": "LoRA Manager is a passion project maintained full-time by a solo developer. Your support on Ko-fi helps cover development costs, keeps new updates coming, and unlocks a license key for the LM Civitai Extension as a thank-you gift. Every contribution truly makes a difference.",
|
||||
"supportCta": "Support on Ko-fi",
|
||||
"learnMore": "LM Civitai Extension Tutorial"
|
||||
},
|
||||
"cacheHealth": {
|
||||
"corrupted": {
|
||||
"title": "檢測到快取損壞"
|
||||
},
|
||||
"degraded": {
|
||||
"title": "檢測到快取問題"
|
||||
},
|
||||
"content": "{total} 個快取項目中有 {invalid} 個無效({rate})。這可能會導致模型遺失或錯誤。建議重建快取。",
|
||||
"rebuildCache": "重建快取",
|
||||
"dismiss": "關閉",
|
||||
"rebuilding": "重建快取中...",
|
||||
"rebuildFailed": "重建快取失敗:{error}",
|
||||
"retry": "重試"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,9 @@
|
||||
"private": true,
|
||||
"type": "module",
|
||||
"scripts": {
|
||||
"test": "vitest run",
|
||||
"test": "npm run test:js && npm run test:vue",
|
||||
"test:js": "vitest run",
|
||||
"test:vue": "cd vue-widgets && npx vitest run",
|
||||
"test:watch": "vitest",
|
||||
"test:coverage": "node scripts/run_frontend_coverage.js"
|
||||
},
|
||||
|
||||
93
py/config.py
93
py/config.py
@@ -441,82 +441,53 @@ class Config:
|
||||
logger.info("Failed to write symlink cache %s: %s", cache_path, exc)
|
||||
|
||||
def _scan_symbolic_links(self):
|
||||
"""Scan all symbolic links in LoRA, Checkpoint, and Embedding root directories"""
|
||||
"""Scan symbolic links in LoRA, Checkpoint, and Embedding root directories.
|
||||
|
||||
Only scans the first level of each root directory to avoid performance
|
||||
issues with large file systems. Detects symlinks and Windows junctions
|
||||
at the root level only (not nested symlinks in subdirectories).
|
||||
"""
|
||||
start = time.perf_counter()
|
||||
|
||||
# Reset mappings before rescanning to avoid stale entries
|
||||
self._path_mappings.clear()
|
||||
self._seed_root_symlink_mappings()
|
||||
visited_dirs: Set[str] = set()
|
||||
for root in self._symlink_roots():
|
||||
self._scan_directory_links(root, visited_dirs)
|
||||
self._scan_first_level_symlinks(root)
|
||||
logger.debug(
|
||||
"Symlink scan finished in %.2f ms with %d mappings",
|
||||
(time.perf_counter() - start) * 1000,
|
||||
len(self._path_mappings),
|
||||
)
|
||||
|
||||
def _scan_directory_links(self, root: str, visited_dirs: Set[str]):
|
||||
"""Iteratively scan directory symlinks to avoid deep recursion."""
|
||||
def _scan_first_level_symlinks(self, root: str):
|
||||
"""Scan only the first level of a directory for symlinks.
|
||||
|
||||
This avoids traversing the entire directory tree which can be extremely
|
||||
slow for large model collections. Only symlinks directly under the root
|
||||
are detected.
|
||||
"""
|
||||
try:
|
||||
# Note: We only use realpath for the initial root if it's not already resolved
|
||||
# to ensure we have a valid entry point.
|
||||
root_real = self._normalize_path(os.path.realpath(root))
|
||||
except OSError:
|
||||
root_real = self._normalize_path(root)
|
||||
with os.scandir(root) as it:
|
||||
for entry in it:
|
||||
try:
|
||||
# Only detect symlinks including Windows junctions
|
||||
# Skip normal directories to avoid deep traversal
|
||||
if not self._entry_is_symlink(entry):
|
||||
continue
|
||||
|
||||
if root_real in visited_dirs:
|
||||
return
|
||||
# Resolve the symlink target
|
||||
target_path = os.path.realpath(entry.path)
|
||||
if not os.path.isdir(target_path):
|
||||
continue
|
||||
|
||||
visited_dirs.add(root_real)
|
||||
# Stack entries: (display_path, real_resolved_path)
|
||||
stack: List[Tuple[str, str]] = [(root, root_real)]
|
||||
|
||||
while stack:
|
||||
current_display, current_real = stack.pop()
|
||||
try:
|
||||
with os.scandir(current_display) as it:
|
||||
for entry in it:
|
||||
try:
|
||||
# 1. Detect symlinks including Windows junctions
|
||||
is_link = self._entry_is_symlink(entry)
|
||||
|
||||
if is_link:
|
||||
# Only resolve realpath when we actually find a link
|
||||
target_path = os.path.realpath(entry.path)
|
||||
if not os.path.isdir(target_path):
|
||||
continue
|
||||
|
||||
normalized_target = self._normalize_path(target_path)
|
||||
self.add_path_mapping(entry.path, target_path)
|
||||
|
||||
if normalized_target in visited_dirs:
|
||||
continue
|
||||
|
||||
visited_dirs.add(normalized_target)
|
||||
stack.append((target_path, normalized_target))
|
||||
continue
|
||||
|
||||
# 2. Process normal directories
|
||||
if not entry.is_dir(follow_symlinks=False):
|
||||
continue
|
||||
|
||||
# For normal directories, we avoid realpath() call by
|
||||
# incrementally building the real path relative to current_real.
|
||||
# This is safe because 'entry' is NOT a symlink.
|
||||
entry_real = self._normalize_path(os.path.join(current_real, entry.name))
|
||||
|
||||
if entry_real in visited_dirs:
|
||||
continue
|
||||
|
||||
visited_dirs.add(entry_real)
|
||||
stack.append((entry.path, entry_real))
|
||||
except Exception as inner_exc:
|
||||
logger.debug(
|
||||
"Error processing directory entry %s: %s", entry.path, inner_exc
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error scanning links in {current_display}: {e}")
|
||||
self.add_path_mapping(entry.path, target_path)
|
||||
except Exception as inner_exc:
|
||||
logger.debug(
|
||||
"Error processing directory entry %s: %s", entry.path, inner_exc
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error scanning links in {root}: {e}")
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
import os
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Check if running in standalone mode
|
||||
standalone_mode = os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1" or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
|
||||
@@ -14,7 +17,7 @@ if not standalone_mode:
|
||||
# Initialize registry
|
||||
registry = MetadataRegistry()
|
||||
|
||||
print("ComfyUI Metadata Collector initialized")
|
||||
logger.info("ComfyUI Metadata Collector initialized")
|
||||
|
||||
def get_metadata(prompt_id=None):
|
||||
"""Helper function to get metadata from the registry"""
|
||||
@@ -23,7 +26,7 @@ if not standalone_mode:
|
||||
else:
|
||||
# Standalone mode - provide dummy implementations
|
||||
def init():
|
||||
print("ComfyUI Metadata Collector disabled in standalone mode")
|
||||
logger.info("ComfyUI Metadata Collector disabled in standalone mode")
|
||||
|
||||
def get_metadata(prompt_id=None):
|
||||
"""Dummy implementation for standalone mode"""
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import sys
|
||||
import inspect
|
||||
import logging
|
||||
from .metadata_registry import MetadataRegistry
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class MetadataHook:
|
||||
"""Install hooks for metadata collection"""
|
||||
|
||||
@@ -23,7 +26,7 @@ class MetadataHook:
|
||||
|
||||
# If we can't find the execution module, we can't install hooks
|
||||
if execution is None:
|
||||
print("Could not locate ComfyUI execution module, metadata collection disabled")
|
||||
logger.warning("Could not locate ComfyUI execution module, metadata collection disabled")
|
||||
return
|
||||
|
||||
# Detect whether we're using the new async version of ComfyUI
|
||||
@@ -37,16 +40,16 @@ class MetadataHook:
|
||||
is_async = inspect.iscoroutinefunction(execution._map_node_over_list)
|
||||
|
||||
if is_async:
|
||||
print("Detected async ComfyUI execution, installing async metadata hooks")
|
||||
logger.info("Detected async ComfyUI execution, installing async metadata hooks")
|
||||
MetadataHook._install_async_hooks(execution, map_node_func_name)
|
||||
else:
|
||||
print("Detected sync ComfyUI execution, installing sync metadata hooks")
|
||||
logger.info("Detected sync ComfyUI execution, installing sync metadata hooks")
|
||||
MetadataHook._install_sync_hooks(execution)
|
||||
|
||||
print("Metadata collection hooks installed for runtime values")
|
||||
logger.info("Metadata collection hooks installed for runtime values")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error installing metadata hooks: {str(e)}")
|
||||
logger.error(f"Error installing metadata hooks: {str(e)}")
|
||||
|
||||
@staticmethod
|
||||
def _install_sync_hooks(execution):
|
||||
@@ -82,7 +85,7 @@ class MetadataHook:
|
||||
if node_id is not None:
|
||||
registry.record_node_execution(node_id, class_type, input_data_all, None)
|
||||
except Exception as e:
|
||||
print(f"Error collecting metadata (pre-execution): {str(e)}")
|
||||
logger.error(f"Error collecting metadata (pre-execution): {str(e)}")
|
||||
|
||||
# Execute the original function
|
||||
results = original_map_node_over_list(obj, input_data_all, func, allow_interrupt, execution_block_cb, pre_execute_cb)
|
||||
@@ -113,7 +116,7 @@ class MetadataHook:
|
||||
if node_id is not None:
|
||||
registry.update_node_execution(node_id, class_type, results)
|
||||
except Exception as e:
|
||||
print(f"Error collecting metadata (post-execution): {str(e)}")
|
||||
logger.error(f"Error collecting metadata (post-execution): {str(e)}")
|
||||
|
||||
return results
|
||||
|
||||
@@ -159,7 +162,7 @@ class MetadataHook:
|
||||
if node_id is not None:
|
||||
registry.record_node_execution(node_id, class_type, input_data_all, None)
|
||||
except Exception as e:
|
||||
print(f"Error collecting metadata (pre-execution): {str(e)}")
|
||||
logger.error(f"Error collecting metadata (pre-execution): {str(e)}")
|
||||
|
||||
# Call original function with all args/kwargs
|
||||
results = await original_map_node_over_list(
|
||||
@@ -176,7 +179,7 @@ class MetadataHook:
|
||||
if node_id is not None:
|
||||
registry.update_node_execution(node_id, class_type, results)
|
||||
except Exception as e:
|
||||
print(f"Error collecting metadata (post-execution): {str(e)}")
|
||||
logger.error(f"Error collecting metadata (post-execution): {str(e)}")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@@ -126,9 +126,7 @@ class LoraCyclerLM:
|
||||
"current_index": [clamped_index],
|
||||
"next_index": [next_index],
|
||||
"total_count": [total_count],
|
||||
"current_lora_name": [
|
||||
current_lora.get("model_name", current_lora["file_name"])
|
||||
],
|
||||
"current_lora_name": [current_lora["file_name"]],
|
||||
"current_lora_filename": [current_lora["file_name"]],
|
||||
"next_lora_name": [next_display_name],
|
||||
"next_lora_filename": [next_lora["file_name"]],
|
||||
|
||||
@@ -8,6 +8,9 @@ from ..metadata_collector.metadata_processor import MetadataProcessor
|
||||
from ..metadata_collector import get_metadata
|
||||
from PIL import Image, PngImagePlugin
|
||||
import piexif
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class SaveImageLM:
|
||||
NAME = "Save Image (LoraManager)"
|
||||
@@ -385,7 +388,7 @@ class SaveImageLM:
|
||||
exif_bytes = piexif.dump(exif_dict)
|
||||
save_kwargs["exif"] = exif_bytes
|
||||
except Exception as e:
|
||||
print(f"Error adding EXIF data: {e}")
|
||||
logger.error(f"Error adding EXIF data: {e}")
|
||||
img.save(file_path, format="JPEG", **save_kwargs)
|
||||
elif file_format == "webp":
|
||||
try:
|
||||
@@ -403,7 +406,7 @@ class SaveImageLM:
|
||||
exif_bytes = piexif.dump(exif_dict)
|
||||
save_kwargs["exif"] = exif_bytes
|
||||
except Exception as e:
|
||||
print(f"Error adding EXIF data: {e}")
|
||||
logger.error(f"Error adding EXIF data: {e}")
|
||||
|
||||
img.save(file_path, format="WEBP", **save_kwargs)
|
||||
|
||||
@@ -414,7 +417,7 @@ class SaveImageLM:
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error saving image: {e}")
|
||||
logger.error(f"Error saving image: {e}")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
@@ -30,6 +30,7 @@ ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
||||
RouteDefinition("POST", "/api/lm/force-download-example-images", "force_download_example_images"),
|
||||
RouteDefinition("POST", "/api/lm/cleanup-example-image-folders", "cleanup_example_image_folders"),
|
||||
RouteDefinition("POST", "/api/lm/example-images/set-nsfw-level", "set_example_image_nsfw_level"),
|
||||
RouteDefinition("POST", "/api/lm/check-example-images-needed", "check_example_images_needed"),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -92,6 +92,19 @@ class ExampleImagesDownloadHandler:
|
||||
except ExampleImagesDownloadError as exc:
|
||||
return web.json_response({'success': False, 'error': str(exc)}, status=500)
|
||||
|
||||
async def check_example_images_needed(self, request: web.Request) -> web.StreamResponse:
|
||||
"""Lightweight check to see if any models need example images downloaded."""
|
||||
try:
|
||||
payload = await request.json()
|
||||
model_types = payload.get('model_types', ['lora', 'checkpoint', 'embedding'])
|
||||
result = await self._download_manager.check_pending_models(model_types)
|
||||
return web.json_response(result)
|
||||
except Exception as exc:
|
||||
return web.json_response(
|
||||
{'success': False, 'error': str(exc)},
|
||||
status=500
|
||||
)
|
||||
|
||||
|
||||
class ExampleImagesManagementHandler:
|
||||
"""HTTP adapters for import/delete endpoints."""
|
||||
@@ -161,6 +174,7 @@ class ExampleImagesHandlerSet:
|
||||
"resume_example_images": self.download.resume_example_images,
|
||||
"stop_example_images": self.download.stop_example_images,
|
||||
"force_download_example_images": self.download.force_download_example_images,
|
||||
"check_example_images_needed": self.download.check_example_images_needed,
|
||||
"import_example_images": self.management.import_example_images,
|
||||
"delete_example_image": self.management.delete_example_image,
|
||||
"set_example_image_nsfw_level": self.management.set_example_image_nsfw_level,
|
||||
|
||||
259
py/services/cache_entry_validator.py
Normal file
259
py/services/cache_entry_validator.py
Normal file
@@ -0,0 +1,259 @@
|
||||
"""
|
||||
Cache Entry Validator
|
||||
|
||||
Validates and repairs cache entries to prevent runtime errors from
|
||||
missing or invalid critical fields.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
import logging
|
||||
import os
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationResult:
|
||||
"""Result of validating a single cache entry."""
|
||||
is_valid: bool
|
||||
repaired: bool
|
||||
errors: List[str] = field(default_factory=list)
|
||||
entry: Optional[Dict[str, Any]] = None
|
||||
|
||||
|
||||
class CacheEntryValidator:
|
||||
"""
|
||||
Validates and repairs cache entry core fields.
|
||||
|
||||
Critical fields that cause runtime errors when missing:
|
||||
- file_path: KeyError in multiple locations
|
||||
- sha256: KeyError/AttributeError in hash operations
|
||||
|
||||
Medium severity fields that may cause sorting/display issues:
|
||||
- size: KeyError during sorting
|
||||
- modified: KeyError during sorting
|
||||
- model_name: AttributeError on .lower() calls
|
||||
|
||||
Low severity fields:
|
||||
- tags: KeyError/TypeError in recipe operations
|
||||
"""
|
||||
|
||||
# Field definitions: (default_value, is_required)
|
||||
CORE_FIELDS: Dict[str, Tuple[Any, bool]] = {
|
||||
'file_path': ('', True),
|
||||
'sha256': ('', True),
|
||||
'file_name': ('', False),
|
||||
'model_name': ('', False),
|
||||
'folder': ('', False),
|
||||
'size': (0, False),
|
||||
'modified': (0.0, False),
|
||||
'tags': ([], False),
|
||||
'preview_url': ('', False),
|
||||
'base_model': ('', False),
|
||||
'from_civitai': (True, False),
|
||||
'favorite': (False, False),
|
||||
'exclude': (False, False),
|
||||
'db_checked': (False, False),
|
||||
'preview_nsfw_level': (0, False),
|
||||
'notes': ('', False),
|
||||
'usage_tips': ('', False),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def validate(cls, entry: Dict[str, Any], *, auto_repair: bool = True) -> ValidationResult:
|
||||
"""
|
||||
Validate a single cache entry.
|
||||
|
||||
Args:
|
||||
entry: The cache entry dictionary to validate
|
||||
auto_repair: If True, attempt to repair missing/invalid fields
|
||||
|
||||
Returns:
|
||||
ValidationResult with validation status and optionally repaired entry
|
||||
"""
|
||||
if entry is None:
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
repaired=False,
|
||||
errors=['Entry is None'],
|
||||
entry=None
|
||||
)
|
||||
|
||||
if not isinstance(entry, dict):
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
repaired=False,
|
||||
errors=[f'Entry is not a dict: {type(entry).__name__}'],
|
||||
entry=None
|
||||
)
|
||||
|
||||
errors: List[str] = []
|
||||
repaired = False
|
||||
working_entry = dict(entry) if auto_repair else entry
|
||||
|
||||
for field_name, (default_value, is_required) in cls.CORE_FIELDS.items():
|
||||
value = working_entry.get(field_name)
|
||||
|
||||
# Check if field is missing or None
|
||||
if value is None:
|
||||
if is_required:
|
||||
errors.append(f"Required field '{field_name}' is missing or None")
|
||||
if auto_repair:
|
||||
working_entry[field_name] = cls._get_default_copy(default_value)
|
||||
repaired = True
|
||||
continue
|
||||
|
||||
# Validate field type and value
|
||||
field_error = cls._validate_field(field_name, value, default_value)
|
||||
if field_error:
|
||||
errors.append(field_error)
|
||||
if auto_repair:
|
||||
working_entry[field_name] = cls._get_default_copy(default_value)
|
||||
repaired = True
|
||||
|
||||
# Special validation: file_path must not be empty for required field
|
||||
file_path = working_entry.get('file_path', '')
|
||||
if not file_path or (isinstance(file_path, str) and not file_path.strip()):
|
||||
errors.append("Required field 'file_path' is empty")
|
||||
# Cannot repair empty file_path - entry is invalid
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
repaired=repaired,
|
||||
errors=errors,
|
||||
entry=working_entry if auto_repair else None
|
||||
)
|
||||
|
||||
# Special validation: sha256 must not be empty for required field
|
||||
sha256 = working_entry.get('sha256', '')
|
||||
if not sha256 or (isinstance(sha256, str) and not sha256.strip()):
|
||||
errors.append("Required field 'sha256' is empty")
|
||||
# Cannot repair empty sha256 - entry is invalid
|
||||
return ValidationResult(
|
||||
is_valid=False,
|
||||
repaired=repaired,
|
||||
errors=errors,
|
||||
entry=working_entry if auto_repair else None
|
||||
)
|
||||
|
||||
# Normalize sha256 to lowercase if needed
|
||||
if isinstance(sha256, str):
|
||||
normalized_sha = sha256.lower().strip()
|
||||
if normalized_sha != sha256:
|
||||
working_entry['sha256'] = normalized_sha
|
||||
repaired = True
|
||||
|
||||
# Determine if entry is valid
|
||||
# Entry is valid if no critical required field errors remain after repair
|
||||
# Critical fields are file_path and sha256
|
||||
CRITICAL_REQUIRED_FIELDS = {'file_path', 'sha256'}
|
||||
has_critical_errors = any(
|
||||
"Required field" in error and
|
||||
any(f"'{field}'" in error for field in CRITICAL_REQUIRED_FIELDS)
|
||||
for error in errors
|
||||
)
|
||||
|
||||
is_valid = not has_critical_errors
|
||||
|
||||
return ValidationResult(
|
||||
is_valid=is_valid,
|
||||
repaired=repaired,
|
||||
errors=errors,
|
||||
entry=working_entry if auto_repair else entry
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate_batch(
|
||||
cls,
|
||||
entries: List[Dict[str, Any]],
|
||||
*,
|
||||
auto_repair: bool = True
|
||||
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
|
||||
"""
|
||||
Validate a batch of cache entries.
|
||||
|
||||
Args:
|
||||
entries: List of cache entry dictionaries to validate
|
||||
auto_repair: If True, attempt to repair missing/invalid fields
|
||||
|
||||
Returns:
|
||||
Tuple of (valid_entries, invalid_entries)
|
||||
"""
|
||||
if not entries:
|
||||
return [], []
|
||||
|
||||
valid_entries: List[Dict[str, Any]] = []
|
||||
invalid_entries: List[Dict[str, Any]] = []
|
||||
|
||||
for entry in entries:
|
||||
result = cls.validate(entry, auto_repair=auto_repair)
|
||||
|
||||
if result.is_valid:
|
||||
# Use repaired entry if available, otherwise original
|
||||
valid_entries.append(result.entry if result.entry else entry)
|
||||
else:
|
||||
invalid_entries.append(entry)
|
||||
# Log invalid entries for debugging
|
||||
file_path = entry.get('file_path', '<unknown>') if isinstance(entry, dict) else '<not a dict>'
|
||||
logger.warning(
|
||||
f"Invalid cache entry for '{file_path}': {', '.join(result.errors)}"
|
||||
)
|
||||
|
||||
return valid_entries, invalid_entries
|
||||
|
||||
@classmethod
|
||||
def _validate_field(cls, field_name: str, value: Any, default_value: Any) -> Optional[str]:
|
||||
"""
|
||||
Validate a specific field value.
|
||||
|
||||
Returns an error message if invalid, None if valid.
|
||||
"""
|
||||
expected_type = type(default_value)
|
||||
|
||||
# Special handling for numeric types
|
||||
if expected_type == int:
|
||||
if not isinstance(value, (int, float)):
|
||||
return f"Field '{field_name}' should be numeric, got {type(value).__name__}"
|
||||
elif expected_type == float:
|
||||
if not isinstance(value, (int, float)):
|
||||
return f"Field '{field_name}' should be numeric, got {type(value).__name__}"
|
||||
elif expected_type == bool:
|
||||
# Be lenient with boolean fields - accept truthy/falsy values
|
||||
pass
|
||||
elif expected_type == str:
|
||||
if not isinstance(value, str):
|
||||
return f"Field '{field_name}' should be string, got {type(value).__name__}"
|
||||
elif expected_type == list:
|
||||
if not isinstance(value, (list, tuple)):
|
||||
return f"Field '{field_name}' should be list, got {type(value).__name__}"
|
||||
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _get_default_copy(cls, default_value: Any) -> Any:
|
||||
"""Get a copy of the default value to avoid shared mutable state."""
|
||||
if isinstance(default_value, list):
|
||||
return list(default_value)
|
||||
if isinstance(default_value, dict):
|
||||
return dict(default_value)
|
||||
return default_value
|
||||
|
||||
@classmethod
|
||||
def get_file_path_safe(cls, entry: Dict[str, Any], default: str = '') -> str:
|
||||
"""Safely get file_path from an entry."""
|
||||
if not isinstance(entry, dict):
|
||||
return default
|
||||
value = entry.get('file_path')
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
return default
|
||||
|
||||
@classmethod
|
||||
def get_sha256_safe(cls, entry: Dict[str, Any], default: str = '') -> str:
|
||||
"""Safely get sha256 from an entry."""
|
||||
if not isinstance(entry, dict):
|
||||
return default
|
||||
value = entry.get('sha256')
|
||||
if isinstance(value, str):
|
||||
return value.lower()
|
||||
return default
|
||||
201
py/services/cache_health_monitor.py
Normal file
201
py/services/cache_health_monitor.py
Normal file
@@ -0,0 +1,201 @@
|
||||
"""
|
||||
Cache Health Monitor
|
||||
|
||||
Monitors cache health status and determines when user intervention is needed.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Optional
|
||||
import logging
|
||||
|
||||
from .cache_entry_validator import CacheEntryValidator, ValidationResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CacheHealthStatus(Enum):
|
||||
"""Health status of the cache."""
|
||||
HEALTHY = "healthy"
|
||||
DEGRADED = "degraded"
|
||||
CORRUPTED = "corrupted"
|
||||
|
||||
|
||||
@dataclass
|
||||
class HealthReport:
|
||||
"""Report of cache health check."""
|
||||
status: CacheHealthStatus
|
||||
total_entries: int
|
||||
valid_entries: int
|
||||
invalid_entries: int
|
||||
repaired_entries: int
|
||||
invalid_paths: List[str] = field(default_factory=list)
|
||||
message: str = ""
|
||||
|
||||
@property
|
||||
def corruption_rate(self) -> float:
|
||||
"""Calculate the percentage of invalid entries."""
|
||||
if self.total_entries <= 0:
|
||||
return 0.0
|
||||
return self.invalid_entries / self.total_entries
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert to dictionary for JSON serialization."""
|
||||
return {
|
||||
'status': self.status.value,
|
||||
'total_entries': self.total_entries,
|
||||
'valid_entries': self.valid_entries,
|
||||
'invalid_entries': self.invalid_entries,
|
||||
'repaired_entries': self.repaired_entries,
|
||||
'corruption_rate': f"{self.corruption_rate:.1%}",
|
||||
'invalid_paths': self.invalid_paths[:10], # Limit to first 10
|
||||
'message': self.message,
|
||||
}
|
||||
|
||||
|
||||
class CacheHealthMonitor:
|
||||
"""
|
||||
Monitors cache health and determines appropriate status.
|
||||
|
||||
Thresholds:
|
||||
- HEALTHY: 0% invalid entries
|
||||
- DEGRADED: 0-5% invalid entries (auto-repaired, user should rebuild)
|
||||
- CORRUPTED: >5% invalid entries (significant data loss likely)
|
||||
"""
|
||||
|
||||
# Threshold percentages
|
||||
DEGRADED_THRESHOLD = 0.01 # 1% - show warning
|
||||
CORRUPTED_THRESHOLD = 0.05 # 5% - critical warning
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
degraded_threshold: float = DEGRADED_THRESHOLD,
|
||||
corrupted_threshold: float = CORRUPTED_THRESHOLD
|
||||
):
|
||||
"""
|
||||
Initialize the health monitor.
|
||||
|
||||
Args:
|
||||
degraded_threshold: Corruption rate threshold for DEGRADED status
|
||||
corrupted_threshold: Corruption rate threshold for CORRUPTED status
|
||||
"""
|
||||
self.degraded_threshold = degraded_threshold
|
||||
self.corrupted_threshold = corrupted_threshold
|
||||
|
||||
def check_health(
|
||||
self,
|
||||
entries: List[Dict[str, Any]],
|
||||
*,
|
||||
auto_repair: bool = True
|
||||
) -> HealthReport:
|
||||
"""
|
||||
Check the health of cache entries.
|
||||
|
||||
Args:
|
||||
entries: List of cache entry dictionaries to check
|
||||
auto_repair: If True, attempt to repair entries during validation
|
||||
|
||||
Returns:
|
||||
HealthReport with status and statistics
|
||||
"""
|
||||
if not entries:
|
||||
return HealthReport(
|
||||
status=CacheHealthStatus.HEALTHY,
|
||||
total_entries=0,
|
||||
valid_entries=0,
|
||||
invalid_entries=0,
|
||||
repaired_entries=0,
|
||||
message="Cache is empty"
|
||||
)
|
||||
|
||||
total_entries = len(entries)
|
||||
valid_entries: List[Dict[str, Any]] = []
|
||||
invalid_entries: List[Dict[str, Any]] = []
|
||||
repaired_count = 0
|
||||
invalid_paths: List[str] = []
|
||||
|
||||
for entry in entries:
|
||||
result = CacheEntryValidator.validate(entry, auto_repair=auto_repair)
|
||||
|
||||
if result.is_valid:
|
||||
valid_entries.append(result.entry if result.entry else entry)
|
||||
if result.repaired:
|
||||
repaired_count += 1
|
||||
else:
|
||||
invalid_entries.append(entry)
|
||||
# Extract file path for reporting
|
||||
file_path = CacheEntryValidator.get_file_path_safe(entry, '<unknown>')
|
||||
invalid_paths.append(file_path)
|
||||
|
||||
invalid_count = len(invalid_entries)
|
||||
valid_count = len(valid_entries)
|
||||
|
||||
# Determine status based on corruption rate
|
||||
corruption_rate = invalid_count / total_entries if total_entries > 0 else 0.0
|
||||
|
||||
if invalid_count == 0:
|
||||
status = CacheHealthStatus.HEALTHY
|
||||
message = "Cache is healthy"
|
||||
elif corruption_rate >= self.corrupted_threshold:
|
||||
status = CacheHealthStatus.CORRUPTED
|
||||
message = (
|
||||
f"Cache is corrupted: {invalid_count} invalid entries "
|
||||
f"({corruption_rate:.1%}). Rebuild recommended."
|
||||
)
|
||||
elif corruption_rate >= self.degraded_threshold or invalid_count > 0:
|
||||
status = CacheHealthStatus.DEGRADED
|
||||
message = (
|
||||
f"Cache has {invalid_count} invalid entries "
|
||||
f"({corruption_rate:.1%}). Consider rebuilding cache."
|
||||
)
|
||||
else:
|
||||
# This shouldn't happen, but handle gracefully
|
||||
status = CacheHealthStatus.HEALTHY
|
||||
message = "Cache is healthy"
|
||||
|
||||
# Log the health check result
|
||||
if status != CacheHealthStatus.HEALTHY:
|
||||
logger.warning(
|
||||
f"Cache health check: {status.value} - "
|
||||
f"{invalid_count}/{total_entries} invalid, "
|
||||
f"{repaired_count} repaired"
|
||||
)
|
||||
if invalid_paths:
|
||||
logger.debug(f"Invalid entry paths: {invalid_paths[:5]}")
|
||||
|
||||
return HealthReport(
|
||||
status=status,
|
||||
total_entries=total_entries,
|
||||
valid_entries=valid_count,
|
||||
invalid_entries=invalid_count,
|
||||
repaired_entries=repaired_count,
|
||||
invalid_paths=invalid_paths,
|
||||
message=message
|
||||
)
|
||||
|
||||
def should_notify_user(self, report: HealthReport) -> bool:
|
||||
"""
|
||||
Determine if the user should be notified about cache health.
|
||||
|
||||
Args:
|
||||
report: The health report to evaluate
|
||||
|
||||
Returns:
|
||||
True if user should be notified
|
||||
"""
|
||||
return report.status != CacheHealthStatus.HEALTHY
|
||||
|
||||
def get_notification_severity(self, report: HealthReport) -> str:
|
||||
"""
|
||||
Get the severity level for user notification.
|
||||
|
||||
Args:
|
||||
report: The health report to evaluate
|
||||
|
||||
Returns:
|
||||
Severity string: 'warning' or 'error'
|
||||
"""
|
||||
if report.status == CacheHealthStatus.CORRUPTED:
|
||||
return 'error'
|
||||
return 'warning'
|
||||
@@ -30,36 +30,36 @@ class LoraScanner(ModelScanner):
|
||||
|
||||
async def diagnose_hash_index(self):
|
||||
"""Diagnostic method to verify hash index functionality"""
|
||||
print("\n\n*** DIAGNOSING LORA HASH INDEX ***\n\n", file=sys.stderr)
|
||||
logger.debug("\n\n*** DIAGNOSING LORA HASH INDEX ***\n\n")
|
||||
|
||||
# First check if the hash index has any entries
|
||||
if hasattr(self, '_hash_index'):
|
||||
index_entries = len(self._hash_index._hash_to_path)
|
||||
print(f"Hash index has {index_entries} entries", file=sys.stderr)
|
||||
logger.debug(f"Hash index has {index_entries} entries")
|
||||
|
||||
# Print a few example entries if available
|
||||
if index_entries > 0:
|
||||
print("\nSample hash index entries:", file=sys.stderr)
|
||||
logger.debug("\nSample hash index entries:")
|
||||
count = 0
|
||||
for hash_val, path in self._hash_index._hash_to_path.items():
|
||||
if count < 5: # Just show the first 5
|
||||
print(f"Hash: {hash_val[:8]}... -> Path: {path}", file=sys.stderr)
|
||||
logger.debug(f"Hash: {hash_val[:8]}... -> Path: {path}")
|
||||
count += 1
|
||||
else:
|
||||
break
|
||||
else:
|
||||
print("Hash index not initialized", file=sys.stderr)
|
||||
logger.debug("Hash index not initialized")
|
||||
|
||||
# Try looking up by a known hash for testing
|
||||
if not hasattr(self, '_hash_index') or not self._hash_index._hash_to_path:
|
||||
print("No hash entries to test lookup with", file=sys.stderr)
|
||||
logger.debug("No hash entries to test lookup with")
|
||||
return
|
||||
|
||||
test_hash = next(iter(self._hash_index._hash_to_path.keys()))
|
||||
test_path = self._hash_index.get_path(test_hash)
|
||||
print(f"\nTest lookup by hash: {test_hash[:8]}... -> {test_path}", file=sys.stderr)
|
||||
logger.debug(f"\nTest lookup by hash: {test_hash[:8]}... -> {test_path}")
|
||||
|
||||
# Also test reverse lookup
|
||||
test_hash_result = self._hash_index.get_hash(test_path)
|
||||
print(f"Test reverse lookup: {test_path} -> {test_hash_result[:8]}...\n\n", file=sys.stderr)
|
||||
logger.debug(f"Test reverse lookup: {test_path} -> {test_hash_result[:8]}...\n\n")
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from dataclasses import dataclass, field
|
||||
from operator import itemgetter
|
||||
from natsort import natsorted
|
||||
|
||||
# Supported sort modes: (sort_key, order)
|
||||
@@ -229,17 +228,17 @@ class ModelCache:
|
||||
reverse=reverse
|
||||
)
|
||||
elif sort_key == 'date':
|
||||
# Sort by modified timestamp
|
||||
# Sort by modified timestamp (use .get() with default to handle missing fields)
|
||||
result = sorted(
|
||||
data,
|
||||
key=itemgetter('modified'),
|
||||
key=lambda x: x.get('modified', 0.0),
|
||||
reverse=reverse
|
||||
)
|
||||
elif sort_key == 'size':
|
||||
# Sort by file size
|
||||
# Sort by file size (use .get() with default to handle missing fields)
|
||||
result = sorted(
|
||||
data,
|
||||
key=itemgetter('size'),
|
||||
key=lambda x: x.get('size', 0),
|
||||
reverse=reverse
|
||||
)
|
||||
elif sort_key == 'usage':
|
||||
|
||||
@@ -20,6 +20,8 @@ from .service_registry import ServiceRegistry
|
||||
from .websocket_manager import ws_manager
|
||||
from .persistent_model_cache import get_persistent_cache
|
||||
from .settings_manager import get_settings_manager
|
||||
from .cache_entry_validator import CacheEntryValidator
|
||||
from .cache_health_monitor import CacheHealthMonitor, CacheHealthStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -468,6 +470,39 @@ class ModelScanner:
|
||||
for tag in adjusted_item.get('tags') or []:
|
||||
tags_count[tag] = tags_count.get(tag, 0) + 1
|
||||
|
||||
# Validate cache entries and check health
|
||||
valid_entries, invalid_entries = CacheEntryValidator.validate_batch(
|
||||
adjusted_raw_data, auto_repair=True
|
||||
)
|
||||
|
||||
if invalid_entries:
|
||||
monitor = CacheHealthMonitor()
|
||||
report = monitor.check_health(adjusted_raw_data, auto_repair=True)
|
||||
|
||||
if report.status != CacheHealthStatus.HEALTHY:
|
||||
# Broadcast health warning to frontend
|
||||
await ws_manager.broadcast_cache_health_warning(report, page_type)
|
||||
logger.warning(
|
||||
f"{self.model_type.capitalize()} Scanner: Cache health issue detected - "
|
||||
f"{report.invalid_entries} invalid entries, {report.repaired_entries} repaired"
|
||||
)
|
||||
|
||||
# Use only valid entries
|
||||
adjusted_raw_data = valid_entries
|
||||
|
||||
# Rebuild tags count from valid entries only
|
||||
tags_count = {}
|
||||
for item in adjusted_raw_data:
|
||||
for tag in item.get('tags') or []:
|
||||
tags_count[tag] = tags_count.get(tag, 0) + 1
|
||||
|
||||
# Remove invalid entries from hash index
|
||||
for invalid_entry in invalid_entries:
|
||||
file_path = CacheEntryValidator.get_file_path_safe(invalid_entry)
|
||||
sha256 = CacheEntryValidator.get_sha256_safe(invalid_entry)
|
||||
if file_path:
|
||||
hash_index.remove_by_path(file_path, sha256)
|
||||
|
||||
scan_result = CacheBuildResult(
|
||||
raw_data=adjusted_raw_data,
|
||||
hash_index=hash_index,
|
||||
@@ -651,7 +686,6 @@ class ModelScanner:
|
||||
|
||||
async def _initialize_cache(self) -> None:
|
||||
"""Initialize or refresh the cache"""
|
||||
print("init start", flush=True)
|
||||
self._is_initializing = True # Set flag
|
||||
try:
|
||||
start_time = time.time()
|
||||
@@ -665,7 +699,6 @@ class ModelScanner:
|
||||
scan_result = await self._gather_model_data()
|
||||
await self._apply_scan_result(scan_result)
|
||||
await self._save_persistent_cache(scan_result)
|
||||
print("init end", flush=True)
|
||||
|
||||
logger.info(
|
||||
f"{self.model_type.capitalize()} Scanner: Cache initialization completed in {time.time() - start_time:.2f} seconds, "
|
||||
@@ -776,6 +809,18 @@ class ModelScanner:
|
||||
model_data = self.adjust_cached_entry(dict(model_data))
|
||||
if not model_data:
|
||||
continue
|
||||
|
||||
# Validate the new entry before adding
|
||||
validation_result = CacheEntryValidator.validate(
|
||||
model_data, auto_repair=True
|
||||
)
|
||||
if not validation_result.is_valid:
|
||||
logger.warning(
|
||||
f"Skipping invalid entry during reconcile: {path}"
|
||||
)
|
||||
continue
|
||||
model_data = validation_result.entry
|
||||
|
||||
self._ensure_license_flags(model_data)
|
||||
# Add to cache
|
||||
self._cache.raw_data.append(model_data)
|
||||
@@ -1090,6 +1135,17 @@ class ModelScanner:
|
||||
processed_files += 1
|
||||
|
||||
if result:
|
||||
# Validate the entry before adding
|
||||
validation_result = CacheEntryValidator.validate(
|
||||
result, auto_repair=True
|
||||
)
|
||||
if not validation_result.is_valid:
|
||||
logger.warning(
|
||||
f"Skipping invalid scan result: {file_path}"
|
||||
)
|
||||
continue
|
||||
result = validation_result.entry
|
||||
|
||||
self._ensure_license_flags(result)
|
||||
raw_data.append(result)
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple
|
||||
from ..config import config
|
||||
from .recipe_cache import RecipeCache
|
||||
from .recipe_fts_index import RecipeFTSIndex
|
||||
from .persistent_recipe_cache import PersistentRecipeCache, get_persistent_recipe_cache
|
||||
from .persistent_recipe_cache import PersistentRecipeCache, get_persistent_recipe_cache, PersistedRecipeData
|
||||
from .service_registry import ServiceRegistry
|
||||
from .lora_scanner import LoraScanner
|
||||
from .metadata_service import get_default_metadata_provider
|
||||
@@ -431,6 +431,16 @@ class RecipeScanner:
|
||||
4. Persist results for next startup
|
||||
"""
|
||||
try:
|
||||
# Ensure cache exists to avoid None reference errors
|
||||
if self._cache is None:
|
||||
self._cache = RecipeCache(
|
||||
raw_data=[],
|
||||
sorted_by_name=[],
|
||||
sorted_by_date=[],
|
||||
folders=[],
|
||||
folder_tree={},
|
||||
)
|
||||
|
||||
# Create a new event loop for this thread
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
@@ -492,7 +502,7 @@ class RecipeScanner:
|
||||
|
||||
def _reconcile_recipe_cache(
|
||||
self,
|
||||
persisted: "PersistedRecipeData",
|
||||
persisted: PersistedRecipeData,
|
||||
recipes_dir: str,
|
||||
) -> Tuple[List[Dict], bool, Dict[str, str]]:
|
||||
"""Reconcile persisted cache with current filesystem state.
|
||||
@@ -504,8 +514,6 @@ class RecipeScanner:
|
||||
Returns:
|
||||
Tuple of (recipes list, changed flag, json_paths dict).
|
||||
"""
|
||||
from .persistent_recipe_cache import PersistedRecipeData
|
||||
|
||||
recipes: List[Dict] = []
|
||||
json_paths: Dict[str, str] = {}
|
||||
changed = False
|
||||
@@ -522,32 +530,37 @@ class RecipeScanner:
|
||||
except OSError:
|
||||
continue
|
||||
|
||||
# Build lookup of persisted recipes by json_path
|
||||
persisted_by_path: Dict[str, Dict] = {}
|
||||
for recipe in persisted.raw_data:
|
||||
recipe_id = str(recipe.get('id', ''))
|
||||
if recipe_id:
|
||||
# Find the json_path from file_stats
|
||||
for json_path, (mtime, size) in persisted.file_stats.items():
|
||||
if os.path.basename(json_path).startswith(recipe_id):
|
||||
persisted_by_path[json_path] = recipe
|
||||
break
|
||||
|
||||
# Also index by recipe ID for faster lookups
|
||||
persisted_by_id: Dict[str, Dict] = {
|
||||
# Build recipe_id -> recipe lookup (O(n) instead of O(n²))
|
||||
recipe_by_id: Dict[str, Dict] = {
|
||||
str(r.get('id', '')): r for r in persisted.raw_data if r.get('id')
|
||||
}
|
||||
|
||||
# Build json_path -> recipe lookup from file_stats (O(m))
|
||||
persisted_by_path: Dict[str, Dict] = {}
|
||||
for json_path in persisted.file_stats.keys():
|
||||
basename = os.path.basename(json_path)
|
||||
if basename.lower().endswith('.recipe.json'):
|
||||
recipe_id = basename[:-len('.recipe.json')]
|
||||
if recipe_id in recipe_by_id:
|
||||
persisted_by_path[json_path] = recipe_by_id[recipe_id]
|
||||
|
||||
# Process current files
|
||||
for file_path, (current_mtime, current_size) in current_files.items():
|
||||
cached_stats = persisted.file_stats.get(file_path)
|
||||
|
||||
# Extract recipe_id from current file for fallback lookup
|
||||
basename = os.path.basename(file_path)
|
||||
recipe_id_from_file = basename[:-len('.recipe.json')] if basename.lower().endswith('.recipe.json') else None
|
||||
|
||||
if cached_stats:
|
||||
cached_mtime, cached_size = cached_stats
|
||||
# Check if file is unchanged
|
||||
if abs(current_mtime - cached_mtime) < 1.0 and current_size == cached_size:
|
||||
# Use cached data
|
||||
# Try direct path lookup first
|
||||
cached_recipe = persisted_by_path.get(file_path)
|
||||
# Fallback to recipe_id lookup if path lookup fails
|
||||
if not cached_recipe and recipe_id_from_file:
|
||||
cached_recipe = recipe_by_id.get(recipe_id_from_file)
|
||||
if cached_recipe:
|
||||
recipe_id = str(cached_recipe.get('id', ''))
|
||||
# Track folder from file path
|
||||
|
||||
@@ -63,7 +63,7 @@ DEFAULT_SETTINGS: Dict[str, Any] = {
|
||||
"compact_mode": False,
|
||||
"priority_tags": DEFAULT_PRIORITY_TAG_CONFIG.copy(),
|
||||
"model_name_display": "model_name",
|
||||
"model_card_footer_action": "example_images",
|
||||
"model_card_footer_action": "replace_preview",
|
||||
"update_flag_strategy": "same_base",
|
||||
"auto_organize_exclusions": [],
|
||||
}
|
||||
|
||||
@@ -48,9 +48,14 @@ class BulkMetadataRefreshUseCase:
|
||||
for model in cache.raw_data
|
||||
if model.get("sha256")
|
||||
and (not model.get("civitai") or not model["civitai"].get("id"))
|
||||
and (
|
||||
(enable_metadata_archive_db and not model.get("db_checked", False))
|
||||
or (not enable_metadata_archive_db and model.get("from_civitai") is True)
|
||||
and not (
|
||||
# Skip models confirmed not on CivitAI when no need to retry
|
||||
model.get("from_civitai") is False
|
||||
and model.get("civitai_deleted") is True
|
||||
and (
|
||||
not enable_metadata_archive_db
|
||||
or model.get("db_checked", False)
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@@ -255,6 +255,42 @@ class WebSocketManager:
|
||||
self._download_progress.pop(download_id, None)
|
||||
logger.debug(f"Cleaned up old download progress for {download_id}")
|
||||
|
||||
async def broadcast_cache_health_warning(self, report: 'HealthReport', page_type: str = None):
|
||||
"""
|
||||
Broadcast cache health warning to frontend.
|
||||
|
||||
Args:
|
||||
report: HealthReport instance from CacheHealthMonitor
|
||||
page_type: The page type (loras, checkpoints, embeddings)
|
||||
"""
|
||||
from .cache_health_monitor import CacheHealthStatus
|
||||
|
||||
# Only broadcast if there are issues
|
||||
if report.status == CacheHealthStatus.HEALTHY:
|
||||
return
|
||||
|
||||
payload = {
|
||||
'type': 'cache_health_warning',
|
||||
'status': report.status.value,
|
||||
'message': report.message,
|
||||
'pageType': page_type,
|
||||
'details': {
|
||||
'total': report.total_entries,
|
||||
'valid': report.valid_entries,
|
||||
'invalid': report.invalid_entries,
|
||||
'repaired': report.repaired_entries,
|
||||
'corruption_rate': f"{report.corruption_rate:.1%}",
|
||||
'invalid_paths': report.invalid_paths[:5], # Limit to first 5
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"Broadcasting cache health warning: {report.status.value} "
|
||||
f"({report.invalid_entries} invalid entries)"
|
||||
)
|
||||
|
||||
await self.broadcast(payload)
|
||||
|
||||
def get_connected_clients_count(self) -> int:
|
||||
"""Get number of connected clients"""
|
||||
return len(self._websockets)
|
||||
|
||||
@@ -216,6 +216,11 @@ class DownloadManager:
|
||||
self._progress["failed_models"] = set()
|
||||
|
||||
self._is_downloading = True
|
||||
snapshot = self._progress.snapshot()
|
||||
|
||||
# Create the download task without awaiting it
|
||||
# This ensures the HTTP response is returned immediately
|
||||
# while the actual processing happens in the background
|
||||
self._download_task = asyncio.create_task(
|
||||
self._download_all_example_images(
|
||||
output_dir,
|
||||
@@ -227,7 +232,10 @@ class DownloadManager:
|
||||
)
|
||||
)
|
||||
|
||||
snapshot = self._progress.snapshot()
|
||||
# Add a callback to handle task completion/errors
|
||||
self._download_task.add_done_callback(
|
||||
lambda t: self._handle_download_task_done(t, output_dir)
|
||||
)
|
||||
except ExampleImagesDownloadError:
|
||||
# Re-raise our own exception types without wrapping
|
||||
self._is_downloading = False
|
||||
@@ -241,10 +249,25 @@ class DownloadManager:
|
||||
)
|
||||
raise ExampleImagesDownloadError(str(e)) from e
|
||||
|
||||
await self._broadcast_progress(status="running")
|
||||
# Broadcast progress in the background without blocking the response
|
||||
# This ensures the HTTP response is returned immediately
|
||||
asyncio.create_task(self._broadcast_progress(status="running"))
|
||||
|
||||
return {"success": True, "message": "Download started", "status": snapshot}
|
||||
|
||||
def _handle_download_task_done(self, task: asyncio.Task, output_dir: str) -> None:
|
||||
"""Handle download task completion, including saving progress on error."""
|
||||
try:
|
||||
# This will re-raise any exception from the task
|
||||
task.result()
|
||||
except Exception as e:
|
||||
logger.error(f"Download task failed with error: {e}", exc_info=True)
|
||||
# Ensure progress is saved even on failure
|
||||
try:
|
||||
self._save_progress(output_dir)
|
||||
except Exception as save_error:
|
||||
logger.error(f"Failed to save progress after task failure: {save_error}")
|
||||
|
||||
async def get_status(self, request):
|
||||
"""Get the current status of example images download."""
|
||||
|
||||
@@ -254,6 +277,130 @@ class DownloadManager:
|
||||
"status": self._progress.snapshot(),
|
||||
}
|
||||
|
||||
async def check_pending_models(self, model_types: list[str]) -> dict:
|
||||
"""Quickly check how many models need example images downloaded.
|
||||
|
||||
This is a lightweight check that avoids the overhead of starting
|
||||
a full download task when no work is needed.
|
||||
|
||||
Returns:
|
||||
dict with keys:
|
||||
- total_models: Total number of models across specified types
|
||||
- pending_count: Number of models needing example images
|
||||
- processed_count: Number of already processed models
|
||||
- failed_count: Number of models marked as failed
|
||||
- needs_download: True if there are pending models to process
|
||||
"""
|
||||
from ..services.service_registry import ServiceRegistry
|
||||
|
||||
if self._is_downloading:
|
||||
return {
|
||||
"success": True,
|
||||
"is_downloading": True,
|
||||
"total_models": 0,
|
||||
"pending_count": 0,
|
||||
"processed_count": 0,
|
||||
"failed_count": 0,
|
||||
"needs_download": False,
|
||||
"message": "Download already in progress",
|
||||
}
|
||||
|
||||
try:
|
||||
# Get scanners
|
||||
scanners = []
|
||||
if "lora" in model_types:
|
||||
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||
scanners.append(("lora", lora_scanner))
|
||||
|
||||
if "checkpoint" in model_types:
|
||||
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
scanners.append(("checkpoint", checkpoint_scanner))
|
||||
|
||||
if "embedding" in model_types:
|
||||
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||
scanners.append(("embedding", embedding_scanner))
|
||||
|
||||
# Load progress file to check processed models
|
||||
settings_manager = get_settings_manager()
|
||||
active_library = settings_manager.get_active_library_name()
|
||||
output_dir = self._resolve_output_dir(active_library)
|
||||
|
||||
processed_models: set[str] = set()
|
||||
failed_models: set[str] = set()
|
||||
|
||||
if output_dir:
|
||||
progress_file = os.path.join(output_dir, ".download_progress.json")
|
||||
if os.path.exists(progress_file):
|
||||
try:
|
||||
with open(progress_file, "r", encoding="utf-8") as f:
|
||||
saved_progress = json.load(f)
|
||||
processed_models = set(saved_progress.get("processed_models", []))
|
||||
failed_models = set(saved_progress.get("failed_models", []))
|
||||
except Exception:
|
||||
pass # Ignore progress file errors for quick check
|
||||
|
||||
# Count models
|
||||
total_models = 0
|
||||
models_with_hash = 0
|
||||
|
||||
for scanner_type, scanner in scanners:
|
||||
cache = await scanner.get_cached_data()
|
||||
if cache and cache.raw_data:
|
||||
for model in cache.raw_data:
|
||||
total_models += 1
|
||||
if model.get("sha256"):
|
||||
models_with_hash += 1
|
||||
|
||||
# Calculate pending count
|
||||
# A model is pending if it has a hash and is not in processed_models
|
||||
# We also exclude failed_models unless force mode would be used
|
||||
pending_count = models_with_hash - len(processed_models.intersection(
|
||||
{m.get("sha256", "").lower() for scanner_type, scanner in scanners
|
||||
for m in (await scanner.get_cached_data()).raw_data if m.get("sha256")}
|
||||
))
|
||||
|
||||
# More accurate pending count: check which models actually need processing
|
||||
pending_hashes = set()
|
||||
for scanner_type, scanner in scanners:
|
||||
cache = await scanner.get_cached_data()
|
||||
if cache and cache.raw_data:
|
||||
for model in cache.raw_data:
|
||||
raw_hash = model.get("sha256")
|
||||
if not raw_hash:
|
||||
continue
|
||||
model_hash = raw_hash.lower()
|
||||
if model_hash not in processed_models:
|
||||
# Check if model folder exists with files
|
||||
model_dir = ExampleImagePathResolver.get_model_folder(
|
||||
model_hash, active_library
|
||||
)
|
||||
if not _model_directory_has_files(model_dir):
|
||||
pending_hashes.add(model_hash)
|
||||
|
||||
pending_count = len(pending_hashes)
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"is_downloading": False,
|
||||
"total_models": total_models,
|
||||
"pending_count": pending_count,
|
||||
"processed_count": len(processed_models),
|
||||
"failed_count": len(failed_models),
|
||||
"needs_download": pending_count > 0,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error checking pending models: {e}", exc_info=True)
|
||||
return {
|
||||
"success": False,
|
||||
"error": str(e),
|
||||
"total_models": 0,
|
||||
"pending_count": 0,
|
||||
"processed_count": 0,
|
||||
"failed_count": 0,
|
||||
"needs_download": False,
|
||||
}
|
||||
|
||||
async def pause_download(self, request):
|
||||
"""Pause the example images download."""
|
||||
|
||||
|
||||
@@ -43,8 +43,15 @@ class ExampleImagesProcessor:
|
||||
return media_url
|
||||
|
||||
@staticmethod
|
||||
def _get_file_extension_from_content_or_headers(content, headers, fallback_url=None):
|
||||
"""Determine file extension from content magic bytes or headers"""
|
||||
def _get_file_extension_from_content_or_headers(content, headers, fallback_url=None, media_type_hint=None):
|
||||
"""Determine file extension from content magic bytes or headers
|
||||
|
||||
Args:
|
||||
content: File content bytes
|
||||
headers: HTTP response headers
|
||||
fallback_url: Original URL for extension extraction
|
||||
media_type_hint: Optional media type hint from metadata (e.g., "video" or "image")
|
||||
"""
|
||||
# Check magic bytes for common formats
|
||||
if content:
|
||||
if content.startswith(b'\xFF\xD8\xFF'):
|
||||
@@ -82,6 +89,10 @@ class ExampleImagesProcessor:
|
||||
if ext in SUPPORTED_MEDIA_EXTENSIONS['images'] or ext in SUPPORTED_MEDIA_EXTENSIONS['videos']:
|
||||
return ext
|
||||
|
||||
# Use media type hint from metadata if available
|
||||
if media_type_hint == "video":
|
||||
return '.mp4'
|
||||
|
||||
# Default fallback
|
||||
return '.jpg'
|
||||
|
||||
@@ -136,7 +147,7 @@ class ExampleImagesProcessor:
|
||||
if success:
|
||||
# Determine file extension from content or headers
|
||||
media_ext = ExampleImagesProcessor._get_file_extension_from_content_or_headers(
|
||||
content, headers, original_url
|
||||
content, headers, original_url, image.get("type")
|
||||
)
|
||||
|
||||
# Check if the detected file type is supported
|
||||
@@ -219,7 +230,7 @@ class ExampleImagesProcessor:
|
||||
if success:
|
||||
# Determine file extension from content or headers
|
||||
media_ext = ExampleImagesProcessor._get_file_extension_from_content_or_headers(
|
||||
content, headers, original_url
|
||||
content, headers, original_url, image.get("type")
|
||||
)
|
||||
|
||||
# Check if the detected file type is supported
|
||||
|
||||
@@ -17,7 +17,7 @@ async def extract_lora_metadata(file_path: str) -> Dict:
|
||||
base_model = determine_base_model(metadata.get("ss_base_model_version"))
|
||||
return {"base_model": base_model}
|
||||
except Exception as e:
|
||||
print(f"Error reading metadata from {file_path}: {str(e)}")
|
||||
logger.error(f"Error reading metadata from {file_path}: {str(e)}")
|
||||
return {"base_model": "Unknown"}
|
||||
|
||||
async def extract_checkpoint_metadata(file_path: str) -> dict:
|
||||
|
||||
@@ -223,7 +223,7 @@ class MetadataManager:
|
||||
preview_url=normalize_path(preview_url),
|
||||
tags=[],
|
||||
modelDescription="",
|
||||
model_type="checkpoint",
|
||||
sub_type="checkpoint",
|
||||
from_civitai=True
|
||||
)
|
||||
elif model_class.__name__ == "EmbeddingMetadata":
|
||||
@@ -238,6 +238,7 @@ class MetadataManager:
|
||||
preview_url=normalize_path(preview_url),
|
||||
tags=[],
|
||||
modelDescription="",
|
||||
sub_type="embedding",
|
||||
from_civitai=True
|
||||
)
|
||||
else: # Default to LoraMetadata
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
[project]
|
||||
name = "comfyui-lora-manager"
|
||||
description = "Revolutionize your workflow with the ultimate LoRA companion for ComfyUI!"
|
||||
version = "0.9.13"
|
||||
version = "0.9.14"
|
||||
license = {file = "LICENSE"}
|
||||
dependencies = [
|
||||
"aiohttp",
|
||||
|
||||
0
scripts/sync_translation_keys.py
Normal file → Executable file
0
scripts/sync_translation_keys.py
Normal file → Executable file
@@ -113,6 +113,12 @@
|
||||
max-width: 110px;
|
||||
}
|
||||
|
||||
/* Compact mode: hide sub-type to save space */
|
||||
.compact-density .model-sub-type,
|
||||
.compact-density .model-separator {
|
||||
display: none;
|
||||
}
|
||||
|
||||
.compact-density .card-actions i {
|
||||
font-size: 0.95em;
|
||||
padding: 3px;
|
||||
|
||||
@@ -26,6 +26,7 @@ class RecipeCard {
|
||||
card.dataset.nsfwLevel = this.recipe.preview_nsfw_level || 0;
|
||||
card.dataset.created = this.recipe.created_date;
|
||||
card.dataset.id = this.recipe.id || '';
|
||||
card.dataset.folder = this.recipe.folder || '';
|
||||
|
||||
// Get base model with fallback
|
||||
const baseModelLabel = (this.recipe.base_model || '').trim() || 'Unknown';
|
||||
|
||||
@@ -198,6 +198,12 @@ class InitializationManager {
|
||||
handleProgressUpdate(data) {
|
||||
if (!data) return;
|
||||
console.log('Received progress update:', data);
|
||||
|
||||
// Handle cache health warning messages
|
||||
if (data.type === 'cache_health_warning') {
|
||||
this.handleCacheHealthWarning(data);
|
||||
return;
|
||||
}
|
||||
|
||||
// Check if this update is for our page type
|
||||
if (data.pageType && data.pageType !== this.pageType) {
|
||||
@@ -466,6 +472,29 @@ class InitializationManager {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle cache health warning messages from WebSocket
|
||||
*/
|
||||
handleCacheHealthWarning(data) {
|
||||
console.log('Cache health warning received:', data);
|
||||
|
||||
// Import bannerService dynamically to avoid circular dependencies
|
||||
import('../managers/BannerService.js').then(({ bannerService }) => {
|
||||
// Initialize banner service if not already done
|
||||
if (!bannerService.initialized) {
|
||||
bannerService.initialize().then(() => {
|
||||
bannerService.registerCacheHealthBanner(data);
|
||||
}).catch(err => {
|
||||
console.error('Failed to initialize banner service:', err);
|
||||
});
|
||||
} else {
|
||||
bannerService.registerCacheHealthBanner(data);
|
||||
}
|
||||
}).catch(err => {
|
||||
console.error('Failed to load banner service:', err);
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Clean up resources when the component is destroyed
|
||||
*/
|
||||
|
||||
@@ -4,9 +4,11 @@ import {
|
||||
removeStorageItem
|
||||
} from '../utils/storageHelpers.js';
|
||||
import { translate } from '../utils/i18nHelpers.js';
|
||||
import { state } from '../state/index.js'
|
||||
import { state } from '../state/index.js';
|
||||
import { getModelApiClient } from '../api/modelApiFactory.js';
|
||||
|
||||
const COMMUNITY_SUPPORT_BANNER_ID = 'community-support';
|
||||
const CACHE_HEALTH_BANNER_ID = 'cache-health-warning';
|
||||
const COMMUNITY_SUPPORT_BANNER_DELAY_MS = 5 * 24 * 60 * 60 * 1000; // 5 days
|
||||
const COMMUNITY_SUPPORT_FIRST_SEEN_AT_KEY = 'community_support_banner_first_seen_at';
|
||||
const COMMUNITY_SUPPORT_VERSION_KEY = 'community_support_banner_state_version';
|
||||
@@ -293,6 +295,177 @@ class BannerService {
|
||||
location.reload();
|
||||
}
|
||||
|
||||
/**
|
||||
* Register a cache health warning banner
|
||||
* @param {Object} healthData - Health data from WebSocket
|
||||
*/
|
||||
registerCacheHealthBanner(healthData) {
|
||||
if (!healthData || healthData.status === 'healthy') {
|
||||
return;
|
||||
}
|
||||
|
||||
// Remove existing cache health banner if any
|
||||
this.removeBannerElement(CACHE_HEALTH_BANNER_ID);
|
||||
|
||||
const isCorrupted = healthData.status === 'corrupted';
|
||||
const titleKey = isCorrupted
|
||||
? 'banners.cacheHealth.corrupted.title'
|
||||
: 'banners.cacheHealth.degraded.title';
|
||||
const defaultTitle = isCorrupted
|
||||
? 'Cache Corruption Detected'
|
||||
: 'Cache Issues Detected';
|
||||
|
||||
const title = translate(titleKey, {}, defaultTitle);
|
||||
|
||||
const contentKey = 'banners.cacheHealth.content';
|
||||
const defaultContent = 'Found {invalid} of {total} cache entries are invalid ({rate}). This may cause missing models or errors. Rebuilding the cache is recommended.';
|
||||
const content = translate(contentKey, {
|
||||
invalid: healthData.details?.invalid || 0,
|
||||
total: healthData.details?.total || 0,
|
||||
rate: healthData.details?.corruption_rate || '0%'
|
||||
}, defaultContent);
|
||||
|
||||
this.registerBanner(CACHE_HEALTH_BANNER_ID, {
|
||||
id: CACHE_HEALTH_BANNER_ID,
|
||||
title: title,
|
||||
content: content,
|
||||
pageType: healthData.pageType,
|
||||
actions: [
|
||||
{
|
||||
text: translate('banners.cacheHealth.rebuildCache', {}, 'Rebuild Cache'),
|
||||
icon: 'fas fa-sync-alt',
|
||||
action: 'rebuild-cache',
|
||||
type: 'primary'
|
||||
},
|
||||
{
|
||||
text: translate('banners.cacheHealth.dismiss', {}, 'Dismiss'),
|
||||
icon: 'fas fa-times',
|
||||
action: 'dismiss',
|
||||
type: 'secondary'
|
||||
}
|
||||
],
|
||||
dismissible: true,
|
||||
priority: 10, // High priority
|
||||
onRegister: (bannerElement) => {
|
||||
// Attach click handlers for actions
|
||||
const rebuildBtn = bannerElement.querySelector('[data-action="rebuild-cache"]');
|
||||
const dismissBtn = bannerElement.querySelector('[data-action="dismiss"]');
|
||||
|
||||
if (rebuildBtn) {
|
||||
rebuildBtn.addEventListener('click', (e) => {
|
||||
e.preventDefault();
|
||||
this.handleRebuildCache(bannerElement, healthData.pageType);
|
||||
});
|
||||
}
|
||||
|
||||
if (dismissBtn) {
|
||||
dismissBtn.addEventListener('click', (e) => {
|
||||
e.preventDefault();
|
||||
this.dismissBanner(CACHE_HEALTH_BANNER_ID);
|
||||
});
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle rebuild cache action from banner
|
||||
* @param {HTMLElement} bannerElement - The banner element
|
||||
* @param {string} pageType - The page type (loras, checkpoints, embeddings)
|
||||
*/
|
||||
async handleRebuildCache(bannerElement, pageType) {
|
||||
const currentPageType = pageType || this.getCurrentPageType();
|
||||
|
||||
try {
|
||||
const apiClient = getModelApiClient(currentPageType);
|
||||
|
||||
// Update banner to show rebuilding status
|
||||
const actionsContainer = bannerElement.querySelector('.banner-actions');
|
||||
if (actionsContainer) {
|
||||
actionsContainer.innerHTML = `
|
||||
<span class="banner-loading">
|
||||
<i class="fas fa-spinner fa-spin"></i>
|
||||
<span>${translate('banners.cacheHealth.rebuilding', {}, 'Rebuilding cache...')}</span>
|
||||
</span>
|
||||
`;
|
||||
}
|
||||
|
||||
await apiClient.refreshModels(true);
|
||||
|
||||
// Remove banner on success without marking as dismissed
|
||||
this.removeBannerElement(CACHE_HEALTH_BANNER_ID);
|
||||
} catch (error) {
|
||||
console.error('Cache rebuild failed:', error);
|
||||
|
||||
const actionsContainer = bannerElement.querySelector('.banner-actions');
|
||||
if (actionsContainer) {
|
||||
actionsContainer.innerHTML = `
|
||||
<span class="banner-error">
|
||||
<i class="fas fa-exclamation-triangle"></i>
|
||||
<span>${translate('banners.cacheHealth.rebuildFailed', {}, 'Rebuild failed. Please try again.')}</span>
|
||||
</span>
|
||||
<a href="#" class="banner-action banner-action-primary" data-action="rebuild-cache">
|
||||
<i class="fas fa-sync-alt"></i>
|
||||
<span>${translate('banners.cacheHealth.retry', {}, 'Retry')}</span>
|
||||
</a>
|
||||
`;
|
||||
|
||||
// Re-attach click handler
|
||||
const retryBtn = actionsContainer.querySelector('[data-action="rebuild-cache"]');
|
||||
if (retryBtn) {
|
||||
retryBtn.addEventListener('click', (e) => {
|
||||
e.preventDefault();
|
||||
this.handleRebuildCache(bannerElement, pageType);
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the current page type from the URL
|
||||
* @returns {string} Page type (loras, checkpoints, embeddings, recipes)
|
||||
*/
|
||||
getCurrentPageType() {
|
||||
const path = window.location.pathname;
|
||||
if (path.includes('/checkpoints')) return 'checkpoints';
|
||||
if (path.includes('/embeddings')) return 'embeddings';
|
||||
if (path.includes('/recipes')) return 'recipes';
|
||||
return 'loras';
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the rebuild cache endpoint for the given page type
|
||||
* @param {string} pageType - The page type
|
||||
* @returns {string} The API endpoint URL
|
||||
*/
|
||||
getRebuildEndpoint(pageType) {
|
||||
const endpoints = {
|
||||
'loras': '/api/lm/loras/reload?rebuild=true',
|
||||
'checkpoints': '/api/lm/checkpoints/reload?rebuild=true',
|
||||
'embeddings': '/api/lm/embeddings/reload?rebuild=true'
|
||||
};
|
||||
return endpoints[pageType] || endpoints['loras'];
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove a banner element from DOM without marking as dismissed
|
||||
* @param {string} bannerId - Banner ID to remove
|
||||
*/
|
||||
removeBannerElement(bannerId) {
|
||||
const bannerElement = document.querySelector(`[data-banner-id="${bannerId}"]`);
|
||||
if (bannerElement) {
|
||||
bannerElement.style.animation = 'banner-slide-up 0.3s ease-in-out forwards';
|
||||
setTimeout(() => {
|
||||
bannerElement.remove();
|
||||
this.updateContainerVisibility();
|
||||
}, 300);
|
||||
}
|
||||
|
||||
// Also remove from banners map
|
||||
this.banners.delete(bannerId);
|
||||
}
|
||||
|
||||
prepareCommunitySupportBanner() {
|
||||
if (this.isBannerDismissed(COMMUNITY_SUPPORT_BANNER_ID)) {
|
||||
return;
|
||||
|
||||
@@ -21,7 +21,7 @@ export class ExampleImagesManager {
|
||||
// Auto download properties
|
||||
this.autoDownloadInterval = null;
|
||||
this.lastAutoDownloadCheck = 0;
|
||||
this.autoDownloadCheckInterval = 10 * 60 * 1000; // 10 minutes in milliseconds
|
||||
this.autoDownloadCheckInterval = 30 * 60 * 1000; // 30 minutes in milliseconds
|
||||
this.pageInitTime = Date.now(); // Track when page was initialized
|
||||
|
||||
// Initialize download path field and check download status
|
||||
@@ -808,19 +808,58 @@ export class ExampleImagesManager {
|
||||
return;
|
||||
}
|
||||
|
||||
this.lastAutoDownloadCheck = now;
|
||||
|
||||
if (!this.canAutoDownload()) {
|
||||
console.log('Auto download conditions not met, skipping check');
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
console.log('Performing auto download check...');
|
||||
console.log('Performing auto download pre-check...');
|
||||
|
||||
// Step 1: Lightweight pre-check to see if any work is needed
|
||||
const checkResponse = await fetch('/api/lm/check-example-images-needed', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json'
|
||||
},
|
||||
body: JSON.stringify({
|
||||
model_types: ['lora', 'checkpoint', 'embedding']
|
||||
})
|
||||
});
|
||||
|
||||
if (!checkResponse.ok) {
|
||||
console.warn('Auto download pre-check HTTP error:', checkResponse.status);
|
||||
return;
|
||||
}
|
||||
|
||||
const checkData = await checkResponse.json();
|
||||
|
||||
if (!checkData.success) {
|
||||
console.warn('Auto download pre-check failed:', checkData.error);
|
||||
return;
|
||||
}
|
||||
|
||||
// Update the check timestamp only after successful pre-check
|
||||
this.lastAutoDownloadCheck = now;
|
||||
|
||||
// If download already in progress, skip
|
||||
if (checkData.is_downloading) {
|
||||
console.log('Download already in progress, skipping auto check');
|
||||
return;
|
||||
}
|
||||
|
||||
// If no models need downloading, skip
|
||||
if (!checkData.needs_download || checkData.pending_count === 0) {
|
||||
console.log(`Auto download pre-check complete: ${checkData.processed_count}/${checkData.total_models} models already processed, no work needed`);
|
||||
return;
|
||||
}
|
||||
|
||||
console.log(`Auto download pre-check: ${checkData.pending_count} models need processing, starting download...`);
|
||||
|
||||
// Step 2: Start the actual download (fire-and-forget)
|
||||
const optimize = state.global.settings.optimize_example_images;
|
||||
|
||||
const response = await fetch('/api/lm/download-example-images', {
|
||||
fetch('/api/lm/download-example-images', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json'
|
||||
@@ -830,18 +869,29 @@ export class ExampleImagesManager {
|
||||
model_types: ['lora', 'checkpoint', 'embedding'],
|
||||
auto_mode: true // Flag to indicate this is an automatic download
|
||||
})
|
||||
}).then(response => {
|
||||
if (!response.ok) {
|
||||
console.warn('Auto download start HTTP error:', response.status);
|
||||
return null;
|
||||
}
|
||||
return response.json();
|
||||
}).then(data => {
|
||||
if (data && !data.success) {
|
||||
console.warn('Auto download start failed:', data.error);
|
||||
// If already in progress, push back the next check to avoid hammering the API
|
||||
if (data.error && data.error.includes('already in progress')) {
|
||||
console.log('Download already in progress, backing off next check');
|
||||
this.lastAutoDownloadCheck = now + (5 * 60 * 1000); // Back off for 5 extra minutes
|
||||
}
|
||||
} else if (data && data.success) {
|
||||
console.log('Auto download started:', data.message || 'Download started');
|
||||
}
|
||||
}).catch(error => {
|
||||
console.error('Auto download start error:', error);
|
||||
});
|
||||
|
||||
const data = await response.json();
|
||||
|
||||
if (!data.success) {
|
||||
console.warn('Auto download check failed:', data.error);
|
||||
// If already in progress, push back the next check to avoid hammering the API
|
||||
if (data.error && data.error.includes('already in progress')) {
|
||||
console.log('Download already in progress, backing off next check');
|
||||
this.lastAutoDownloadCheck = now + (5 * 60 * 1000); // Back off for 5 extra minutes
|
||||
}
|
||||
}
|
||||
// Immediately return without waiting for the download fetch to complete
|
||||
// This keeps the UI responsive
|
||||
} catch (error) {
|
||||
console.error('Auto download check error:', error);
|
||||
}
|
||||
|
||||
@@ -27,6 +27,10 @@ export const BASE_MODELS = {
|
||||
FLUX_1_KREA: "Flux.1 Krea",
|
||||
FLUX_1_KONTEXT: "Flux.1 Kontext",
|
||||
FLUX_2_D: "Flux.2 D",
|
||||
FLUX_2_KLEIN_9B: "Flux.2 Klein 9B",
|
||||
FLUX_2_KLEIN_9B_BASE: "Flux.2 Klein 9B-base",
|
||||
FLUX_2_KLEIN_4B: "Flux.2 Klein 4B",
|
||||
FLUX_2_KLEIN_4B_BASE: "Flux.2 Klein 4B-base",
|
||||
AURAFLOW: "AuraFlow",
|
||||
CHROMA: "Chroma",
|
||||
PIXART_A: "PixArt a",
|
||||
@@ -40,10 +44,12 @@ export const BASE_MODELS = {
|
||||
HIDREAM: "HiDream",
|
||||
QWEN: "Qwen",
|
||||
ZIMAGE_TURBO: "ZImageTurbo",
|
||||
|
||||
ZIMAGE_BASE: "ZImageBase",
|
||||
|
||||
// Video models
|
||||
SVD: "SVD",
|
||||
LTXV: "LTXV",
|
||||
LTXV2: "LTXV2",
|
||||
WAN_VIDEO: "Wan Video",
|
||||
WAN_VIDEO_1_3B_T2V: "Wan Video 1.3B t2v",
|
||||
WAN_VIDEO_14B_T2V: "Wan Video 14B t2v",
|
||||
@@ -120,6 +126,10 @@ export const BASE_MODEL_ABBREVIATIONS = {
|
||||
[BASE_MODELS.FLUX_1_KREA]: 'F1KR',
|
||||
[BASE_MODELS.FLUX_1_KONTEXT]: 'F1KX',
|
||||
[BASE_MODELS.FLUX_2_D]: 'F2D',
|
||||
[BASE_MODELS.FLUX_2_KLEIN_9B]: 'FK9',
|
||||
[BASE_MODELS.FLUX_2_KLEIN_9B_BASE]: 'FK9B',
|
||||
[BASE_MODELS.FLUX_2_KLEIN_4B]: 'FK4',
|
||||
[BASE_MODELS.FLUX_2_KLEIN_4B_BASE]: 'FK4B',
|
||||
|
||||
// Other diffusion models
|
||||
[BASE_MODELS.AURAFLOW]: 'AF',
|
||||
@@ -135,10 +145,12 @@ export const BASE_MODEL_ABBREVIATIONS = {
|
||||
[BASE_MODELS.HIDREAM]: 'HID',
|
||||
[BASE_MODELS.QWEN]: 'QWEN',
|
||||
[BASE_MODELS.ZIMAGE_TURBO]: 'ZIT',
|
||||
[BASE_MODELS.ZIMAGE_BASE]: 'ZIB',
|
||||
|
||||
// Video models
|
||||
[BASE_MODELS.SVD]: 'SVD',
|
||||
[BASE_MODELS.LTXV]: 'LTXV',
|
||||
[BASE_MODELS.LTXV2]: 'LTV2',
|
||||
[BASE_MODELS.WAN_VIDEO]: 'WAN',
|
||||
[BASE_MODELS.WAN_VIDEO_1_3B_T2V]: 'WAN',
|
||||
[BASE_MODELS.WAN_VIDEO_14B_T2V]: 'WAN',
|
||||
@@ -328,16 +340,16 @@ export const BASE_MODEL_CATEGORIES = {
|
||||
'Stable Diffusion 3.x': [BASE_MODELS.SD_3, BASE_MODELS.SD_3_5, BASE_MODELS.SD_3_5_MEDIUM, BASE_MODELS.SD_3_5_LARGE, BASE_MODELS.SD_3_5_LARGE_TURBO],
|
||||
'SDXL': [BASE_MODELS.SDXL, BASE_MODELS.SDXL_LIGHTNING, BASE_MODELS.SDXL_HYPER],
|
||||
'Video Models': [
|
||||
BASE_MODELS.SVD, BASE_MODELS.LTXV, BASE_MODELS.HUNYUAN_VIDEO, BASE_MODELS.WAN_VIDEO,
|
||||
BASE_MODELS.SVD, BASE_MODELS.LTXV, BASE_MODELS.LTXV2, BASE_MODELS.HUNYUAN_VIDEO, BASE_MODELS.WAN_VIDEO,
|
||||
BASE_MODELS.WAN_VIDEO_1_3B_T2V, BASE_MODELS.WAN_VIDEO_14B_T2V,
|
||||
BASE_MODELS.WAN_VIDEO_14B_I2V_480P, BASE_MODELS.WAN_VIDEO_14B_I2V_720P,
|
||||
BASE_MODELS.WAN_VIDEO_2_2_TI2V_5B, BASE_MODELS.WAN_VIDEO_2_2_T2V_A14B,
|
||||
BASE_MODELS.WAN_VIDEO_2_2_I2V_A14B
|
||||
],
|
||||
'Flux Models': [BASE_MODELS.FLUX_1_D, BASE_MODELS.FLUX_1_S, BASE_MODELS.FLUX_1_KONTEXT, BASE_MODELS.FLUX_1_KREA, BASE_MODELS.FLUX_2_D],
|
||||
'Flux Models': [BASE_MODELS.FLUX_1_D, BASE_MODELS.FLUX_1_S, BASE_MODELS.FLUX_1_KONTEXT, BASE_MODELS.FLUX_1_KREA, BASE_MODELS.FLUX_2_D, BASE_MODELS.FLUX_2_KLEIN_9B, BASE_MODELS.FLUX_2_KLEIN_9B_BASE, BASE_MODELS.FLUX_2_KLEIN_4B, BASE_MODELS.FLUX_2_KLEIN_4B_BASE],
|
||||
'Other Models': [
|
||||
BASE_MODELS.ILLUSTRIOUS, BASE_MODELS.PONY, BASE_MODELS.HIDREAM,
|
||||
BASE_MODELS.QWEN, BASE_MODELS.AURAFLOW, BASE_MODELS.CHROMA, BASE_MODELS.ZIMAGE_TURBO,
|
||||
BASE_MODELS.QWEN, BASE_MODELS.AURAFLOW, BASE_MODELS.CHROMA, BASE_MODELS.ZIMAGE_TURBO, BASE_MODELS.ZIMAGE_BASE,
|
||||
BASE_MODELS.PIXART_A, BASE_MODELS.PIXART_E, BASE_MODELS.HUNYUAN_1,
|
||||
BASE_MODELS.LUMINA, BASE_MODELS.KOLORS, BASE_MODELS.NOOBAI,
|
||||
BASE_MODELS.UNKNOWN
|
||||
|
||||
@@ -230,8 +230,58 @@ def test_new_symlink_triggers_rescan(monkeypatch: pytest.MonkeyPatch, tmp_path):
|
||||
assert normalized_external in second_cfg._path_mappings
|
||||
|
||||
|
||||
def test_removed_deep_symlink_triggers_rescan(monkeypatch: pytest.MonkeyPatch, tmp_path):
|
||||
"""Removing a deep symlink should trigger cache invalidation."""
|
||||
def test_removed_first_level_symlink_triggers_rescan(monkeypatch: pytest.MonkeyPatch, tmp_path):
|
||||
"""Removing a first-level symlink should trigger cache invalidation."""
|
||||
loras_dir, settings_dir = _setup_paths(monkeypatch, tmp_path)
|
||||
|
||||
# Create first-level symlink (directly under loras root)
|
||||
external_dir = tmp_path / "external"
|
||||
external_dir.mkdir()
|
||||
symlink = loras_dir / "external_models"
|
||||
symlink.symlink_to(external_dir, target_is_directory=True)
|
||||
|
||||
# Initial scan finds the symlink
|
||||
first_cfg = config_module.Config()
|
||||
normalized_external = _normalize(str(external_dir))
|
||||
assert normalized_external in first_cfg._path_mappings
|
||||
|
||||
# Remove the symlink
|
||||
symlink.unlink()
|
||||
|
||||
# Second config should detect invalid cached mapping and rescan
|
||||
second_cfg = config_module.Config()
|
||||
assert normalized_external not in second_cfg._path_mappings
|
||||
|
||||
|
||||
def test_retargeted_first_level_symlink_triggers_rescan(monkeypatch: pytest.MonkeyPatch, tmp_path):
|
||||
"""Changing a first-level symlink's target should trigger cache invalidation."""
|
||||
loras_dir, settings_dir = _setup_paths(monkeypatch, tmp_path)
|
||||
|
||||
# Create first-level symlink
|
||||
target_v1 = tmp_path / "external_v1"
|
||||
target_v1.mkdir()
|
||||
target_v2 = tmp_path / "external_v2"
|
||||
target_v2.mkdir()
|
||||
|
||||
symlink = loras_dir / "external_models"
|
||||
symlink.symlink_to(target_v1, target_is_directory=True)
|
||||
|
||||
# Initial scan
|
||||
first_cfg = config_module.Config()
|
||||
assert _normalize(str(target_v1)) in first_cfg._path_mappings
|
||||
|
||||
# Retarget the symlink
|
||||
symlink.unlink()
|
||||
symlink.symlink_to(target_v2, target_is_directory=True)
|
||||
|
||||
# Second config should detect changed target and rescan
|
||||
second_cfg = config_module.Config()
|
||||
assert _normalize(str(target_v2)) in second_cfg._path_mappings
|
||||
assert _normalize(str(target_v1)) not in second_cfg._path_mappings
|
||||
|
||||
|
||||
def test_deep_symlink_not_scanned(monkeypatch: pytest.MonkeyPatch, tmp_path):
|
||||
"""Deep symlinks (below first level) are not scanned to avoid performance issues."""
|
||||
loras_dir, settings_dir = _setup_paths(monkeypatch, tmp_path)
|
||||
|
||||
# Create nested structure with deep symlink
|
||||
@@ -242,46 +292,12 @@ def test_removed_deep_symlink_triggers_rescan(monkeypatch: pytest.MonkeyPatch, t
|
||||
deep_symlink = subdir / "styles"
|
||||
deep_symlink.symlink_to(external_dir, target_is_directory=True)
|
||||
|
||||
# Initial scan finds the deep symlink
|
||||
first_cfg = config_module.Config()
|
||||
# Config should not detect deep symlinks (only first-level)
|
||||
cfg = config_module.Config()
|
||||
normalized_external = _normalize(str(external_dir))
|
||||
assert normalized_external in first_cfg._path_mappings
|
||||
|
||||
# Remove the deep symlink
|
||||
deep_symlink.unlink()
|
||||
|
||||
# Second config should detect invalid cached mapping and rescan
|
||||
second_cfg = config_module.Config()
|
||||
assert normalized_external not in second_cfg._path_mappings
|
||||
assert normalized_external not in cfg._path_mappings
|
||||
|
||||
|
||||
def test_retargeted_deep_symlink_triggers_rescan(monkeypatch: pytest.MonkeyPatch, tmp_path):
|
||||
"""Changing a deep symlink's target should trigger cache invalidation."""
|
||||
loras_dir, settings_dir = _setup_paths(monkeypatch, tmp_path)
|
||||
|
||||
# Create nested structure
|
||||
subdir = loras_dir / "anime"
|
||||
subdir.mkdir()
|
||||
target_v1 = tmp_path / "external_v1"
|
||||
target_v1.mkdir()
|
||||
target_v2 = tmp_path / "external_v2"
|
||||
target_v2.mkdir()
|
||||
|
||||
deep_symlink = subdir / "styles"
|
||||
deep_symlink.symlink_to(target_v1, target_is_directory=True)
|
||||
|
||||
# Initial scan
|
||||
first_cfg = config_module.Config()
|
||||
assert _normalize(str(target_v1)) in first_cfg._path_mappings
|
||||
|
||||
# Retarget the symlink
|
||||
deep_symlink.unlink()
|
||||
deep_symlink.symlink_to(target_v2, target_is_directory=True)
|
||||
|
||||
# Second config should detect changed target and rescan
|
||||
second_cfg = config_module.Config()
|
||||
assert _normalize(str(target_v2)) in second_cfg._path_mappings
|
||||
assert _normalize(str(target_v1)) not in second_cfg._path_mappings
|
||||
def test_legacy_symlink_cache_automatic_cleanup(monkeypatch: pytest.MonkeyPatch, tmp_path):
|
||||
"""Test that legacy symlink cache is automatically cleaned up after migration."""
|
||||
settings_dir = tmp_path / "settings"
|
||||
|
||||
@@ -47,6 +47,8 @@ class StubDownloadManager:
|
||||
self.resume_error: Exception | None = None
|
||||
self.stop_error: Exception | None = None
|
||||
self.force_error: Exception | None = None
|
||||
self.check_pending_result: dict[str, Any] | None = None
|
||||
self.check_pending_calls: list[list[str]] = []
|
||||
|
||||
async def get_status(self, request: web.Request) -> dict[str, Any]:
|
||||
return {"success": True, "status": "idle"}
|
||||
@@ -75,6 +77,20 @@ class StubDownloadManager:
|
||||
raise self.force_error
|
||||
return {"success": True, "payload": payload}
|
||||
|
||||
async def check_pending_models(self, model_types: list[str]) -> dict[str, Any]:
|
||||
self.check_pending_calls.append(model_types)
|
||||
if self.check_pending_result is not None:
|
||||
return self.check_pending_result
|
||||
return {
|
||||
"success": True,
|
||||
"is_downloading": False,
|
||||
"total_models": 100,
|
||||
"pending_count": 10,
|
||||
"processed_count": 90,
|
||||
"failed_count": 0,
|
||||
"needs_download": True,
|
||||
}
|
||||
|
||||
|
||||
class StubImportUseCase:
|
||||
def __init__(self) -> None:
|
||||
@@ -236,3 +252,123 @@ async def test_import_route_returns_validation_errors():
|
||||
assert response.status == 400
|
||||
body = await _json(response)
|
||||
assert body == {"success": False, "error": "bad payload"}
|
||||
|
||||
|
||||
async def test_check_example_images_needed_returns_pending_counts():
|
||||
"""Test that check_example_images_needed endpoint returns pending model counts."""
|
||||
async with registrar_app() as harness:
|
||||
harness.download_manager.check_pending_result = {
|
||||
"success": True,
|
||||
"is_downloading": False,
|
||||
"total_models": 5500,
|
||||
"pending_count": 12,
|
||||
"processed_count": 5488,
|
||||
"failed_count": 45,
|
||||
"needs_download": True,
|
||||
}
|
||||
|
||||
response = await harness.client.post(
|
||||
"/api/lm/check-example-images-needed",
|
||||
json={"model_types": ["lora", "checkpoint"]},
|
||||
)
|
||||
|
||||
assert response.status == 200
|
||||
body = await _json(response)
|
||||
assert body["success"] is True
|
||||
assert body["total_models"] == 5500
|
||||
assert body["pending_count"] == 12
|
||||
assert body["processed_count"] == 5488
|
||||
assert body["failed_count"] == 45
|
||||
assert body["needs_download"] is True
|
||||
assert body["is_downloading"] is False
|
||||
|
||||
# Verify the manager was called with correct model types
|
||||
assert harness.download_manager.check_pending_calls == [["lora", "checkpoint"]]
|
||||
|
||||
|
||||
async def test_check_example_images_needed_handles_download_in_progress():
|
||||
"""Test that check_example_images_needed returns correct status when download is running."""
|
||||
async with registrar_app() as harness:
|
||||
harness.download_manager.check_pending_result = {
|
||||
"success": True,
|
||||
"is_downloading": True,
|
||||
"total_models": 0,
|
||||
"pending_count": 0,
|
||||
"processed_count": 0,
|
||||
"failed_count": 0,
|
||||
"needs_download": False,
|
||||
"message": "Download already in progress",
|
||||
}
|
||||
|
||||
response = await harness.client.post(
|
||||
"/api/lm/check-example-images-needed",
|
||||
json={"model_types": ["lora"]},
|
||||
)
|
||||
|
||||
assert response.status == 200
|
||||
body = await _json(response)
|
||||
assert body["success"] is True
|
||||
assert body["is_downloading"] is True
|
||||
assert body["needs_download"] is False
|
||||
|
||||
|
||||
async def test_check_example_images_needed_handles_no_pending_models():
|
||||
"""Test that check_example_images_needed returns correct status when no work is needed."""
|
||||
async with registrar_app() as harness:
|
||||
harness.download_manager.check_pending_result = {
|
||||
"success": True,
|
||||
"is_downloading": False,
|
||||
"total_models": 5500,
|
||||
"pending_count": 0,
|
||||
"processed_count": 5500,
|
||||
"failed_count": 0,
|
||||
"needs_download": False,
|
||||
}
|
||||
|
||||
response = await harness.client.post(
|
||||
"/api/lm/check-example-images-needed",
|
||||
json={"model_types": ["lora", "checkpoint", "embedding"]},
|
||||
)
|
||||
|
||||
assert response.status == 200
|
||||
body = await _json(response)
|
||||
assert body["success"] is True
|
||||
assert body["pending_count"] == 0
|
||||
assert body["needs_download"] is False
|
||||
assert body["processed_count"] == 5500
|
||||
|
||||
|
||||
async def test_check_example_images_needed_uses_default_model_types():
|
||||
"""Test that check_example_images_needed uses default model types when not specified."""
|
||||
async with registrar_app() as harness:
|
||||
response = await harness.client.post(
|
||||
"/api/lm/check-example-images-needed",
|
||||
json={}, # No model_types specified
|
||||
)
|
||||
|
||||
assert response.status == 200
|
||||
# Should use default model types
|
||||
assert harness.download_manager.check_pending_calls == [["lora", "checkpoint", "embedding"]]
|
||||
|
||||
|
||||
async def test_check_example_images_needed_returns_error_on_exception():
|
||||
"""Test that check_example_images_needed returns 500 on internal error."""
|
||||
async with registrar_app() as harness:
|
||||
# Simulate an error by setting result to an error state
|
||||
# Actually, we need to make the method raise an exception
|
||||
original_method = harness.download_manager.check_pending_models
|
||||
|
||||
async def failing_check(_model_types):
|
||||
raise RuntimeError("Database connection failed")
|
||||
|
||||
harness.download_manager.check_pending_models = failing_check
|
||||
|
||||
response = await harness.client.post(
|
||||
"/api/lm/check-example-images-needed",
|
||||
json={"model_types": ["lora"]},
|
||||
)
|
||||
|
||||
assert response.status == 500
|
||||
body = await _json(response)
|
||||
assert body["success"] is False
|
||||
assert "Database connection failed" in body["error"]
|
||||
|
||||
@@ -502,6 +502,7 @@ def test_handler_set_route_mapping_includes_all_handlers() -> None:
|
||||
"resume_example_images",
|
||||
"stop_example_images",
|
||||
"force_download_example_images",
|
||||
"check_example_images_needed",
|
||||
"import_example_images",
|
||||
"delete_example_image",
|
||||
"set_example_image_nsfw_level",
|
||||
|
||||
283
tests/services/test_cache_entry_validator.py
Normal file
283
tests/services/test_cache_entry_validator.py
Normal file
@@ -0,0 +1,283 @@
|
||||
"""
|
||||
Unit tests for CacheEntryValidator
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from py.services.cache_entry_validator import (
|
||||
CacheEntryValidator,
|
||||
ValidationResult,
|
||||
)
|
||||
|
||||
|
||||
class TestCacheEntryValidator:
|
||||
"""Tests for CacheEntryValidator class"""
|
||||
|
||||
def test_validate_valid_entry(self):
|
||||
"""Test validation of a valid cache entry"""
|
||||
entry = {
|
||||
'file_path': '/models/test.safetensors',
|
||||
'sha256': 'abc123def456',
|
||||
'file_name': 'test.safetensors',
|
||||
'model_name': 'Test Model',
|
||||
'size': 1024,
|
||||
'modified': 1234567890.0,
|
||||
'tags': ['tag1', 'tag2'],
|
||||
}
|
||||
|
||||
result = CacheEntryValidator.validate(entry, auto_repair=False)
|
||||
|
||||
assert result.is_valid is True
|
||||
assert result.repaired is False
|
||||
assert len(result.errors) == 0
|
||||
assert result.entry == entry
|
||||
|
||||
def test_validate_missing_required_field_sha256(self):
|
||||
"""Test validation fails when required sha256 field is missing"""
|
||||
entry = {
|
||||
'file_path': '/models/test.safetensors',
|
||||
# sha256 missing
|
||||
'file_name': 'test.safetensors',
|
||||
}
|
||||
|
||||
result = CacheEntryValidator.validate(entry, auto_repair=False)
|
||||
|
||||
assert result.is_valid is False
|
||||
assert result.repaired is False
|
||||
assert any('sha256' in error for error in result.errors)
|
||||
|
||||
def test_validate_missing_required_field_file_path(self):
|
||||
"""Test validation fails when required file_path field is missing"""
|
||||
entry = {
|
||||
# file_path missing
|
||||
'sha256': 'abc123def456',
|
||||
'file_name': 'test.safetensors',
|
||||
}
|
||||
|
||||
result = CacheEntryValidator.validate(entry, auto_repair=False)
|
||||
|
||||
assert result.is_valid is False
|
||||
assert result.repaired is False
|
||||
assert any('file_path' in error for error in result.errors)
|
||||
|
||||
def test_validate_empty_required_field_sha256(self):
|
||||
"""Test validation fails when sha256 is empty string"""
|
||||
entry = {
|
||||
'file_path': '/models/test.safetensors',
|
||||
'sha256': '', # Empty string
|
||||
}
|
||||
|
||||
result = CacheEntryValidator.validate(entry, auto_repair=False)
|
||||
|
||||
assert result.is_valid is False
|
||||
assert result.repaired is False
|
||||
assert any('sha256' in error for error in result.errors)
|
||||
|
||||
def test_validate_empty_required_field_file_path(self):
|
||||
"""Test validation fails when file_path is empty string"""
|
||||
entry = {
|
||||
'file_path': '', # Empty string
|
||||
'sha256': 'abc123def456',
|
||||
}
|
||||
|
||||
result = CacheEntryValidator.validate(entry, auto_repair=False)
|
||||
|
||||
assert result.is_valid is False
|
||||
assert result.repaired is False
|
||||
assert any('file_path' in error for error in result.errors)
|
||||
|
||||
def test_validate_none_required_field(self):
|
||||
"""Test validation fails when required field is None"""
|
||||
entry = {
|
||||
'file_path': None,
|
||||
'sha256': 'abc123def456',
|
||||
}
|
||||
|
||||
result = CacheEntryValidator.validate(entry, auto_repair=False)
|
||||
|
||||
assert result.is_valid is False
|
||||
assert result.repaired is False
|
||||
assert any('file_path' in error for error in result.errors)
|
||||
|
||||
def test_validate_none_entry(self):
|
||||
"""Test validation handles None entry"""
|
||||
result = CacheEntryValidator.validate(None, auto_repair=False)
|
||||
|
||||
assert result.is_valid is False
|
||||
assert result.repaired is False
|
||||
assert any('None' in error for error in result.errors)
|
||||
assert result.entry is None
|
||||
|
||||
def test_validate_non_dict_entry(self):
|
||||
"""Test validation handles non-dict entry"""
|
||||
result = CacheEntryValidator.validate("not a dict", auto_repair=False)
|
||||
|
||||
assert result.is_valid is False
|
||||
assert result.repaired is False
|
||||
assert any('not a dict' in error for error in result.errors)
|
||||
assert result.entry is None
|
||||
|
||||
def test_auto_repair_missing_non_required_field(self):
|
||||
"""Test auto-repair adds missing non-required fields"""
|
||||
entry = {
|
||||
'file_path': '/models/test.safetensors',
|
||||
'sha256': 'abc123def456',
|
||||
# file_name, model_name, tags missing
|
||||
}
|
||||
|
||||
result = CacheEntryValidator.validate(entry, auto_repair=True)
|
||||
|
||||
assert result.is_valid is True
|
||||
assert result.repaired is True
|
||||
assert result.entry['file_name'] == ''
|
||||
assert result.entry['model_name'] == ''
|
||||
assert result.entry['tags'] == []
|
||||
|
||||
def test_auto_repair_wrong_type_field(self):
|
||||
"""Test auto-repair fixes fields with wrong type"""
|
||||
entry = {
|
||||
'file_path': '/models/test.safetensors',
|
||||
'sha256': 'abc123def456',
|
||||
'size': 'not a number', # Should be int
|
||||
'tags': 'not a list', # Should be list
|
||||
}
|
||||
|
||||
result = CacheEntryValidator.validate(entry, auto_repair=True)
|
||||
|
||||
assert result.is_valid is True
|
||||
assert result.repaired is True
|
||||
assert result.entry['size'] == 0 # Default value
|
||||
assert result.entry['tags'] == [] # Default value
|
||||
|
||||
def test_normalize_sha256_lowercase(self):
|
||||
"""Test sha256 is normalized to lowercase"""
|
||||
entry = {
|
||||
'file_path': '/models/test.safetensors',
|
||||
'sha256': 'ABC123DEF456', # Uppercase
|
||||
}
|
||||
|
||||
result = CacheEntryValidator.validate(entry, auto_repair=True)
|
||||
|
||||
assert result.is_valid is True
|
||||
assert result.entry['sha256'] == 'abc123def456'
|
||||
|
||||
def test_validate_batch_all_valid(self):
|
||||
"""Test batch validation with all valid entries"""
|
||||
entries = [
|
||||
{
|
||||
'file_path': '/models/test1.safetensors',
|
||||
'sha256': 'abc123',
|
||||
},
|
||||
{
|
||||
'file_path': '/models/test2.safetensors',
|
||||
'sha256': 'def456',
|
||||
},
|
||||
]
|
||||
|
||||
valid, invalid = CacheEntryValidator.validate_batch(entries, auto_repair=False)
|
||||
|
||||
assert len(valid) == 2
|
||||
assert len(invalid) == 0
|
||||
|
||||
def test_validate_batch_mixed_validity(self):
|
||||
"""Test batch validation with mixed valid/invalid entries"""
|
||||
entries = [
|
||||
{
|
||||
'file_path': '/models/test1.safetensors',
|
||||
'sha256': 'abc123',
|
||||
},
|
||||
{
|
||||
'file_path': '/models/test2.safetensors',
|
||||
# sha256 missing - invalid
|
||||
},
|
||||
{
|
||||
'file_path': '/models/test3.safetensors',
|
||||
'sha256': 'def456',
|
||||
},
|
||||
]
|
||||
|
||||
valid, invalid = CacheEntryValidator.validate_batch(entries, auto_repair=False)
|
||||
|
||||
assert len(valid) == 2
|
||||
assert len(invalid) == 1
|
||||
# invalid list contains the actual invalid entries (not by index)
|
||||
assert invalid[0]['file_path'] == '/models/test2.safetensors'
|
||||
|
||||
def test_validate_batch_empty_list(self):
|
||||
"""Test batch validation with empty list"""
|
||||
valid, invalid = CacheEntryValidator.validate_batch([], auto_repair=False)
|
||||
|
||||
assert len(valid) == 0
|
||||
assert len(invalid) == 0
|
||||
|
||||
def test_get_file_path_safe(self):
|
||||
"""Test safe file_path extraction"""
|
||||
entry = {'file_path': '/models/test.safetensors', 'sha256': 'abc123'}
|
||||
assert CacheEntryValidator.get_file_path_safe(entry) == '/models/test.safetensors'
|
||||
|
||||
def test_get_file_path_safe_missing(self):
|
||||
"""Test safe file_path extraction when missing"""
|
||||
entry = {'sha256': 'abc123'}
|
||||
assert CacheEntryValidator.get_file_path_safe(entry) == ''
|
||||
|
||||
def test_get_file_path_safe_not_dict(self):
|
||||
"""Test safe file_path extraction from non-dict"""
|
||||
assert CacheEntryValidator.get_file_path_safe(None) == ''
|
||||
assert CacheEntryValidator.get_file_path_safe('string') == ''
|
||||
|
||||
def test_get_sha256_safe(self):
|
||||
"""Test safe sha256 extraction"""
|
||||
entry = {'file_path': '/models/test.safetensors', 'sha256': 'ABC123'}
|
||||
assert CacheEntryValidator.get_sha256_safe(entry) == 'abc123'
|
||||
|
||||
def test_get_sha256_safe_missing(self):
|
||||
"""Test safe sha256 extraction when missing"""
|
||||
entry = {'file_path': '/models/test.safetensors'}
|
||||
assert CacheEntryValidator.get_sha256_safe(entry) == ''
|
||||
|
||||
def test_get_sha256_safe_not_dict(self):
|
||||
"""Test safe sha256 extraction from non-dict"""
|
||||
assert CacheEntryValidator.get_sha256_safe(None) == ''
|
||||
assert CacheEntryValidator.get_sha256_safe('string') == ''
|
||||
|
||||
def test_validate_with_all_optional_fields(self):
|
||||
"""Test validation with all optional fields present"""
|
||||
entry = {
|
||||
'file_path': '/models/test.safetensors',
|
||||
'sha256': 'abc123',
|
||||
'file_name': 'test.safetensors',
|
||||
'model_name': 'Test Model',
|
||||
'folder': 'test_folder',
|
||||
'size': 1024,
|
||||
'modified': 1234567890.0,
|
||||
'tags': ['tag1', 'tag2'],
|
||||
'preview_url': 'http://example.com/preview.jpg',
|
||||
'base_model': 'SD1.5',
|
||||
'from_civitai': True,
|
||||
'favorite': True,
|
||||
'exclude': False,
|
||||
'db_checked': True,
|
||||
'preview_nsfw_level': 1,
|
||||
'notes': 'Test notes',
|
||||
'usage_tips': 'Test tips',
|
||||
}
|
||||
|
||||
result = CacheEntryValidator.validate(entry, auto_repair=False)
|
||||
|
||||
assert result.is_valid is True
|
||||
assert result.repaired is False
|
||||
assert result.entry == entry
|
||||
|
||||
def test_validate_numeric_field_accepts_float_for_int(self):
|
||||
"""Test that numeric fields accept float for int type"""
|
||||
entry = {
|
||||
'file_path': '/models/test.safetensors',
|
||||
'sha256': 'abc123',
|
||||
'size': 1024.5, # Float for int field
|
||||
'modified': 1234567890.0,
|
||||
}
|
||||
|
||||
result = CacheEntryValidator.validate(entry, auto_repair=False)
|
||||
|
||||
assert result.is_valid is True
|
||||
assert result.repaired is False
|
||||
364
tests/services/test_cache_health_monitor.py
Normal file
364
tests/services/test_cache_health_monitor.py
Normal file
@@ -0,0 +1,364 @@
|
||||
"""
|
||||
Unit tests for CacheHealthMonitor
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from py.services.cache_health_monitor import (
|
||||
CacheHealthMonitor,
|
||||
CacheHealthStatus,
|
||||
HealthReport,
|
||||
)
|
||||
|
||||
|
||||
class TestCacheHealthMonitor:
|
||||
"""Tests for CacheHealthMonitor class"""
|
||||
|
||||
def test_check_health_all_valid_entries(self):
|
||||
"""Test health check with 100% valid entries"""
|
||||
monitor = CacheHealthMonitor()
|
||||
|
||||
entries = [
|
||||
{
|
||||
'file_path': f'/models/test{i}.safetensors',
|
||||
'sha256': f'hash{i}',
|
||||
}
|
||||
for i in range(100)
|
||||
]
|
||||
|
||||
report = monitor.check_health(entries, auto_repair=False)
|
||||
|
||||
assert report.status == CacheHealthStatus.HEALTHY
|
||||
assert report.total_entries == 100
|
||||
assert report.valid_entries == 100
|
||||
assert report.invalid_entries == 0
|
||||
assert report.repaired_entries == 0
|
||||
assert report.corruption_rate == 0.0
|
||||
assert report.message == "Cache is healthy"
|
||||
|
||||
def test_check_health_degraded_cache(self):
|
||||
"""Test health check with 1-5% invalid entries (degraded)"""
|
||||
monitor = CacheHealthMonitor()
|
||||
|
||||
# Create 100 entries, 2 invalid (2%)
|
||||
entries = [
|
||||
{
|
||||
'file_path': f'/models/test{i}.safetensors',
|
||||
'sha256': f'hash{i}',
|
||||
}
|
||||
for i in range(98)
|
||||
]
|
||||
# Add 2 invalid entries
|
||||
entries.append({'file_path': '/models/invalid1.safetensors'}) # Missing sha256
|
||||
entries.append({'file_path': '/models/invalid2.safetensors'}) # Missing sha256
|
||||
|
||||
report = monitor.check_health(entries, auto_repair=False)
|
||||
|
||||
assert report.status == CacheHealthStatus.DEGRADED
|
||||
assert report.total_entries == 100
|
||||
assert report.valid_entries == 98
|
||||
assert report.invalid_entries == 2
|
||||
assert report.corruption_rate == 0.02
|
||||
# Message describes the issue without necessarily containing the word "degraded"
|
||||
assert 'invalid entries' in report.message.lower()
|
||||
|
||||
def test_check_health_corrupted_cache(self):
|
||||
"""Test health check with >5% invalid entries (corrupted)"""
|
||||
monitor = CacheHealthMonitor()
|
||||
|
||||
# Create 100 entries, 10 invalid (10%)
|
||||
entries = [
|
||||
{
|
||||
'file_path': f'/models/test{i}.safetensors',
|
||||
'sha256': f'hash{i}',
|
||||
}
|
||||
for i in range(90)
|
||||
]
|
||||
# Add 10 invalid entries
|
||||
for i in range(10):
|
||||
entries.append({'file_path': f'/models/invalid{i}.safetensors'})
|
||||
|
||||
report = monitor.check_health(entries, auto_repair=False)
|
||||
|
||||
assert report.status == CacheHealthStatus.CORRUPTED
|
||||
assert report.total_entries == 100
|
||||
assert report.valid_entries == 90
|
||||
assert report.invalid_entries == 10
|
||||
assert report.corruption_rate == 0.10
|
||||
assert 'corrupted' in report.message.lower()
|
||||
|
||||
def test_check_health_empty_cache(self):
|
||||
"""Test health check with empty cache"""
|
||||
monitor = CacheHealthMonitor()
|
||||
|
||||
report = monitor.check_health([], auto_repair=False)
|
||||
|
||||
assert report.status == CacheHealthStatus.HEALTHY
|
||||
assert report.total_entries == 0
|
||||
assert report.valid_entries == 0
|
||||
assert report.invalid_entries == 0
|
||||
assert report.corruption_rate == 0.0
|
||||
assert report.message == "Cache is empty"
|
||||
|
||||
def test_check_health_single_invalid_entry(self):
|
||||
"""Test health check with 1 invalid entry out of 1 (100% corruption)"""
|
||||
monitor = CacheHealthMonitor()
|
||||
|
||||
entries = [{'file_path': '/models/invalid.safetensors'}]
|
||||
|
||||
report = monitor.check_health(entries, auto_repair=False)
|
||||
|
||||
assert report.status == CacheHealthStatus.CORRUPTED
|
||||
assert report.total_entries == 1
|
||||
assert report.valid_entries == 0
|
||||
assert report.invalid_entries == 1
|
||||
assert report.corruption_rate == 1.0
|
||||
|
||||
def test_check_health_boundary_degraded_threshold(self):
|
||||
"""Test health check at degraded threshold (1%)"""
|
||||
monitor = CacheHealthMonitor(degraded_threshold=0.01)
|
||||
|
||||
# 100 entries, 1 invalid (exactly 1%)
|
||||
entries = [
|
||||
{
|
||||
'file_path': f'/models/test{i}.safetensors',
|
||||
'sha256': f'hash{i}',
|
||||
}
|
||||
for i in range(99)
|
||||
]
|
||||
entries.append({'file_path': '/models/invalid.safetensors'})
|
||||
|
||||
report = monitor.check_health(entries, auto_repair=False)
|
||||
|
||||
assert report.status == CacheHealthStatus.DEGRADED
|
||||
assert report.corruption_rate == 0.01
|
||||
|
||||
def test_check_health_boundary_corrupted_threshold(self):
|
||||
"""Test health check at corrupted threshold (5%)"""
|
||||
monitor = CacheHealthMonitor(corrupted_threshold=0.05)
|
||||
|
||||
# 100 entries, 5 invalid (exactly 5%)
|
||||
entries = [
|
||||
{
|
||||
'file_path': f'/models/test{i}.safetensors',
|
||||
'sha256': f'hash{i}',
|
||||
}
|
||||
for i in range(95)
|
||||
]
|
||||
for i in range(5):
|
||||
entries.append({'file_path': f'/models/invalid{i}.safetensors'})
|
||||
|
||||
report = monitor.check_health(entries, auto_repair=False)
|
||||
|
||||
assert report.status == CacheHealthStatus.CORRUPTED
|
||||
assert report.corruption_rate == 0.05
|
||||
|
||||
def test_check_health_below_degraded_threshold(self):
|
||||
"""Test health check below degraded threshold (0%)"""
|
||||
monitor = CacheHealthMonitor(degraded_threshold=0.01)
|
||||
|
||||
# All entries valid
|
||||
entries = [
|
||||
{
|
||||
'file_path': f'/models/test{i}.safetensors',
|
||||
'sha256': f'hash{i}',
|
||||
}
|
||||
for i in range(100)
|
||||
]
|
||||
|
||||
report = monitor.check_health(entries, auto_repair=False)
|
||||
|
||||
assert report.status == CacheHealthStatus.HEALTHY
|
||||
assert report.corruption_rate == 0.0
|
||||
|
||||
def test_check_health_auto_repair(self):
|
||||
"""Test health check with auto_repair enabled"""
|
||||
monitor = CacheHealthMonitor()
|
||||
|
||||
# 1 entry with all fields (won't be repaired), 1 entry with missing non-required fields (will be repaired)
|
||||
complete_entry = {
|
||||
'file_path': '/models/test1.safetensors',
|
||||
'sha256': 'hash1',
|
||||
'file_name': 'test1.safetensors',
|
||||
'model_name': 'Model 1',
|
||||
'folder': '',
|
||||
'size': 0,
|
||||
'modified': 0.0,
|
||||
'tags': ['tag1'],
|
||||
'preview_url': '',
|
||||
'base_model': '',
|
||||
'from_civitai': True,
|
||||
'favorite': False,
|
||||
'exclude': False,
|
||||
'db_checked': False,
|
||||
'preview_nsfw_level': 0,
|
||||
'notes': '',
|
||||
'usage_tips': '',
|
||||
}
|
||||
incomplete_entry = {
|
||||
'file_path': '/models/test2.safetensors',
|
||||
'sha256': 'hash2',
|
||||
# Missing many optional fields (will be repaired)
|
||||
}
|
||||
|
||||
entries = [complete_entry, incomplete_entry]
|
||||
|
||||
report = monitor.check_health(entries, auto_repair=True)
|
||||
|
||||
assert report.status == CacheHealthStatus.HEALTHY
|
||||
assert report.total_entries == 2
|
||||
assert report.valid_entries == 2
|
||||
assert report.invalid_entries == 0
|
||||
assert report.repaired_entries == 1
|
||||
|
||||
def test_should_notify_user_healthy(self):
|
||||
"""Test should_notify_user for healthy cache"""
|
||||
monitor = CacheHealthMonitor()
|
||||
|
||||
report = HealthReport(
|
||||
status=CacheHealthStatus.HEALTHY,
|
||||
total_entries=100,
|
||||
valid_entries=100,
|
||||
invalid_entries=0,
|
||||
repaired_entries=0,
|
||||
message="Cache is healthy"
|
||||
)
|
||||
|
||||
assert monitor.should_notify_user(report) is False
|
||||
|
||||
def test_should_notify_user_degraded(self):
|
||||
"""Test should_notify_user for degraded cache"""
|
||||
monitor = CacheHealthMonitor()
|
||||
|
||||
report = HealthReport(
|
||||
status=CacheHealthStatus.DEGRADED,
|
||||
total_entries=100,
|
||||
valid_entries=98,
|
||||
invalid_entries=2,
|
||||
repaired_entries=0,
|
||||
message="Cache is degraded"
|
||||
)
|
||||
|
||||
assert monitor.should_notify_user(report) is True
|
||||
|
||||
def test_should_notify_user_corrupted(self):
|
||||
"""Test should_notify_user for corrupted cache"""
|
||||
monitor = CacheHealthMonitor()
|
||||
|
||||
report = HealthReport(
|
||||
status=CacheHealthStatus.CORRUPTED,
|
||||
total_entries=100,
|
||||
valid_entries=90,
|
||||
invalid_entries=10,
|
||||
repaired_entries=0,
|
||||
message="Cache is corrupted"
|
||||
)
|
||||
|
||||
assert monitor.should_notify_user(report) is True
|
||||
|
||||
def test_get_notification_severity_degraded(self):
|
||||
"""Test get_notification_severity for degraded cache"""
|
||||
monitor = CacheHealthMonitor()
|
||||
|
||||
report = HealthReport(
|
||||
status=CacheHealthStatus.DEGRADED,
|
||||
total_entries=100,
|
||||
valid_entries=98,
|
||||
invalid_entries=2,
|
||||
repaired_entries=0,
|
||||
message="Cache is degraded"
|
||||
)
|
||||
|
||||
assert monitor.get_notification_severity(report) == 'warning'
|
||||
|
||||
def test_get_notification_severity_corrupted(self):
|
||||
"""Test get_notification_severity for corrupted cache"""
|
||||
monitor = CacheHealthMonitor()
|
||||
|
||||
report = HealthReport(
|
||||
status=CacheHealthStatus.CORRUPTED,
|
||||
total_entries=100,
|
||||
valid_entries=90,
|
||||
invalid_entries=10,
|
||||
repaired_entries=0,
|
||||
message="Cache is corrupted"
|
||||
)
|
||||
|
||||
assert monitor.get_notification_severity(report) == 'error'
|
||||
|
||||
def test_report_to_dict(self):
|
||||
"""Test HealthReport to_dict conversion"""
|
||||
report = HealthReport(
|
||||
status=CacheHealthStatus.DEGRADED,
|
||||
total_entries=100,
|
||||
valid_entries=98,
|
||||
invalid_entries=2,
|
||||
repaired_entries=1,
|
||||
invalid_paths=['/path1', '/path2'],
|
||||
message="Cache issues detected"
|
||||
)
|
||||
|
||||
result = report.to_dict()
|
||||
|
||||
assert result['status'] == 'degraded'
|
||||
assert result['total_entries'] == 100
|
||||
assert result['valid_entries'] == 98
|
||||
assert result['invalid_entries'] == 2
|
||||
assert result['repaired_entries'] == 1
|
||||
assert result['corruption_rate'] == '2.0%'
|
||||
assert len(result['invalid_paths']) == 2
|
||||
assert result['message'] == "Cache issues detected"
|
||||
|
||||
def test_report_corruption_rate_zero_division(self):
|
||||
"""Test corruption_rate calculation with zero entries"""
|
||||
report = HealthReport(
|
||||
status=CacheHealthStatus.HEALTHY,
|
||||
total_entries=0,
|
||||
valid_entries=0,
|
||||
invalid_entries=0,
|
||||
repaired_entries=0,
|
||||
message="Cache is empty"
|
||||
)
|
||||
|
||||
assert report.corruption_rate == 0.0
|
||||
|
||||
def test_check_health_collects_invalid_paths(self):
|
||||
"""Test health check collects invalid entry paths"""
|
||||
monitor = CacheHealthMonitor()
|
||||
|
||||
entries = [
|
||||
{
|
||||
'file_path': '/models/valid.safetensors',
|
||||
'sha256': 'hash1',
|
||||
},
|
||||
{
|
||||
'file_path': '/models/invalid1.safetensors',
|
||||
},
|
||||
{
|
||||
'file_path': '/models/invalid2.safetensors',
|
||||
},
|
||||
]
|
||||
|
||||
report = monitor.check_health(entries, auto_repair=False)
|
||||
|
||||
assert len(report.invalid_paths) == 2
|
||||
assert '/models/invalid1.safetensors' in report.invalid_paths
|
||||
assert '/models/invalid2.safetensors' in report.invalid_paths
|
||||
|
||||
def test_report_to_dict_limits_invalid_paths(self):
|
||||
"""Test that to_dict limits invalid_paths to first 10"""
|
||||
report = HealthReport(
|
||||
status=CacheHealthStatus.CORRUPTED,
|
||||
total_entries=15,
|
||||
valid_entries=0,
|
||||
invalid_entries=15,
|
||||
repaired_entries=0,
|
||||
invalid_paths=[f'/path{i}' for i in range(15)],
|
||||
message="Cache corrupted"
|
||||
)
|
||||
|
||||
result = report.to_dict()
|
||||
|
||||
assert len(result['invalid_paths']) == 10
|
||||
assert result['invalid_paths'][0] == '/path0'
|
||||
assert result['invalid_paths'][-1] == '/path9'
|
||||
368
tests/services/test_check_pending_models.py
Normal file
368
tests/services/test_check_pending_models.py
Normal file
@@ -0,0 +1,368 @@
|
||||
"""Tests for the check_pending_models lightweight pre-check functionality."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from py.services.settings_manager import get_settings_manager
|
||||
from py.utils import example_images_download_manager as download_module
|
||||
|
||||
|
||||
class StubScanner:
|
||||
"""Scanner double returning predetermined cache contents."""
|
||||
|
||||
def __init__(self, models: list[dict]) -> None:
|
||||
self._cache = SimpleNamespace(raw_data=models)
|
||||
|
||||
async def get_cached_data(self):
|
||||
return self._cache
|
||||
|
||||
|
||||
def _patch_scanners(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
lora_scanner: StubScanner | None = None,
|
||||
checkpoint_scanner: StubScanner | None = None,
|
||||
embedding_scanner: StubScanner | None = None,
|
||||
) -> None:
|
||||
"""Patch ServiceRegistry to return stub scanners."""
|
||||
|
||||
async def _get_lora_scanner(cls):
|
||||
return lora_scanner or StubScanner([])
|
||||
|
||||
async def _get_checkpoint_scanner(cls):
|
||||
return checkpoint_scanner or StubScanner([])
|
||||
|
||||
async def _get_embedding_scanner(cls):
|
||||
return embedding_scanner or StubScanner([])
|
||||
|
||||
monkeypatch.setattr(
|
||||
download_module.ServiceRegistry,
|
||||
"get_lora_scanner",
|
||||
classmethod(_get_lora_scanner),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
download_module.ServiceRegistry,
|
||||
"get_checkpoint_scanner",
|
||||
classmethod(_get_checkpoint_scanner),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
download_module.ServiceRegistry,
|
||||
"get_embedding_scanner",
|
||||
classmethod(_get_embedding_scanner),
|
||||
)
|
||||
|
||||
|
||||
class RecordingWebSocketManager:
|
||||
"""Collects broadcast payloads for assertions."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.payloads: list[dict] = []
|
||||
|
||||
async def broadcast(self, payload: dict) -> None:
|
||||
self.payloads.append(payload)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.usefixtures("tmp_path")
|
||||
async def test_check_pending_models_returns_zero_when_all_processed(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path,
|
||||
settings_manager,
|
||||
):
|
||||
"""Test that check_pending_models returns 0 pending when all models are processed."""
|
||||
ws_manager = RecordingWebSocketManager()
|
||||
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||
|
||||
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
|
||||
|
||||
# Create processed models
|
||||
processed_hashes = ["a" * 64, "b" * 64, "c" * 64]
|
||||
models = [
|
||||
{"sha256": h, "model_name": f"Model {i}"}
|
||||
for i, h in enumerate(processed_hashes)
|
||||
]
|
||||
|
||||
# Create progress file with all models processed
|
||||
progress_file = tmp_path / ".download_progress.json"
|
||||
progress_file.write_text(
|
||||
json.dumps({"processed_models": processed_hashes, "failed_models": []}),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
# Create model directories with files (simulating completed downloads)
|
||||
for h in processed_hashes:
|
||||
model_dir = tmp_path / h
|
||||
model_dir.mkdir()
|
||||
(model_dir / "image_0.png").write_text("data")
|
||||
|
||||
_patch_scanners(monkeypatch, lora_scanner=StubScanner(models))
|
||||
|
||||
result = await manager.check_pending_models(["lora"])
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["is_downloading"] is False
|
||||
assert result["total_models"] == 3
|
||||
assert result["pending_count"] == 0
|
||||
assert result["processed_count"] == 3
|
||||
assert result["needs_download"] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.usefixtures("tmp_path")
|
||||
async def test_check_pending_models_finds_unprocessed_models(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path,
|
||||
settings_manager,
|
||||
):
|
||||
"""Test that check_pending_models correctly identifies unprocessed models."""
|
||||
ws_manager = RecordingWebSocketManager()
|
||||
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||
|
||||
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
|
||||
|
||||
# Create models - some processed, some not
|
||||
processed_hash = "a" * 64
|
||||
unprocessed_hash = "b" * 64
|
||||
models = [
|
||||
{"sha256": processed_hash, "model_name": "Processed Model"},
|
||||
{"sha256": unprocessed_hash, "model_name": "Unprocessed Model"},
|
||||
]
|
||||
|
||||
# Create progress file with only one model processed
|
||||
progress_file = tmp_path / ".download_progress.json"
|
||||
progress_file.write_text(
|
||||
json.dumps({"processed_models": [processed_hash], "failed_models": []}),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
# Create directory only for processed model
|
||||
processed_dir = tmp_path / processed_hash
|
||||
processed_dir.mkdir()
|
||||
(processed_dir / "image_0.png").write_text("data")
|
||||
|
||||
_patch_scanners(monkeypatch, lora_scanner=StubScanner(models))
|
||||
|
||||
result = await manager.check_pending_models(["lora"])
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["total_models"] == 2
|
||||
assert result["pending_count"] == 1
|
||||
assert result["processed_count"] == 1
|
||||
assert result["needs_download"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.usefixtures("tmp_path")
|
||||
async def test_check_pending_models_skips_models_without_hash(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path,
|
||||
settings_manager,
|
||||
):
|
||||
"""Test that models without sha256 are not counted as pending."""
|
||||
ws_manager = RecordingWebSocketManager()
|
||||
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||
|
||||
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
|
||||
|
||||
# Models - one with hash, one without
|
||||
models = [
|
||||
{"sha256": "a" * 64, "model_name": "Hashed Model"},
|
||||
{"sha256": None, "model_name": "No Hash Model"},
|
||||
{"model_name": "Missing Hash Model"}, # No sha256 key at all
|
||||
]
|
||||
|
||||
_patch_scanners(monkeypatch, lora_scanner=StubScanner(models))
|
||||
|
||||
result = await manager.check_pending_models(["lora"])
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["total_models"] == 3
|
||||
assert result["pending_count"] == 1 # Only the one with hash
|
||||
assert result["needs_download"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.usefixtures("tmp_path")
|
||||
async def test_check_pending_models_handles_multiple_model_types(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path,
|
||||
settings_manager,
|
||||
):
|
||||
"""Test that check_pending_models aggregates counts across multiple model types."""
|
||||
ws_manager = RecordingWebSocketManager()
|
||||
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||
|
||||
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
|
||||
|
||||
lora_models = [
|
||||
{"sha256": "a" * 64, "model_name": "Lora 1"},
|
||||
{"sha256": "b" * 64, "model_name": "Lora 2"},
|
||||
]
|
||||
checkpoint_models = [
|
||||
{"sha256": "c" * 64, "model_name": "Checkpoint 1"},
|
||||
]
|
||||
embedding_models = [
|
||||
{"sha256": "d" * 64, "model_name": "Embedding 1"},
|
||||
{"sha256": "e" * 64, "model_name": "Embedding 2"},
|
||||
{"sha256": "f" * 64, "model_name": "Embedding 3"},
|
||||
]
|
||||
|
||||
_patch_scanners(
|
||||
monkeypatch,
|
||||
lora_scanner=StubScanner(lora_models),
|
||||
checkpoint_scanner=StubScanner(checkpoint_models),
|
||||
embedding_scanner=StubScanner(embedding_models),
|
||||
)
|
||||
|
||||
result = await manager.check_pending_models(["lora", "checkpoint", "embedding"])
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["total_models"] == 6 # 2 + 1 + 3
|
||||
assert result["pending_count"] == 6 # All unprocessed
|
||||
assert result["needs_download"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.usefixtures("tmp_path")
|
||||
async def test_check_pending_models_returns_error_when_download_in_progress(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path,
|
||||
settings_manager,
|
||||
):
|
||||
"""Test that check_pending_models returns special response when download is running."""
|
||||
ws_manager = RecordingWebSocketManager()
|
||||
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||
|
||||
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
|
||||
|
||||
# Simulate download in progress
|
||||
manager._is_downloading = True
|
||||
|
||||
result = await manager.check_pending_models(["lora"])
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["is_downloading"] is True
|
||||
assert result["needs_download"] is False
|
||||
assert result["pending_count"] == 0
|
||||
assert "already in progress" in result["message"].lower()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.usefixtures("tmp_path")
|
||||
async def test_check_pending_models_handles_empty_library(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path,
|
||||
settings_manager,
|
||||
):
|
||||
"""Test that check_pending_models handles empty model library."""
|
||||
ws_manager = RecordingWebSocketManager()
|
||||
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||
|
||||
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
|
||||
|
||||
_patch_scanners(monkeypatch, lora_scanner=StubScanner([]))
|
||||
|
||||
result = await manager.check_pending_models(["lora"])
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["total_models"] == 0
|
||||
assert result["pending_count"] == 0
|
||||
assert result["processed_count"] == 0
|
||||
assert result["needs_download"] is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.usefixtures("tmp_path")
|
||||
async def test_check_pending_models_reads_failed_models(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path,
|
||||
settings_manager,
|
||||
):
|
||||
"""Test that check_pending_models correctly reports failed model count."""
|
||||
ws_manager = RecordingWebSocketManager()
|
||||
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||
|
||||
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
|
||||
|
||||
models = [{"sha256": "a" * 64, "model_name": "Model"}]
|
||||
|
||||
# Create progress file with failed models
|
||||
progress_file = tmp_path / ".download_progress.json"
|
||||
progress_file.write_text(
|
||||
json.dumps({"processed_models": [], "failed_models": ["a" * 64, "b" * 64]}),
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
_patch_scanners(monkeypatch, lora_scanner=StubScanner(models))
|
||||
|
||||
result = await manager.check_pending_models(["lora"])
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["failed_count"] == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.usefixtures("tmp_path")
|
||||
async def test_check_pending_models_handles_missing_progress_file(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path,
|
||||
settings_manager,
|
||||
):
|
||||
"""Test that check_pending_models works correctly when no progress file exists."""
|
||||
ws_manager = RecordingWebSocketManager()
|
||||
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||
|
||||
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
|
||||
|
||||
models = [
|
||||
{"sha256": "a" * 64, "model_name": "Model 1"},
|
||||
{"sha256": "b" * 64, "model_name": "Model 2"},
|
||||
]
|
||||
|
||||
_patch_scanners(monkeypatch, lora_scanner=StubScanner(models))
|
||||
|
||||
# No progress file created
|
||||
result = await manager.check_pending_models(["lora"])
|
||||
|
||||
assert result["success"] is True
|
||||
assert result["total_models"] == 2
|
||||
assert result["pending_count"] == 2 # All pending since no progress
|
||||
assert result["processed_count"] == 0
|
||||
assert result["failed_count"] == 0
|
||||
assert result["needs_download"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.usefixtures("tmp_path")
|
||||
async def test_check_pending_models_handles_corrupted_progress_file(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
tmp_path,
|
||||
settings_manager,
|
||||
):
|
||||
"""Test that check_pending_models handles corrupted progress file gracefully."""
|
||||
ws_manager = RecordingWebSocketManager()
|
||||
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||
|
||||
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
|
||||
|
||||
models = [{"sha256": "a" * 64, "model_name": "Model"}]
|
||||
|
||||
# Create corrupted progress file
|
||||
progress_file = tmp_path / ".download_progress.json"
|
||||
progress_file.write_text("not valid json", encoding="utf-8")
|
||||
|
||||
_patch_scanners(monkeypatch, lora_scanner=StubScanner(models))
|
||||
|
||||
result = await manager.check_pending_models(["lora"])
|
||||
|
||||
# Should still succeed, treating all as unprocessed
|
||||
assert result["success"] is True
|
||||
assert result["total_models"] == 1
|
||||
assert result["pending_count"] == 1
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def settings_manager():
|
||||
return get_settings_manager()
|
||||
167
tests/services/test_model_scanner_cache_validation.py
Normal file
167
tests/services/test_model_scanner_cache_validation.py
Normal file
@@ -0,0 +1,167 @@
|
||||
"""
|
||||
Integration tests for cache validation in ModelScanner
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
|
||||
from py.services.model_scanner import ModelScanner
|
||||
from py.services.cache_entry_validator import CacheEntryValidator
|
||||
from py.services.cache_health_monitor import CacheHealthMonitor, CacheHealthStatus
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_scanner_validates_cache_entries(tmp_path_factory):
|
||||
"""Test that ModelScanner validates cache entries during initialization"""
|
||||
# Create temporary test data
|
||||
tmp_dir = tmp_path_factory.mktemp("test_loras")
|
||||
|
||||
# Create test files
|
||||
test_file = tmp_dir / "test_model.safetensors"
|
||||
test_file.write_bytes(b"fake model data" * 100)
|
||||
|
||||
# Mock model scanner (we can't easily instantiate a full scanner in tests)
|
||||
# Instead, test the validation logic directly
|
||||
entries = [
|
||||
{
|
||||
'file_path': str(test_file),
|
||||
'sha256': 'abc123def456',
|
||||
'file_name': 'test_model.safetensors',
|
||||
},
|
||||
{
|
||||
'file_path': str(tmp_dir / 'invalid.safetensors'),
|
||||
# Missing sha256 - invalid
|
||||
},
|
||||
]
|
||||
|
||||
valid, invalid = CacheEntryValidator.validate_batch(entries, auto_repair=True)
|
||||
|
||||
assert len(valid) == 1
|
||||
assert len(invalid) == 1
|
||||
assert valid[0]['sha256'] == 'abc123def456'
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_scanner_detects_degraded_cache():
|
||||
"""Test that ModelScanner detects degraded cache health"""
|
||||
# Create 100 entries with 2% corruption
|
||||
entries = [
|
||||
{
|
||||
'file_path': f'/models/test{i}.safetensors',
|
||||
'sha256': f'hash{i}',
|
||||
}
|
||||
for i in range(98)
|
||||
]
|
||||
# Add 2 invalid entries
|
||||
entries.append({'file_path': '/models/invalid1.safetensors'})
|
||||
entries.append({'file_path': '/models/invalid2.safetensors'})
|
||||
|
||||
monitor = CacheHealthMonitor()
|
||||
report = monitor.check_health(entries, auto_repair=True)
|
||||
|
||||
assert report.status == CacheHealthStatus.DEGRADED
|
||||
assert report.invalid_entries == 2
|
||||
assert report.valid_entries == 98
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_scanner_detects_corrupted_cache():
|
||||
"""Test that ModelScanner detects corrupted cache health"""
|
||||
# Create 100 entries with 10% corruption
|
||||
entries = [
|
||||
{
|
||||
'file_path': f'/models/test{i}.safetensors',
|
||||
'sha256': f'hash{i}',
|
||||
}
|
||||
for i in range(90)
|
||||
]
|
||||
# Add 10 invalid entries
|
||||
for i in range(10):
|
||||
entries.append({'file_path': f'/models/invalid{i}.safetensors'})
|
||||
|
||||
monitor = CacheHealthMonitor()
|
||||
report = monitor.check_health(entries, auto_repair=True)
|
||||
|
||||
assert report.status == CacheHealthStatus.CORRUPTED
|
||||
assert report.invalid_entries == 10
|
||||
assert report.valid_entries == 90
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_scanner_removes_invalid_from_hash_index():
|
||||
"""Test that ModelScanner removes invalid entries from hash index"""
|
||||
from py.services.model_hash_index import ModelHashIndex
|
||||
|
||||
# Create a hash index with some entries
|
||||
hash_index = ModelHashIndex()
|
||||
valid_entry = {
|
||||
'file_path': '/models/valid.safetensors',
|
||||
'sha256': 'abc123',
|
||||
}
|
||||
invalid_entry = {
|
||||
'file_path': '/models/invalid.safetensors',
|
||||
'sha256': '', # Empty sha256
|
||||
}
|
||||
|
||||
# Add entries to hash index
|
||||
hash_index.add_entry(valid_entry['sha256'], valid_entry['file_path'])
|
||||
hash_index.add_entry(invalid_entry['sha256'], invalid_entry['file_path'])
|
||||
|
||||
# Verify both entries are in the index (using get_hash method)
|
||||
assert hash_index.get_hash(valid_entry['file_path']) == valid_entry['sha256']
|
||||
# Invalid entry won't be added due to empty sha256
|
||||
assert hash_index.get_hash(invalid_entry['file_path']) is None
|
||||
|
||||
# Simulate removing invalid entry (it's not actually there, but let's test the method)
|
||||
hash_index.remove_by_path(
|
||||
CacheEntryValidator.get_file_path_safe(invalid_entry),
|
||||
CacheEntryValidator.get_sha256_safe(invalid_entry)
|
||||
)
|
||||
|
||||
# Verify valid entry remains
|
||||
assert hash_index.get_hash(valid_entry['file_path']) == valid_entry['sha256']
|
||||
|
||||
|
||||
def test_cache_entry_validator_handles_various_field_types():
|
||||
"""Test that validator handles various field types correctly"""
|
||||
# Test with different field types
|
||||
entry = {
|
||||
'file_path': '/models/test.safetensors',
|
||||
'sha256': 'abc123',
|
||||
'size': 1024, # int
|
||||
'modified': 1234567890.0, # float
|
||||
'favorite': True, # bool
|
||||
'tags': ['tag1', 'tag2'], # list
|
||||
'exclude': False, # bool
|
||||
}
|
||||
|
||||
result = CacheEntryValidator.validate(entry, auto_repair=False)
|
||||
|
||||
assert result.is_valid is True
|
||||
assert result.repaired is False
|
||||
|
||||
|
||||
def test_cache_health_report_serialization():
|
||||
"""Test that HealthReport can be serialized to dict"""
|
||||
from py.services.cache_health_monitor import HealthReport
|
||||
|
||||
report = HealthReport(
|
||||
status=CacheHealthStatus.DEGRADED,
|
||||
total_entries=100,
|
||||
valid_entries=98,
|
||||
invalid_entries=2,
|
||||
repaired_entries=1,
|
||||
invalid_paths=['/path1', '/path2'],
|
||||
message="Cache issues detected"
|
||||
)
|
||||
|
||||
result = report.to_dict()
|
||||
|
||||
assert result['status'] == 'degraded'
|
||||
assert result['total_entries'] == 100
|
||||
assert result['valid_entries'] == 98
|
||||
assert result['invalid_entries'] == 2
|
||||
assert result['repaired_entries'] == 1
|
||||
assert result['corruption_rate'] == '2.0%'
|
||||
assert len(result['invalid_paths']) == 2
|
||||
assert result['message'] == "Cache issues detected"
|
||||
@@ -242,6 +242,148 @@ async def test_bulk_metadata_refresh_reports_errors() -> None:
|
||||
assert progress.events[-1]["error"] == "boom"
|
||||
|
||||
|
||||
async def test_bulk_metadata_refresh_skips_confirmed_not_found_models(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Models marked as from_civitai=False and civitai_deleted=True should be skipped."""
|
||||
scanner = MockScanner()
|
||||
scanner._cache.raw_data = [
|
||||
{
|
||||
"file_path": "model1.safetensors",
|
||||
"sha256": "hash1",
|
||||
"from_civitai": False,
|
||||
"civitai_deleted": True,
|
||||
"model_name": "NotOnCivitAI",
|
||||
},
|
||||
{
|
||||
"file_path": "model2.safetensors",
|
||||
"sha256": "hash2",
|
||||
"from_civitai": True,
|
||||
"model_name": "OnCivitAI",
|
||||
},
|
||||
]
|
||||
service = MockModelService(scanner)
|
||||
metadata_sync = StubMetadataSync()
|
||||
settings = StubSettings(enable_metadata_archive_db=False)
|
||||
progress = ProgressCollector()
|
||||
|
||||
async def fake_hydrate(model_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# Preserve the original data (simulating no metadata file on disk)
|
||||
return model_data
|
||||
|
||||
monkeypatch.setattr(MetadataManager, "hydrate_model_data", staticmethod(fake_hydrate))
|
||||
|
||||
use_case = BulkMetadataRefreshUseCase(
|
||||
service=service,
|
||||
metadata_sync=metadata_sync,
|
||||
settings_service=settings,
|
||||
logger=logging.getLogger("test"),
|
||||
)
|
||||
|
||||
result = await use_case.execute_with_error_handling(progress_callback=progress)
|
||||
|
||||
assert result["success"] is True
|
||||
# Only model2 should be processed (model1 is skipped)
|
||||
assert result["processed"] == 1
|
||||
assert result["updated"] == 1
|
||||
assert len(metadata_sync.calls) == 1
|
||||
assert metadata_sync.calls[0]["file_path"] == "model2.safetensors"
|
||||
|
||||
|
||||
async def test_bulk_metadata_refresh_skips_when_archive_checked(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Models with db_checked=True should be skipped even if archive DB is enabled."""
|
||||
scanner = MockScanner()
|
||||
scanner._cache.raw_data = [
|
||||
{
|
||||
"file_path": "model1.safetensors",
|
||||
"sha256": "hash1",
|
||||
"from_civitai": False,
|
||||
"civitai_deleted": True,
|
||||
"db_checked": True,
|
||||
"model_name": "ArchiveChecked",
|
||||
},
|
||||
{
|
||||
"file_path": "model2.safetensors",
|
||||
"sha256": "hash2",
|
||||
"from_civitai": False,
|
||||
"civitai_deleted": True,
|
||||
"db_checked": False,
|
||||
"model_name": "ArchiveNotChecked",
|
||||
},
|
||||
]
|
||||
service = MockModelService(scanner)
|
||||
metadata_sync = StubMetadataSync()
|
||||
settings = StubSettings(enable_metadata_archive_db=True)
|
||||
progress = ProgressCollector()
|
||||
|
||||
async def fake_hydrate(model_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return model_data
|
||||
|
||||
monkeypatch.setattr(MetadataManager, "hydrate_model_data", staticmethod(fake_hydrate))
|
||||
|
||||
use_case = BulkMetadataRefreshUseCase(
|
||||
service=service,
|
||||
metadata_sync=metadata_sync,
|
||||
settings_service=settings,
|
||||
logger=logging.getLogger("test"),
|
||||
)
|
||||
|
||||
result = await use_case.execute_with_error_handling(progress_callback=progress)
|
||||
|
||||
assert result["success"] is True
|
||||
# Only model2 should be processed (model1 has db_checked=True)
|
||||
assert result["processed"] == 1
|
||||
assert result["updated"] == 1
|
||||
assert len(metadata_sync.calls) == 1
|
||||
assert metadata_sync.calls[0]["file_path"] == "model2.safetensors"
|
||||
|
||||
|
||||
async def test_bulk_metadata_refresh_processes_never_fetched_models(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
"""Models that have never been fetched (from_civitai=None) should be processed."""
|
||||
scanner = MockScanner()
|
||||
scanner._cache.raw_data = [
|
||||
{
|
||||
"file_path": "model1.safetensors",
|
||||
"sha256": "hash1",
|
||||
"from_civitai": None,
|
||||
"model_name": "NeverFetched",
|
||||
},
|
||||
{
|
||||
"file_path": "model2.safetensors",
|
||||
"sha256": "hash2",
|
||||
"model_name": "NoFromCivitaiField",
|
||||
},
|
||||
]
|
||||
service = MockModelService(scanner)
|
||||
metadata_sync = StubMetadataSync()
|
||||
settings = StubSettings(enable_metadata_archive_db=False)
|
||||
progress = ProgressCollector()
|
||||
|
||||
async def fake_hydrate(model_data: Dict[str, Any]) -> Dict[str, Any]:
|
||||
return model_data
|
||||
|
||||
monkeypatch.setattr(MetadataManager, "hydrate_model_data", staticmethod(fake_hydrate))
|
||||
|
||||
use_case = BulkMetadataRefreshUseCase(
|
||||
service=service,
|
||||
metadata_sync=metadata_sync,
|
||||
settings_service=settings,
|
||||
logger=logging.getLogger("test"),
|
||||
)
|
||||
|
||||
result = await use_case.execute_with_error_handling(progress_callback=progress)
|
||||
|
||||
assert result["success"] is True
|
||||
# Both models should be processed
|
||||
assert result["processed"] == 2
|
||||
assert result["updated"] == 2
|
||||
assert len(metadata_sync.calls) == 2
|
||||
|
||||
|
||||
async def test_download_model_use_case_raises_validation_error() -> None:
|
||||
coordinator = StubDownloadCoordinator(error="validation")
|
||||
use_case = DownloadModelUseCase(download_coordinator=coordinator)
|
||||
|
||||
@@ -75,6 +75,31 @@ def test_get_file_extension_defaults_to_jpg() -> None:
|
||||
assert ext == ".jpg"
|
||||
|
||||
|
||||
def test_get_file_extension_from_media_type_hint_video() -> None:
|
||||
"""Test that media_type_hint='video' returns .mp4 when other methods fail"""
|
||||
ext = processor_module.ExampleImagesProcessor._get_file_extension_from_content_or_headers(
|
||||
b"", {}, "https://c.genur.art/536be3c9-e506-4365-b078-bfbc5df9ceec", "video"
|
||||
)
|
||||
assert ext == ".mp4"
|
||||
|
||||
|
||||
def test_get_file_extension_from_media_type_hint_image() -> None:
|
||||
"""Test that media_type_hint='image' falls back to .jpg"""
|
||||
ext = processor_module.ExampleImagesProcessor._get_file_extension_from_content_or_headers(
|
||||
b"", {}, "https://example.com/no-extension", "image"
|
||||
)
|
||||
assert ext == ".jpg"
|
||||
|
||||
|
||||
def test_get_file_extension_media_type_hint_low_priority() -> None:
|
||||
"""Test that media_type_hint is only used as last resort (after URL extension)"""
|
||||
# URL has extension, should use that instead of media_type_hint
|
||||
ext = processor_module.ExampleImagesProcessor._get_file_extension_from_content_or_headers(
|
||||
b"", {}, "https://example.com/video.mp4", "image"
|
||||
)
|
||||
assert ext == ".mp4"
|
||||
|
||||
|
||||
class StubScanner:
|
||||
def __init__(self, models: list[Dict[str, Any]]) -> None:
|
||||
self._cache = SimpleNamespace(raw_data=models)
|
||||
|
||||
@@ -6,7 +6,8 @@ export default defineConfig({
|
||||
globals: true,
|
||||
setupFiles: ['tests/frontend/setup.js'],
|
||||
include: [
|
||||
'tests/frontend/**/*.test.js'
|
||||
'tests/frontend/**/*.test.js',
|
||||
'tests/frontend/**/*.test.ts'
|
||||
],
|
||||
coverage: {
|
||||
enabled: process.env.VITEST_COVERAGE === 'true',
|
||||
|
||||
1865
vue-widgets/package-lock.json
generated
1865
vue-widgets/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -12,9 +12,13 @@
|
||||
"@comfyorg/comfyui-frontend-types": "^1.35.4",
|
||||
"@types/node": "^22.10.1",
|
||||
"@vitejs/plugin-vue": "^5.2.3",
|
||||
"@vitest/coverage-v8": "^3.2.4",
|
||||
"@vue/test-utils": "^2.4.6",
|
||||
"jsdom": "^26.0.0",
|
||||
"typescript": "^5.7.2",
|
||||
"vite": "^6.3.5",
|
||||
"vite-plugin-css-injected-by-js": "^3.5.2",
|
||||
"vitest": "^3.0.0",
|
||||
"vue-tsc": "^2.1.10"
|
||||
},
|
||||
"scripts": {
|
||||
@@ -24,6 +28,9 @@
|
||||
"typecheck": "vue-tsc --noEmit",
|
||||
"clean": "rm -rf ../web/comfyui/vue-widgets",
|
||||
"rebuild": "npm run clean && npm run build",
|
||||
"prepare": "npm run build"
|
||||
"prepare": "npm run build",
|
||||
"test": "vitest run",
|
||||
"test:watch": "vitest",
|
||||
"test:coverage": "vitest run --coverage"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,11 +10,28 @@
|
||||
:use-custom-clip-range="state.useCustomClipRange.value"
|
||||
:is-clip-strength-disabled="state.isClipStrengthDisabled.value"
|
||||
:is-loading="state.isLoading.value"
|
||||
:repeat-count="state.repeatCount.value"
|
||||
:repeat-used="state.displayRepeatUsed.value"
|
||||
:is-paused="state.isPaused.value"
|
||||
:is-pause-disabled="hasQueuedPrompts"
|
||||
:is-workflow-executing="state.isWorkflowExecuting.value"
|
||||
:executing-repeat-step="state.executingRepeatStep.value"
|
||||
@update:current-index="handleIndexUpdate"
|
||||
@update:model-strength="state.modelStrength.value = $event"
|
||||
@update:clip-strength="state.clipStrength.value = $event"
|
||||
@update:use-custom-clip-range="handleUseCustomClipRangeChange"
|
||||
@refresh="handleRefresh"
|
||||
@update:repeat-count="handleRepeatCountChange"
|
||||
@toggle-pause="handleTogglePause"
|
||||
@reset-index="handleResetIndex"
|
||||
@open-lora-selector="isModalOpen = true"
|
||||
/>
|
||||
|
||||
<LoraListModal
|
||||
:visible="isModalOpen"
|
||||
:lora-list="cachedLoraList"
|
||||
:current-index="state.currentIndex.value"
|
||||
@close="isModalOpen = false"
|
||||
@select="handleModalSelect"
|
||||
/>
|
||||
</div>
|
||||
</template>
|
||||
@@ -22,8 +39,9 @@
|
||||
<script setup lang="ts">
|
||||
import { onMounted, ref } from 'vue'
|
||||
import LoraCyclerSettingsView from './lora-cycler/LoraCyclerSettingsView.vue'
|
||||
import LoraListModal from './lora-cycler/LoraListModal.vue'
|
||||
import { useLoraCyclerState } from '../composables/useLoraCyclerState'
|
||||
import type { ComponentWidget, CyclerConfig, LoraPoolConfig } from '../composables/types'
|
||||
import type { ComponentWidget, CyclerConfig, LoraPoolConfig, LoraItem } from '../composables/types'
|
||||
|
||||
type CyclerWidget = ComponentWidget<CyclerConfig>
|
||||
|
||||
@@ -31,6 +49,7 @@ type CyclerWidget = ComponentWidget<CyclerConfig>
|
||||
const props = defineProps<{
|
||||
widget: CyclerWidget
|
||||
node: { id: number; inputs?: any[]; widgets?: any[]; graph?: any }
|
||||
api?: any // ComfyUI API for execution events
|
||||
}>()
|
||||
|
||||
// State management
|
||||
@@ -39,12 +58,50 @@ const state = useLoraCyclerState(props.widget)
|
||||
// Symbol to track if the widget has been executed at least once
|
||||
const HAS_EXECUTED = Symbol('HAS_EXECUTED')
|
||||
|
||||
// Execution context queue for batch queue synchronization
|
||||
// In batch queue mode, all beforeQueued calls happen BEFORE any onExecuted calls,
|
||||
// so we need to snapshot the state at queue time and replay it during execution
|
||||
interface ExecutionContext {
|
||||
isPaused: boolean
|
||||
repeatUsed: number
|
||||
repeatCount: number
|
||||
shouldAdvanceDisplay: boolean
|
||||
displayRepeatUsed: number // Value to show in UI after completion
|
||||
}
|
||||
const executionQueue: ExecutionContext[] = []
|
||||
|
||||
// Reactive flag to track if there are queued prompts (for disabling pause button)
|
||||
const hasQueuedPrompts = ref(false)
|
||||
|
||||
// Track pending executions for batch queue support (deferred UI updates)
|
||||
// Uses FIFO order since executions are processed in the order they were queued
|
||||
interface PendingExecution {
|
||||
repeatUsed: number
|
||||
repeatCount: number
|
||||
shouldAdvanceDisplay: boolean
|
||||
displayRepeatUsed: number // Value to show in UI after completion
|
||||
output?: {
|
||||
nextIndex: number
|
||||
nextLoraName: string
|
||||
nextLoraFilename: string
|
||||
currentLoraName: string
|
||||
currentLoraFilename: string
|
||||
}
|
||||
}
|
||||
const pendingExecutions: PendingExecution[] = []
|
||||
|
||||
// Track last known pool config hash
|
||||
const lastPoolConfigHash = ref('')
|
||||
|
||||
// Track if component is mounted
|
||||
const isMounted = ref(false)
|
||||
|
||||
// Modal state
|
||||
const isModalOpen = ref(false)
|
||||
|
||||
// Cache for LoRA list (used by modal)
|
||||
const cachedLoraList = ref<LoraItem[]>([])
|
||||
|
||||
// Get pool config from connected node
|
||||
const getPoolConfig = (): LoraPoolConfig | null => {
|
||||
// Check if getPoolConfig method exists on node (added by main.ts)
|
||||
@@ -54,27 +111,47 @@ const getPoolConfig = (): LoraPoolConfig | null => {
|
||||
return null
|
||||
}
|
||||
|
||||
// Update display from LoRA list and index
|
||||
const updateDisplayFromLoraList = (loraList: LoraItem[], index: number) => {
|
||||
if (loraList.length > 0 && index > 0 && index <= loraList.length) {
|
||||
const currentLora = loraList[index - 1]
|
||||
if (currentLora) {
|
||||
state.currentLoraName.value = currentLora.file_name
|
||||
state.currentLoraFilename.value = currentLora.file_name
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handle index update from user
|
||||
const handleIndexUpdate = async (newIndex: number) => {
|
||||
// Reset execution state when user manually changes index
|
||||
// This ensures the next execution starts from the user-set index
|
||||
;(props.widget as any)[HAS_EXECUTED] = false
|
||||
state.executionIndex.value = null
|
||||
state.nextIndex.value = null
|
||||
|
||||
// Clear execution queue since user is manually changing state
|
||||
executionQueue.length = 0
|
||||
hasQueuedPrompts.value = false
|
||||
|
||||
state.setIndex(newIndex)
|
||||
|
||||
// Refresh list to update current LoRA display
|
||||
try {
|
||||
const poolConfig = getPoolConfig()
|
||||
const loraList = await state.fetchCyclerList(poolConfig)
|
||||
|
||||
if (loraList.length > 0 && newIndex > 0 && newIndex <= loraList.length) {
|
||||
const currentLora = loraList[newIndex - 1]
|
||||
if (currentLora) {
|
||||
state.currentLoraName.value = currentLora.file_name
|
||||
state.currentLoraFilename.value = currentLora.file_name
|
||||
}
|
||||
}
|
||||
cachedLoraList.value = loraList
|
||||
updateDisplayFromLoraList(loraList, newIndex)
|
||||
} catch (error) {
|
||||
console.error('[LoraCyclerWidget] Error updating index:', error)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle LoRA selection from modal
|
||||
const handleModalSelect = (index: number) => {
|
||||
handleIndexUpdate(index)
|
||||
}
|
||||
|
||||
// Handle use custom clip range toggle
|
||||
const handleUseCustomClipRangeChange = (newValue: boolean) => {
|
||||
state.useCustomClipRange.value = newValue
|
||||
@@ -84,13 +161,41 @@ const handleUseCustomClipRangeChange = (newValue: boolean) => {
|
||||
}
|
||||
}
|
||||
|
||||
// Handle refresh button click
|
||||
const handleRefresh = async () => {
|
||||
// Handle repeat count change
|
||||
const handleRepeatCountChange = (newValue: number) => {
|
||||
state.repeatCount.value = newValue
|
||||
// Reset repeatUsed when changing repeat count
|
||||
state.repeatUsed.value = 0
|
||||
state.displayRepeatUsed.value = 0
|
||||
}
|
||||
|
||||
// Handle pause toggle
|
||||
const handleTogglePause = () => {
|
||||
state.togglePause()
|
||||
}
|
||||
|
||||
// Handle reset index
|
||||
const handleResetIndex = async () => {
|
||||
// Reset execution state
|
||||
;(props.widget as any)[HAS_EXECUTED] = false
|
||||
state.executionIndex.value = null
|
||||
state.nextIndex.value = null
|
||||
|
||||
// Clear execution queue since user is resetting state
|
||||
executionQueue.length = 0
|
||||
hasQueuedPrompts.value = false
|
||||
|
||||
// Reset index and repeat state
|
||||
state.resetIndex()
|
||||
|
||||
// Refresh list to update current LoRA display
|
||||
try {
|
||||
const poolConfig = getPoolConfig()
|
||||
await state.refreshList(poolConfig)
|
||||
const loraList = await state.fetchCyclerList(poolConfig)
|
||||
cachedLoraList.value = loraList
|
||||
updateDisplayFromLoraList(loraList, 1)
|
||||
} catch (error) {
|
||||
console.error('[LoraCyclerWidget] Error refreshing:', error)
|
||||
console.error('[LoraCyclerWidget] Error resetting index:', error)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -106,6 +211,9 @@ const checkPoolConfigChanges = async () => {
|
||||
lastPoolConfigHash.value = newHash
|
||||
try {
|
||||
await state.refreshList(poolConfig)
|
||||
// Update cached list when pool config changes
|
||||
const loraList = await state.fetchCyclerList(poolConfig)
|
||||
cachedLoraList.value = loraList
|
||||
} catch (error) {
|
||||
console.error('[LoraCyclerWidget] Error on pool config change:', error)
|
||||
}
|
||||
@@ -129,17 +237,68 @@ onMounted(async () => {
|
||||
|
||||
// Add beforeQueued hook to handle index shifting for batch queue synchronization
|
||||
// This ensures each execution uses a different LoRA in the cycle
|
||||
// Now with support for repeat count and pause features
|
||||
//
|
||||
// IMPORTANT: In batch queue mode, ALL beforeQueued calls happen BEFORE any execution.
|
||||
// We push an "execution context" snapshot to a queue so that onExecuted can use the
|
||||
// correct state values that were captured at queue time (not the live state).
|
||||
;(props.widget as any).beforeQueued = () => {
|
||||
if (state.isPaused.value) {
|
||||
// When paused: use current index, don't advance, don't count toward repeat limit
|
||||
// Push context indicating this execution should NOT advance display
|
||||
executionQueue.push({
|
||||
isPaused: true,
|
||||
repeatUsed: state.repeatUsed.value,
|
||||
repeatCount: state.repeatCount.value,
|
||||
shouldAdvanceDisplay: false,
|
||||
displayRepeatUsed: state.displayRepeatUsed.value // Keep current display value when paused
|
||||
})
|
||||
hasQueuedPrompts.value = true
|
||||
// CRITICAL: Clear execution_index when paused to force backend to use current_index
|
||||
// This ensures paused executions use the same LoRA regardless of any
|
||||
// execution_index set by previous non-paused beforeQueued calls
|
||||
const pausedConfig = state.buildConfig()
|
||||
pausedConfig.execution_index = null
|
||||
props.widget.value = pausedConfig
|
||||
return
|
||||
}
|
||||
|
||||
if ((props.widget as any)[HAS_EXECUTED]) {
|
||||
// After first execution: shift indices (previous next_index becomes execution_index)
|
||||
state.generateNextIndex()
|
||||
// After first execution: check repeat logic
|
||||
if (state.repeatUsed.value < state.repeatCount.value) {
|
||||
// Still repeating: increment repeatUsed, use same index
|
||||
state.repeatUsed.value++
|
||||
} else {
|
||||
// Repeat complete: reset repeatUsed to 1, advance to next index
|
||||
state.repeatUsed.value = 1
|
||||
state.generateNextIndex()
|
||||
}
|
||||
} else {
|
||||
// First execution: just initialize next_index (execution_index stays null)
|
||||
// This means first execution uses current_index from widget
|
||||
// First execution: initialize
|
||||
state.repeatUsed.value = 1
|
||||
state.initializeNextIndex()
|
||||
;(props.widget as any)[HAS_EXECUTED] = true
|
||||
}
|
||||
|
||||
// Determine if this execution should advance the display
|
||||
// (only when repeat cycle is complete for this queued item)
|
||||
const shouldAdvanceDisplay = state.repeatUsed.value >= state.repeatCount.value
|
||||
|
||||
// Calculate the display value to show after this execution completes
|
||||
// When advancing to a new LoRA: reset to 0 (fresh start for new LoRA)
|
||||
// When repeating same LoRA: show current repeat step
|
||||
const displayRepeatUsed = shouldAdvanceDisplay ? 0 : state.repeatUsed.value
|
||||
|
||||
// Push execution context snapshot to queue
|
||||
executionQueue.push({
|
||||
isPaused: false,
|
||||
repeatUsed: state.repeatUsed.value,
|
||||
repeatCount: state.repeatCount.value,
|
||||
shouldAdvanceDisplay,
|
||||
displayRepeatUsed
|
||||
})
|
||||
hasQueuedPrompts.value = true
|
||||
|
||||
// Update the widget value so the indices are included in the serialized config
|
||||
props.widget.value = state.buildConfig()
|
||||
}
|
||||
@@ -152,40 +311,71 @@ onMounted(async () => {
|
||||
const poolConfig = getPoolConfig()
|
||||
lastPoolConfigHash.value = state.hashPoolConfig(poolConfig)
|
||||
await state.refreshList(poolConfig)
|
||||
// Cache the initial LoRA list for modal
|
||||
const loraList = await state.fetchCyclerList(poolConfig)
|
||||
cachedLoraList.value = loraList
|
||||
} catch (error) {
|
||||
console.error('[LoraCyclerWidget] Error on initial load:', error)
|
||||
}
|
||||
|
||||
// Override onExecuted to handle backend UI updates
|
||||
// This defers the UI update until workflow completes (via API events)
|
||||
const originalOnExecuted = (props.node as any).onExecuted?.bind(props.node)
|
||||
|
||||
;(props.node as any).onExecuted = function(output: any) {
|
||||
console.log("[LoraCyclerWidget] Node executed with output:", output)
|
||||
|
||||
// Update state from backend response (values are wrapped in arrays)
|
||||
if (output?.next_index !== undefined) {
|
||||
const val = Array.isArray(output.next_index) ? output.next_index[0] : output.next_index
|
||||
state.currentIndex.value = val
|
||||
}
|
||||
// Pop execution context from queue (FIFO order)
|
||||
const context = executionQueue.shift()
|
||||
hasQueuedPrompts.value = executionQueue.length > 0
|
||||
|
||||
// Determine if we should advance the display index
|
||||
const shouldAdvanceDisplay = context
|
||||
? context.shouldAdvanceDisplay
|
||||
: (!state.isPaused.value && state.repeatUsed.value >= state.repeatCount.value)
|
||||
|
||||
// Extract output values
|
||||
const nextIndex = output?.next_index !== undefined
|
||||
? (Array.isArray(output.next_index) ? output.next_index[0] : output.next_index)
|
||||
: state.currentIndex.value
|
||||
const nextLoraName = output?.next_lora_name !== undefined
|
||||
? (Array.isArray(output.next_lora_name) ? output.next_lora_name[0] : output.next_lora_name)
|
||||
: ''
|
||||
const nextLoraFilename = output?.next_lora_filename !== undefined
|
||||
? (Array.isArray(output.next_lora_filename) ? output.next_lora_filename[0] : output.next_lora_filename)
|
||||
: ''
|
||||
const currentLoraName = output?.current_lora_name !== undefined
|
||||
? (Array.isArray(output.current_lora_name) ? output.current_lora_name[0] : output.current_lora_name)
|
||||
: ''
|
||||
const currentLoraFilename = output?.current_lora_filename !== undefined
|
||||
? (Array.isArray(output.current_lora_filename) ? output.current_lora_filename[0] : output.current_lora_filename)
|
||||
: ''
|
||||
|
||||
// Update total count immediately (doesn't need to wait for workflow completion)
|
||||
if (output?.total_count !== undefined) {
|
||||
const val = Array.isArray(output.total_count) ? output.total_count[0] : output.total_count
|
||||
state.totalCount.value = val
|
||||
}
|
||||
if (output?.current_lora_name !== undefined) {
|
||||
const val = Array.isArray(output.current_lora_name) ? output.current_lora_name[0] : output.current_lora_name
|
||||
state.currentLoraName.value = val
|
||||
}
|
||||
if (output?.current_lora_filename !== undefined) {
|
||||
const val = Array.isArray(output.current_lora_filename) ? output.current_lora_filename[0] : output.current_lora_filename
|
||||
state.currentLoraFilename.value = val
|
||||
}
|
||||
if (output?.next_lora_name !== undefined) {
|
||||
const val = Array.isArray(output.next_lora_name) ? output.next_lora_name[0] : output.next_lora_name
|
||||
state.currentLoraName.value = val
|
||||
}
|
||||
if (output?.next_lora_filename !== undefined) {
|
||||
const val = Array.isArray(output.next_lora_filename) ? output.next_lora_filename[0] : output.next_lora_filename
|
||||
state.currentLoraFilename.value = val
|
||||
|
||||
// Store pending update (will be applied on workflow completion)
|
||||
if (context) {
|
||||
pendingExecutions.push({
|
||||
repeatUsed: context.repeatUsed,
|
||||
repeatCount: context.repeatCount,
|
||||
shouldAdvanceDisplay,
|
||||
displayRepeatUsed: context.displayRepeatUsed,
|
||||
output: {
|
||||
nextIndex,
|
||||
nextLoraName,
|
||||
nextLoraFilename,
|
||||
currentLoraName,
|
||||
currentLoraFilename
|
||||
}
|
||||
})
|
||||
|
||||
// Update visual feedback state (don't update displayRepeatUsed yet - wait for workflow completion)
|
||||
state.executingRepeatStep.value = context.repeatUsed
|
||||
state.isWorkflowExecuting.value = true
|
||||
}
|
||||
|
||||
// Call original onExecuted if it exists
|
||||
@@ -194,11 +384,69 @@ onMounted(async () => {
|
||||
}
|
||||
}
|
||||
|
||||
// Set up execution tracking via API events
|
||||
if (props.api) {
|
||||
// Handle workflow completion events using FIFO order
|
||||
// Note: The 'executing' event doesn't contain prompt_id (only node ID as string),
|
||||
// so we use FIFO order instead of prompt_id matching since executions are processed
|
||||
// in the order they were queued
|
||||
const handleExecutionComplete = () => {
|
||||
// Process the first pending execution (FIFO order)
|
||||
if (pendingExecutions.length === 0) {
|
||||
return
|
||||
}
|
||||
|
||||
const pending = pendingExecutions.shift()!
|
||||
|
||||
// Apply UI update now that workflow is complete
|
||||
// Update repeat display (deferred like index updates)
|
||||
state.displayRepeatUsed.value = pending.displayRepeatUsed
|
||||
|
||||
if (pending.output) {
|
||||
if (pending.shouldAdvanceDisplay) {
|
||||
state.currentIndex.value = pending.output.nextIndex
|
||||
state.currentLoraName.value = pending.output.nextLoraName
|
||||
state.currentLoraFilename.value = pending.output.nextLoraFilename
|
||||
} else {
|
||||
// When not advancing, show current LoRA info
|
||||
state.currentLoraName.value = pending.output.currentLoraName
|
||||
state.currentLoraFilename.value = pending.output.currentLoraFilename
|
||||
}
|
||||
}
|
||||
|
||||
// Reset visual feedback if no more pending
|
||||
if (pendingExecutions.length === 0) {
|
||||
state.isWorkflowExecuting.value = false
|
||||
state.executingRepeatStep.value = 0
|
||||
}
|
||||
}
|
||||
|
||||
props.api.addEventListener('execution_success', handleExecutionComplete)
|
||||
props.api.addEventListener('execution_error', handleExecutionComplete)
|
||||
props.api.addEventListener('execution_interrupted', handleExecutionComplete)
|
||||
|
||||
// Store cleanup function for API listeners
|
||||
const apiCleanup = () => {
|
||||
props.api.removeEventListener('execution_success', handleExecutionComplete)
|
||||
props.api.removeEventListener('execution_error', handleExecutionComplete)
|
||||
props.api.removeEventListener('execution_interrupted', handleExecutionComplete)
|
||||
}
|
||||
|
||||
// Extend existing cleanup
|
||||
const existingCleanup = (props.widget as any).onRemoveCleanup
|
||||
;(props.widget as any).onRemoveCleanup = () => {
|
||||
existingCleanup?.()
|
||||
apiCleanup()
|
||||
}
|
||||
}
|
||||
|
||||
// Watch for connection changes by polling (since ComfyUI doesn't provide connection events)
|
||||
const checkInterval = setInterval(checkPoolConfigChanges, 1000)
|
||||
|
||||
// Cleanup on unmount (handled by Vue's effect scope)
|
||||
const existingCleanupForInterval = (props.widget as any).onRemoveCleanup
|
||||
;(props.widget as any).onRemoveCleanup = () => {
|
||||
existingCleanupForInterval?.()
|
||||
clearInterval(checkInterval)
|
||||
}
|
||||
})
|
||||
|
||||
@@ -6,57 +6,111 @@
|
||||
|
||||
<!-- Progress Display -->
|
||||
<div class="setting-section progress-section">
|
||||
<div class="progress-display">
|
||||
<div class="progress-info">
|
||||
<span class="progress-label">Next LoRA:</span>
|
||||
<span class="progress-name" :title="currentLoraFilename">{{ currentLoraName || 'None' }}</span>
|
||||
<div class="progress-display" :class="{ executing: isWorkflowExecuting }">
|
||||
<div
|
||||
class="progress-info"
|
||||
:class="{ disabled: isPauseDisabled }"
|
||||
@click="handleOpenSelector"
|
||||
>
|
||||
<span class="progress-label">{{ isWorkflowExecuting ? 'Using LoRA:' : 'Next LoRA:' }}</span>
|
||||
<span class="progress-name clickable" :class="{ disabled: isPauseDisabled }" :title="currentLoraFilename">
|
||||
{{ currentLoraName || 'None' }}
|
||||
<svg class="selector-icon" viewBox="0 0 24 24" fill="currentColor">
|
||||
<path d="M7 10l5 5 5-5z"/>
|
||||
</svg>
|
||||
</span>
|
||||
</div>
|
||||
<div class="progress-counter">
|
||||
<span class="progress-index">{{ currentIndex }}</span>
|
||||
<span class="progress-separator">/</span>
|
||||
<span class="progress-total">{{ totalCount }}</span>
|
||||
<button
|
||||
class="refresh-button"
|
||||
:disabled="isLoading"
|
||||
@click="$emit('refresh')"
|
||||
title="Refresh list"
|
||||
>
|
||||
<svg
|
||||
class="refresh-icon"
|
||||
:class="{ spinning: isLoading }"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
stroke-width="2"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
>
|
||||
<path d="M21 12a9 9 0 1 1-6.219-8.56"/>
|
||||
<path d="M21 3v5h-5"/>
|
||||
</svg>
|
||||
</button>
|
||||
|
||||
<!-- Repeat progress indicator (only shown when repeatCount > 1) -->
|
||||
<div v-if="repeatCount > 1" class="repeat-progress">
|
||||
<div class="repeat-progress-track">
|
||||
<div
|
||||
class="repeat-progress-fill"
|
||||
:style="{ width: `${(repeatUsed / repeatCount) * 100}%` }"
|
||||
:class="{ 'is-complete': repeatUsed >= repeatCount }"
|
||||
></div>
|
||||
</div>
|
||||
<span class="repeat-progress-text">{{ repeatUsed }}/{{ repeatCount }}</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Starting Index -->
|
||||
<!-- Starting Index with Advanced Controls -->
|
||||
<div class="setting-section">
|
||||
<label class="setting-label">Starting Index</label>
|
||||
<div class="index-input-container">
|
||||
<input
|
||||
type="number"
|
||||
class="index-input"
|
||||
:min="1"
|
||||
:max="totalCount || 1"
|
||||
:value="currentIndex"
|
||||
:disabled="totalCount === 0"
|
||||
@input="onIndexInput"
|
||||
@blur="onIndexBlur"
|
||||
@pointerdown.stop
|
||||
@pointermove.stop
|
||||
@pointerup.stop
|
||||
/>
|
||||
<span class="index-hint">1 - {{ totalCount || 1 }}</span>
|
||||
<div class="index-controls-row">
|
||||
<!-- Left: Index group -->
|
||||
<div class="control-group">
|
||||
<label class="control-group-label">Starting Index</label>
|
||||
<div class="control-group-content">
|
||||
<input
|
||||
type="number"
|
||||
class="index-input"
|
||||
:min="1"
|
||||
:max="totalCount || 1"
|
||||
:value="currentIndex"
|
||||
:disabled="totalCount === 0"
|
||||
@input="onIndexInput"
|
||||
@blur="onIndexBlur"
|
||||
@pointerdown.stop
|
||||
@pointermove.stop
|
||||
@pointerup.stop
|
||||
/>
|
||||
<span class="index-hint">/ {{ totalCount || 1 }}</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Right: Repeat group -->
|
||||
<div class="control-group">
|
||||
<label class="control-group-label">Repeat</label>
|
||||
<div class="control-group-content">
|
||||
<input
|
||||
type="number"
|
||||
class="repeat-input"
|
||||
min="1"
|
||||
max="99"
|
||||
:value="repeatCount"
|
||||
@input="onRepeatInput"
|
||||
@blur="onRepeatBlur"
|
||||
@pointerdown.stop
|
||||
@pointermove.stop
|
||||
@pointerup.stop
|
||||
title="Each LoRA will be used this many times before moving to the next"
|
||||
/>
|
||||
<span class="repeat-suffix">×</span>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Action buttons -->
|
||||
<div class="action-buttons">
|
||||
<button
|
||||
class="control-btn"
|
||||
:class="{ active: isPaused }"
|
||||
:disabled="isPauseDisabled"
|
||||
@click="$emit('toggle-pause')"
|
||||
:title="isPauseDisabled ? 'Cannot pause while prompts are queued' : (isPaused ? 'Continue iteration' : 'Pause iteration')"
|
||||
>
|
||||
<svg v-if="isPaused" viewBox="0 0 24 24" fill="currentColor" class="control-icon">
|
||||
<path d="M8 5v14l11-7z"/>
|
||||
</svg>
|
||||
<svg v-else viewBox="0 0 24 24" fill="currentColor" class="control-icon">
|
||||
<path d="M6 4h4v16H6zm8 0h4v16h-4z"/>
|
||||
</svg>
|
||||
</button>
|
||||
<button
|
||||
class="control-btn"
|
||||
@click="$emit('reset-index')"
|
||||
title="Reset to index 1"
|
||||
>
|
||||
<svg viewBox="0 0 24 24" fill="currentColor" class="control-icon">
|
||||
<path d="M12 5V1L7 6l5 5V7c3.31 0 6 2.69 6 6s-2.69 6-6 6-6-2.69-6-6H4c0 4.42 3.58 8 8 8s8-3.58 8-8-3.58-8-8-8z"/>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -122,7 +176,12 @@ const props = defineProps<{
|
||||
clipStrength: number
|
||||
useCustomClipRange: boolean
|
||||
isClipStrengthDisabled: boolean
|
||||
isLoading: boolean
|
||||
repeatCount: number
|
||||
repeatUsed: number
|
||||
isPaused: boolean
|
||||
isPauseDisabled: boolean
|
||||
isWorkflowExecuting: boolean
|
||||
executingRepeatStep: number
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{
|
||||
@@ -130,11 +189,22 @@ const emit = defineEmits<{
|
||||
'update:modelStrength': [value: number]
|
||||
'update:clipStrength': [value: number]
|
||||
'update:useCustomClipRange': [value: boolean]
|
||||
'refresh': []
|
||||
'update:repeatCount': [value: number]
|
||||
'toggle-pause': []
|
||||
'reset-index': []
|
||||
'open-lora-selector': []
|
||||
}>()
|
||||
|
||||
// Temporary value for input while typing
|
||||
const tempIndex = ref<string>('')
|
||||
const tempRepeat = ref<string>('')
|
||||
|
||||
const handleOpenSelector = () => {
|
||||
if (props.isPauseDisabled) {
|
||||
return
|
||||
}
|
||||
emit('open-lora-selector')
|
||||
}
|
||||
|
||||
const onIndexInput = (event: Event) => {
|
||||
const input = event.target as HTMLInputElement
|
||||
@@ -154,6 +224,25 @@ const onIndexBlur = (event: Event) => {
|
||||
}
|
||||
tempIndex.value = ''
|
||||
}
|
||||
|
||||
const onRepeatInput = (event: Event) => {
|
||||
const input = event.target as HTMLInputElement
|
||||
tempRepeat.value = input.value
|
||||
}
|
||||
|
||||
const onRepeatBlur = (event: Event) => {
|
||||
const input = event.target as HTMLInputElement
|
||||
const value = parseInt(input.value, 10)
|
||||
|
||||
if (!isNaN(value)) {
|
||||
const clampedValue = Math.max(1, Math.min(value, 99))
|
||||
emit('update:repeatCount', clampedValue)
|
||||
input.value = clampedValue.toString()
|
||||
} else {
|
||||
input.value = props.repeatCount.toString()
|
||||
}
|
||||
tempRepeat.value = ''
|
||||
}
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
@@ -203,6 +292,17 @@ const onIndexBlur = (event: Event) => {
|
||||
display: flex;
|
||||
justify-content: space-between;
|
||||
align-items: center;
|
||||
transition: border-color 0.3s ease;
|
||||
}
|
||||
|
||||
.progress-display.executing {
|
||||
border-color: rgba(66, 153, 225, 0.5);
|
||||
animation: pulse 2s ease-in-out infinite;
|
||||
}
|
||||
|
||||
@keyframes pulse {
|
||||
0%, 100% { border-color: rgba(66, 153, 225, 0.3); }
|
||||
50% { border-color: rgba(66, 153, 225, 0.7); }
|
||||
}
|
||||
|
||||
.progress-info {
|
||||
@@ -230,6 +330,42 @@ const onIndexBlur = (event: Event) => {
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.progress-name.clickable {
|
||||
cursor: pointer;
|
||||
padding: 2px 6px;
|
||||
margin: -2px -6px;
|
||||
border-radius: 4px;
|
||||
transition: all 0.2s;
|
||||
display: inline-flex;
|
||||
align-items: center;
|
||||
gap: 4px;
|
||||
}
|
||||
|
||||
.progress-name.clickable:hover:not(.disabled) {
|
||||
background: rgba(66, 153, 225, 0.2);
|
||||
color: rgba(191, 219, 254, 1);
|
||||
}
|
||||
|
||||
.progress-name.clickable.disabled {
|
||||
cursor: not-allowed;
|
||||
opacity: 0.5;
|
||||
}
|
||||
|
||||
.progress-info.disabled {
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
.selector-icon {
|
||||
width: 16px;
|
||||
height: 16px;
|
||||
opacity: 0.5;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
|
||||
.progress-name.clickable:hover .selector-icon {
|
||||
opacity: 0.8;
|
||||
}
|
||||
|
||||
.progress-counter {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
@@ -243,6 +379,9 @@ const onIndexBlur = (event: Event) => {
|
||||
font-weight: 600;
|
||||
color: rgba(66, 153, 225, 1);
|
||||
font-family: 'SF Mono', 'Roboto Mono', monospace;
|
||||
min-width: 4ch;
|
||||
text-align: right;
|
||||
font-variant-numeric: tabular-nums;
|
||||
}
|
||||
|
||||
.progress-separator {
|
||||
@@ -256,69 +395,92 @@ const onIndexBlur = (event: Event) => {
|
||||
font-weight: 500;
|
||||
color: rgba(226, 232, 240, 0.6);
|
||||
font-family: 'SF Mono', 'Roboto Mono', monospace;
|
||||
min-width: 4ch;
|
||||
text-align: left;
|
||||
font-variant-numeric: tabular-nums;
|
||||
}
|
||||
|
||||
.refresh-button {
|
||||
/* Repeat Progress */
|
||||
.repeat-progress {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
width: 24px;
|
||||
height: 24px;
|
||||
gap: 6px;
|
||||
margin-left: 8px;
|
||||
padding: 0;
|
||||
background: transparent;
|
||||
border: 1px solid rgba(255, 255, 255, 0.1);
|
||||
padding: 2px 6px;
|
||||
background: rgba(26, 32, 44, 0.6);
|
||||
border: 1px solid rgba(226, 232, 240, 0.1);
|
||||
border-radius: 4px;
|
||||
color: rgba(226, 232, 240, 0.6);
|
||||
cursor: pointer;
|
||||
transition: all 0.2s;
|
||||
}
|
||||
|
||||
.refresh-button:hover:not(:disabled) {
|
||||
background: rgba(66, 153, 225, 0.2);
|
||||
border-color: rgba(66, 153, 225, 0.4);
|
||||
color: rgba(191, 219, 254, 1);
|
||||
.repeat-progress-track {
|
||||
width: 32px;
|
||||
height: 4px;
|
||||
background: rgba(226, 232, 240, 0.15);
|
||||
border-radius: 2px;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
.refresh-button:disabled {
|
||||
opacity: 0.4;
|
||||
cursor: not-allowed;
|
||||
.repeat-progress-fill {
|
||||
height: 100%;
|
||||
background: linear-gradient(90deg, #f59e0b, #fbbf24);
|
||||
border-radius: 2px;
|
||||
transition: width 0.3s ease;
|
||||
}
|
||||
|
||||
.refresh-icon {
|
||||
width: 14px;
|
||||
height: 14px;
|
||||
.repeat-progress-fill.is-complete {
|
||||
background: linear-gradient(90deg, #10b981, #34d399);
|
||||
}
|
||||
|
||||
.refresh-icon.spinning {
|
||||
animation: spin 1s linear infinite;
|
||||
.repeat-progress-text {
|
||||
font-size: 10px;
|
||||
font-family: 'SF Mono', 'Roboto Mono', monospace;
|
||||
color: rgba(253, 230, 138, 0.9);
|
||||
min-width: 3ch;
|
||||
font-variant-numeric: tabular-nums;
|
||||
}
|
||||
|
||||
@keyframes spin {
|
||||
from {
|
||||
transform: rotate(0deg);
|
||||
}
|
||||
to {
|
||||
transform: rotate(360deg);
|
||||
}
|
||||
}
|
||||
|
||||
/* Index Input */
|
||||
.index-input-container {
|
||||
/* Index Controls Row - Grouped Layout */
|
||||
.index-controls-row {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 8px;
|
||||
align-items: flex-end;
|
||||
gap: 16px;
|
||||
}
|
||||
|
||||
/* Control Group */
|
||||
.control-group {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 6px;
|
||||
}
|
||||
|
||||
.control-group-label {
|
||||
font-size: 11px;
|
||||
font-weight: 500;
|
||||
color: rgba(226, 232, 240, 0.5);
|
||||
text-transform: uppercase;
|
||||
letter-spacing: 0.03em;
|
||||
line-height: 1;
|
||||
}
|
||||
|
||||
.control-group-content {
|
||||
display: flex;
|
||||
align-items: baseline;
|
||||
gap: 4px;
|
||||
height: 32px;
|
||||
}
|
||||
|
||||
.index-input {
|
||||
width: 80px;
|
||||
padding: 6px 10px;
|
||||
width: 50px;
|
||||
height: 32px;
|
||||
padding: 0 8px;
|
||||
background: rgba(26, 32, 44, 0.9);
|
||||
border: 1px solid rgba(226, 232, 240, 0.2);
|
||||
border-radius: 6px;
|
||||
color: #e4e4e7;
|
||||
font-size: 13px;
|
||||
font-family: 'SF Mono', 'Roboto Mono', monospace;
|
||||
line-height: 32px;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
.index-input:focus {
|
||||
@@ -332,8 +494,89 @@ const onIndexBlur = (event: Event) => {
|
||||
}
|
||||
|
||||
.index-hint {
|
||||
font-size: 11px;
|
||||
font-size: 12px;
|
||||
color: rgba(226, 232, 240, 0.4);
|
||||
font-variant-numeric: tabular-nums;
|
||||
line-height: 32px;
|
||||
}
|
||||
|
||||
/* Repeat Controls */
|
||||
.repeat-input {
|
||||
width: 40px;
|
||||
height: 32px;
|
||||
padding: 0 6px;
|
||||
background: rgba(26, 32, 44, 0.9);
|
||||
border: 1px solid rgba(226, 232, 240, 0.2);
|
||||
border-radius: 6px;
|
||||
color: #e4e4e7;
|
||||
font-size: 13px;
|
||||
font-family: 'SF Mono', 'Roboto Mono', monospace;
|
||||
text-align: center;
|
||||
line-height: 32px;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
.repeat-input:focus {
|
||||
outline: none;
|
||||
border-color: rgba(66, 153, 225, 0.6);
|
||||
}
|
||||
|
||||
.repeat-suffix {
|
||||
font-size: 13px;
|
||||
color: rgba(226, 232, 240, 0.4);
|
||||
font-weight: 500;
|
||||
line-height: 32px;
|
||||
}
|
||||
|
||||
/* Action Buttons */
|
||||
.action-buttons {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
margin-left: auto;
|
||||
}
|
||||
|
||||
/* Control Buttons */
|
||||
.control-btn {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
width: 24px;
|
||||
height: 24px;
|
||||
padding: 0;
|
||||
background: transparent;
|
||||
border: 1px solid rgba(255, 255, 255, 0.1);
|
||||
border-radius: 4px;
|
||||
color: rgba(226, 232, 240, 0.6);
|
||||
cursor: pointer;
|
||||
transition: all 0.2s;
|
||||
}
|
||||
|
||||
.control-btn:hover:not(:disabled) {
|
||||
background: rgba(66, 153, 225, 0.2);
|
||||
border-color: rgba(66, 153, 225, 0.4);
|
||||
color: rgba(191, 219, 254, 1);
|
||||
}
|
||||
|
||||
.control-btn:disabled {
|
||||
opacity: 0.4;
|
||||
cursor: not-allowed;
|
||||
}
|
||||
|
||||
.control-btn.active {
|
||||
background: rgba(245, 158, 11, 0.2);
|
||||
border-color: rgba(245, 158, 11, 0.5);
|
||||
color: rgba(253, 230, 138, 1);
|
||||
}
|
||||
|
||||
.control-btn.active:hover {
|
||||
background: rgba(245, 158, 11, 0.3);
|
||||
border-color: rgba(245, 158, 11, 0.6);
|
||||
}
|
||||
|
||||
.control-icon {
|
||||
width: 14px;
|
||||
height: 14px;
|
||||
}
|
||||
|
||||
/* Slider Container */
|
||||
|
||||
313
vue-widgets/src/components/lora-cycler/LoraListModal.vue
Normal file
313
vue-widgets/src/components/lora-cycler/LoraListModal.vue
Normal file
@@ -0,0 +1,313 @@
|
||||
<template>
|
||||
<ModalWrapper
|
||||
:visible="visible"
|
||||
title="Select LoRA"
|
||||
:subtitle="subtitleText"
|
||||
@close="$emit('close')"
|
||||
>
|
||||
<template #search>
|
||||
<div class="search-container">
|
||||
<svg class="search-icon" viewBox="0 0 16 16" fill="currentColor">
|
||||
<path d="M11.742 10.344a6.5 6.5 0 1 0-1.397 1.398h-.001c.03.04.062.078.098.115l3.85 3.85a1 1 0 0 0 1.415-1.414l-3.85-3.85a1.007 1.007 0 0 0-.115-.1zM12 6.5a5.5 5.5 0 1 1-11 0 5.5 5.5 0 0 1 11 0z"/>
|
||||
</svg>
|
||||
<input
|
||||
ref="searchInputRef"
|
||||
v-model="searchQuery"
|
||||
type="text"
|
||||
class="search-input"
|
||||
placeholder="Search LoRAs..."
|
||||
/>
|
||||
<button
|
||||
v-if="searchQuery"
|
||||
type="button"
|
||||
class="clear-button"
|
||||
@click="clearSearch"
|
||||
>
|
||||
<svg viewBox="0 0 16 16" fill="currentColor">
|
||||
<path d="M4.646 4.646a.5.5 0 0 1 .708 0L8 7.293l2.646-2.647a.5.5 0 0 1 .708.708L8.707 8l2.647 2.646a.5.5 0 0 1-.708.708L8 8.707l-2.646 2.647a.5.5 0 0 1-.708-.708L7.293 8 4.646 5.354a.5.5 0 0 1 0-.708z"/>
|
||||
</svg>
|
||||
</button>
|
||||
</div>
|
||||
</template>
|
||||
|
||||
<div class="lora-list">
|
||||
<div
|
||||
v-for="item in filteredList"
|
||||
:key="item.index"
|
||||
class="lora-item"
|
||||
:class="{ active: currentIndex === item.index }"
|
||||
@mouseenter="showPreview(item.lora.file_name, $event)"
|
||||
@mouseleave="hidePreview"
|
||||
@click="selectLora(item.index)"
|
||||
>
|
||||
<span class="lora-index">{{ item.index }}</span>
|
||||
<span class="lora-name" :title="item.lora.file_name">{{ item.lora.file_name }}</span>
|
||||
<span v-if="currentIndex === item.index" class="current-badge">Current</span>
|
||||
</div>
|
||||
<div v-if="filteredList.length === 0" class="no-results">
|
||||
No LoRAs found
|
||||
</div>
|
||||
</div>
|
||||
</ModalWrapper>
|
||||
</template>
|
||||
|
||||
<script setup lang="ts">
|
||||
import { ref, computed, watch, nextTick, onUnmounted } from 'vue'
|
||||
import ModalWrapper from '../lora-pool/modals/ModalWrapper.vue'
|
||||
import type { LoraItem } from '../../composables/types'
|
||||
|
||||
interface LoraListItem {
|
||||
index: number
|
||||
lora: LoraItem
|
||||
}
|
||||
|
||||
const props = defineProps<{
|
||||
visible: boolean
|
||||
loraList: LoraItem[]
|
||||
currentIndex: number
|
||||
}>()
|
||||
|
||||
const emit = defineEmits<{
|
||||
close: []
|
||||
select: [index: number]
|
||||
}>()
|
||||
|
||||
const searchQuery = ref('')
|
||||
const searchInputRef = ref<HTMLInputElement | null>(null)
|
||||
|
||||
// Preview tooltip instance (lazy init)
|
||||
let previewTooltip: any = null
|
||||
|
||||
const subtitleText = computed(() => {
|
||||
const total = props.loraList.length
|
||||
const filtered = filteredList.value.length
|
||||
if (filtered === total) {
|
||||
return `Total: ${total} LoRA${total !== 1 ? 's' : ''}`
|
||||
}
|
||||
return `Showing ${filtered} of ${total} LoRA${total !== 1 ? 's' : ''}`
|
||||
})
|
||||
|
||||
const filteredList = computed<LoraListItem[]>(() => {
|
||||
const list = props.loraList.map((lora, idx) => ({
|
||||
index: idx + 1,
|
||||
lora
|
||||
}))
|
||||
|
||||
if (!searchQuery.value.trim()) {
|
||||
return list
|
||||
}
|
||||
|
||||
const query = searchQuery.value.toLowerCase()
|
||||
return list.filter(item =>
|
||||
item.lora.file_name.toLowerCase().includes(query)
|
||||
)
|
||||
})
|
||||
|
||||
const clearSearch = () => {
|
||||
searchQuery.value = ''
|
||||
searchInputRef.value?.focus()
|
||||
}
|
||||
|
||||
const selectLora = (index: number) => {
|
||||
emit('select', index)
|
||||
emit('close')
|
||||
}
|
||||
|
||||
// Custom preview URL resolver for Vue widgets environment
|
||||
// The default preview_tooltip.js uses api.fetchApi which is mocked as native fetch
|
||||
// in the Vue widgets build, so we need to use the full path with /api prefix
|
||||
const customPreviewUrlResolver = async (modelName: string) => {
|
||||
const response = await fetch(
|
||||
`/api/lm/loras/preview-url?name=${encodeURIComponent(modelName)}&license_flags=true`
|
||||
)
|
||||
if (!response.ok) {
|
||||
throw new Error('Failed to fetch preview URL')
|
||||
}
|
||||
const data = await response.json()
|
||||
if (!data.success || !data.preview_url) {
|
||||
throw new Error('No preview available')
|
||||
}
|
||||
return {
|
||||
previewUrl: data.preview_url,
|
||||
displayName: data.display_name ?? modelName,
|
||||
licenseFlags: data.license_flags
|
||||
}
|
||||
}
|
||||
|
||||
// Lazy load PreviewTooltip to avoid loading it unnecessarily
|
||||
const getPreviewTooltip = async () => {
|
||||
if (!previewTooltip) {
|
||||
const { PreviewTooltip } = await import(/* @vite-ignore */ `${'../preview_tooltip.js'}`)
|
||||
previewTooltip = new PreviewTooltip({
|
||||
modelType: 'loras',
|
||||
displayNameFormatter: (name: string) => name,
|
||||
previewUrlResolver: customPreviewUrlResolver
|
||||
})
|
||||
}
|
||||
return previewTooltip
|
||||
}
|
||||
|
||||
const showPreview = async (loraName: string, event: MouseEvent) => {
|
||||
const tooltip = await getPreviewTooltip()
|
||||
const rect = (event.target as HTMLElement).getBoundingClientRect()
|
||||
// Position to the right of the item, centered vertically
|
||||
tooltip.show(loraName, rect.right + 10, rect.top + rect.height / 2)
|
||||
}
|
||||
|
||||
const hidePreview = async () => {
|
||||
if (previewTooltip) {
|
||||
previewTooltip.hide()
|
||||
}
|
||||
}
|
||||
|
||||
// Focus search input when modal opens
|
||||
watch(() => props.visible, (isVisible) => {
|
||||
if (isVisible) {
|
||||
searchQuery.value = ''
|
||||
nextTick(() => {
|
||||
searchInputRef.value?.focus()
|
||||
})
|
||||
} else {
|
||||
// Hide preview when modal closes
|
||||
hidePreview()
|
||||
}
|
||||
})
|
||||
|
||||
// Cleanup on unmount
|
||||
onUnmounted(() => {
|
||||
if (previewTooltip) {
|
||||
previewTooltip.cleanup()
|
||||
previewTooltip = null
|
||||
}
|
||||
})
|
||||
</script>
|
||||
|
||||
<style scoped>
|
||||
.search-container {
|
||||
position: relative;
|
||||
}
|
||||
|
||||
.search-icon {
|
||||
position: absolute;
|
||||
left: 10px;
|
||||
top: 50%;
|
||||
transform: translateY(-50%);
|
||||
width: 14px;
|
||||
height: 14px;
|
||||
color: var(--fg-color, #fff);
|
||||
opacity: 0.5;
|
||||
}
|
||||
|
||||
.search-input {
|
||||
width: 100%;
|
||||
padding: 8px 32px;
|
||||
background: var(--comfy-input-bg, #333);
|
||||
border: 1px solid var(--border-color, #444);
|
||||
border-radius: 6px;
|
||||
color: var(--fg-color, #fff);
|
||||
font-size: 13px;
|
||||
outline: none;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
.search-input:focus {
|
||||
border-color: rgba(66, 153, 225, 0.6);
|
||||
}
|
||||
|
||||
.search-input::placeholder {
|
||||
color: var(--fg-color, #fff);
|
||||
opacity: 0.4;
|
||||
}
|
||||
|
||||
.clear-button {
|
||||
position: absolute;
|
||||
right: 8px;
|
||||
top: 50%;
|
||||
transform: translateY(-50%);
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
width: 20px;
|
||||
height: 20px;
|
||||
background: transparent;
|
||||
border: none;
|
||||
cursor: pointer;
|
||||
padding: 0;
|
||||
opacity: 0.5;
|
||||
transition: opacity 0.15s;
|
||||
}
|
||||
|
||||
.clear-button:hover {
|
||||
opacity: 0.8;
|
||||
}
|
||||
|
||||
.clear-button svg {
|
||||
width: 12px;
|
||||
height: 12px;
|
||||
color: var(--fg-color, #fff);
|
||||
}
|
||||
|
||||
.lora-list {
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
gap: 2px;
|
||||
max-height: 400px;
|
||||
overflow-y: auto;
|
||||
}
|
||||
|
||||
.lora-item {
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 12px;
|
||||
padding: 10px 12px;
|
||||
border-radius: 6px;
|
||||
cursor: pointer;
|
||||
transition: all 0.15s;
|
||||
border-left: 3px solid transparent;
|
||||
}
|
||||
|
||||
.lora-item:hover {
|
||||
background: rgba(66, 153, 225, 0.15);
|
||||
}
|
||||
|
||||
.lora-item.active {
|
||||
background: rgba(66, 153, 225, 0.25);
|
||||
border-left-color: rgba(66, 153, 225, 0.8);
|
||||
}
|
||||
|
||||
.lora-index {
|
||||
font-family: 'SF Mono', 'Roboto Mono', monospace;
|
||||
font-size: 12px;
|
||||
color: rgba(226, 232, 240, 0.5);
|
||||
min-width: 3ch;
|
||||
text-align: right;
|
||||
font-variant-numeric: tabular-nums;
|
||||
}
|
||||
|
||||
.lora-name {
|
||||
flex: 1;
|
||||
font-size: 13px;
|
||||
color: var(--fg-color, #fff);
|
||||
overflow: hidden;
|
||||
text-overflow: ellipsis;
|
||||
white-space: nowrap;
|
||||
}
|
||||
|
||||
.current-badge {
|
||||
font-size: 11px;
|
||||
padding: 2px 8px;
|
||||
background: rgba(66, 153, 225, 0.3);
|
||||
border: 1px solid rgba(66, 153, 225, 0.5);
|
||||
border-radius: 4px;
|
||||
color: rgba(191, 219, 254, 1);
|
||||
font-weight: 500;
|
||||
}
|
||||
|
||||
.no-results {
|
||||
padding: 32px 20px;
|
||||
text-align: center;
|
||||
color: var(--fg-color, #fff);
|
||||
opacity: 0.5;
|
||||
font-size: 13px;
|
||||
}
|
||||
</style>
|
||||
@@ -81,7 +81,7 @@ watch(() => props.visible, (isVisible) => {
|
||||
.lora-pool-modal-backdrop {
|
||||
position: fixed;
|
||||
inset: 0;
|
||||
z-index: 10000;
|
||||
z-index: 9998;
|
||||
background: rgba(0, 0, 0, 0.6);
|
||||
display: flex;
|
||||
align-items: center;
|
||||
|
||||
@@ -206,7 +206,9 @@ const stepToDecimals = (step: number): number => {
|
||||
const snapToStep = (value: number, segmentMultiplier?: number): number => {
|
||||
const effectiveStep = segmentMultiplier ? props.step * segmentMultiplier : props.step
|
||||
const steps = Math.round((value - props.min) / effectiveStep)
|
||||
return Math.max(props.min, Math.min(props.max, props.min + steps * effectiveStep))
|
||||
const rawValue = Math.max(props.min, Math.min(props.max, props.min + steps * effectiveStep))
|
||||
// Fix floating point precision issues, limit to 2 decimal places
|
||||
return Math.round(rawValue * 100) / 100
|
||||
}
|
||||
|
||||
const startDrag = (handle: 'min' | 'max', event: PointerEvent) => {
|
||||
|
||||
@@ -82,7 +82,9 @@ const stepToDecimals = (step: number): number => {
|
||||
|
||||
const snapToStep = (value: number): number => {
|
||||
const steps = Math.round((value - props.min) / props.step)
|
||||
return Math.max(props.min, Math.min(props.max, props.min + steps * props.step))
|
||||
const rawValue = Math.max(props.min, Math.min(props.max, props.min + steps * props.step))
|
||||
// Fix floating point precision issues, limit to 2 decimal places
|
||||
return Math.round(rawValue * 100) / 100
|
||||
}
|
||||
|
||||
const startDrag = (event: PointerEvent) => {
|
||||
|
||||
@@ -80,6 +80,10 @@ export interface CyclerConfig {
|
||||
// Dual-index mechanism for batch queue synchronization
|
||||
execution_index?: number | null // Index to use for current execution
|
||||
next_index?: number | null // Index for display after execution
|
||||
// Advanced index control features
|
||||
repeat_count: number // How many times each LoRA should repeat (default: 1)
|
||||
repeat_used: number // How many times current index has been used
|
||||
is_paused: boolean // Whether iteration is paused
|
||||
}
|
||||
|
||||
// Widget config union type
|
||||
|
||||
@@ -29,6 +29,16 @@ export function useLoraCyclerState(widget: ComponentWidget<CyclerConfig>) {
|
||||
const executionIndex = ref<number | null>(null)
|
||||
const nextIndex = ref<number | null>(null)
|
||||
|
||||
// Advanced index control features
|
||||
const repeatCount = ref(1) // How many times each LoRA should repeat
|
||||
const repeatUsed = ref(0) // How many times current index has been used (internal tracking)
|
||||
const displayRepeatUsed = ref(0) // For UI display, deferred updates like currentIndex
|
||||
const isPaused = ref(false) // Whether iteration is paused
|
||||
|
||||
// Execution progress tracking (visual feedback)
|
||||
const isWorkflowExecuting = ref(false) // Workflow is currently running
|
||||
const executingRepeatStep = ref(0) // Which repeat step (1-based, 0 = not executing)
|
||||
|
||||
// Build config object from current state
|
||||
const buildConfig = (): CyclerConfig => {
|
||||
// Skip updating widget.value during restoration to prevent infinite loops
|
||||
@@ -45,6 +55,9 @@ export function useLoraCyclerState(widget: ComponentWidget<CyclerConfig>) {
|
||||
current_lora_filename: currentLoraFilename.value,
|
||||
execution_index: executionIndex.value,
|
||||
next_index: nextIndex.value,
|
||||
repeat_count: repeatCount.value,
|
||||
repeat_used: repeatUsed.value,
|
||||
is_paused: isPaused.value,
|
||||
}
|
||||
}
|
||||
return {
|
||||
@@ -59,6 +72,9 @@ export function useLoraCyclerState(widget: ComponentWidget<CyclerConfig>) {
|
||||
current_lora_filename: currentLoraFilename.value,
|
||||
execution_index: executionIndex.value,
|
||||
next_index: nextIndex.value,
|
||||
repeat_count: repeatCount.value,
|
||||
repeat_used: repeatUsed.value,
|
||||
is_paused: isPaused.value,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,6 +93,10 @@ export function useLoraCyclerState(widget: ComponentWidget<CyclerConfig>) {
|
||||
sortBy.value = config.sort_by || 'filename'
|
||||
currentLoraName.value = config.current_lora_name || ''
|
||||
currentLoraFilename.value = config.current_lora_filename || ''
|
||||
// Advanced index control features
|
||||
repeatCount.value = config.repeat_count ?? 1
|
||||
repeatUsed.value = config.repeat_used ?? 0
|
||||
isPaused.value = config.is_paused ?? false
|
||||
// Note: execution_index and next_index are not restored from config
|
||||
// as they are transient values used only during batch execution
|
||||
} finally {
|
||||
@@ -215,6 +235,19 @@ export function useLoraCyclerState(widget: ComponentWidget<CyclerConfig>) {
|
||||
}
|
||||
}
|
||||
|
||||
// Reset index to 1 and clear repeat state
|
||||
const resetIndex = () => {
|
||||
currentIndex.value = 1
|
||||
repeatUsed.value = 0
|
||||
displayRepeatUsed.value = 0
|
||||
// Note: isPaused is intentionally not reset - user may want to stay paused after reset
|
||||
}
|
||||
|
||||
// Toggle pause state
|
||||
const togglePause = () => {
|
||||
isPaused.value = !isPaused.value
|
||||
}
|
||||
|
||||
// Computed property to check if clip strength is disabled
|
||||
const isClipStrengthDisabled = computed(() => !useCustomClipRange.value)
|
||||
|
||||
@@ -236,6 +269,9 @@ export function useLoraCyclerState(widget: ComponentWidget<CyclerConfig>) {
|
||||
sortBy,
|
||||
currentLoraName,
|
||||
currentLoraFilename,
|
||||
repeatCount,
|
||||
repeatUsed,
|
||||
isPaused,
|
||||
], () => {
|
||||
widget.value = buildConfig()
|
||||
}, { deep: true })
|
||||
@@ -254,6 +290,12 @@ export function useLoraCyclerState(widget: ComponentWidget<CyclerConfig>) {
|
||||
isLoading,
|
||||
executionIndex,
|
||||
nextIndex,
|
||||
repeatCount,
|
||||
repeatUsed,
|
||||
displayRepeatUsed,
|
||||
isPaused,
|
||||
isWorkflowExecuting,
|
||||
executingRepeatStep,
|
||||
|
||||
// Computed
|
||||
isClipStrengthDisabled,
|
||||
@@ -267,5 +309,7 @@ export function useLoraCyclerState(widget: ComponentWidget<CyclerConfig>) {
|
||||
setIndex,
|
||||
generateNextIndex,
|
||||
initializeNextIndex,
|
||||
resetIndex,
|
||||
togglePause,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,6 +27,8 @@ const AUTOCOMPLETE_TEXT_WIDGET_MAX_HEIGHT = 100
|
||||
|
||||
// @ts-ignore - ComfyUI external module
|
||||
import { app } from '../../../scripts/app.js'
|
||||
// @ts-ignore - ComfyUI external module
|
||||
import { api } from '../../../scripts/api.js'
|
||||
// @ts-ignore
|
||||
import { getPoolConfigFromConnectedNode, getActiveLorasFromNode, updateConnectedTriggerWords, updateDownstreamLoaders } from '../../web/comfyui/utils.js'
|
||||
|
||||
@@ -255,7 +257,8 @@ function createLoraCyclerWidget(node) {
|
||||
|
||||
const vueApp = createApp(LoraCyclerWidget, {
|
||||
widget,
|
||||
node
|
||||
node,
|
||||
api
|
||||
})
|
||||
|
||||
vueApp.use(PrimeVue, {
|
||||
|
||||
634
vue-widgets/tests/composables/useLoraCyclerState.test.ts
Normal file
634
vue-widgets/tests/composables/useLoraCyclerState.test.ts
Normal file
@@ -0,0 +1,634 @@
|
||||
/**
|
||||
* Unit tests for useLoraCyclerState composable
|
||||
*
|
||||
* Tests pure state transitions and index calculations in isolation.
|
||||
*/
|
||||
|
||||
import { describe, it, expect, beforeEach, vi } from 'vitest'
|
||||
import { useLoraCyclerState } from '@/composables/useLoraCyclerState'
|
||||
import {
|
||||
createMockWidget,
|
||||
createMockCyclerConfig,
|
||||
createMockPoolConfig
|
||||
} from '../fixtures/mockConfigs'
|
||||
import { setupFetchMock, resetFetchMock } from '../setup'
|
||||
|
||||
describe('useLoraCyclerState', () => {
|
||||
beforeEach(() => {
|
||||
resetFetchMock()
|
||||
})
|
||||
|
||||
describe('Initial State', () => {
|
||||
it('should initialize with default values', () => {
|
||||
const widget = createMockWidget()
|
||||
const state = useLoraCyclerState(widget)
|
||||
|
||||
expect(state.currentIndex.value).toBe(1)
|
||||
expect(state.totalCount.value).toBe(0)
|
||||
expect(state.poolConfigHash.value).toBe('')
|
||||
expect(state.modelStrength.value).toBe(1.0)
|
||||
expect(state.clipStrength.value).toBe(1.0)
|
||||
expect(state.useCustomClipRange.value).toBe(false)
|
||||
expect(state.sortBy.value).toBe('filename')
|
||||
expect(state.executionIndex.value).toBeNull()
|
||||
expect(state.nextIndex.value).toBeNull()
|
||||
expect(state.repeatCount.value).toBe(1)
|
||||
expect(state.repeatUsed.value).toBe(0)
|
||||
expect(state.displayRepeatUsed.value).toBe(0)
|
||||
expect(state.isPaused.value).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('restoreFromConfig', () => {
|
||||
it('should restore state from config object', () => {
|
||||
const widget = createMockWidget()
|
||||
const state = useLoraCyclerState(widget)
|
||||
|
||||
const config = createMockCyclerConfig({
|
||||
current_index: 3,
|
||||
total_count: 10,
|
||||
model_strength: 0.8,
|
||||
clip_strength: 0.6,
|
||||
use_same_clip_strength: false,
|
||||
repeat_count: 2,
|
||||
repeat_used: 1,
|
||||
is_paused: true
|
||||
})
|
||||
|
||||
state.restoreFromConfig(config)
|
||||
|
||||
expect(state.currentIndex.value).toBe(3)
|
||||
expect(state.totalCount.value).toBe(10)
|
||||
expect(state.modelStrength.value).toBe(0.8)
|
||||
expect(state.clipStrength.value).toBe(0.6)
|
||||
expect(state.useCustomClipRange.value).toBe(true) // inverted from use_same_clip_strength
|
||||
expect(state.repeatCount.value).toBe(2)
|
||||
expect(state.repeatUsed.value).toBe(1)
|
||||
expect(state.isPaused.value).toBe(true)
|
||||
})
|
||||
|
||||
it('should handle missing optional fields with defaults', () => {
|
||||
const widget = createMockWidget()
|
||||
const state = useLoraCyclerState(widget)
|
||||
|
||||
// Minimal config
|
||||
state.restoreFromConfig({
|
||||
current_index: 5,
|
||||
total_count: 10,
|
||||
pool_config_hash: '',
|
||||
model_strength: 1.0,
|
||||
clip_strength: 1.0,
|
||||
use_same_clip_strength: true,
|
||||
sort_by: 'filename',
|
||||
current_lora_name: '',
|
||||
current_lora_filename: '',
|
||||
repeat_count: 1,
|
||||
repeat_used: 0,
|
||||
is_paused: false
|
||||
})
|
||||
|
||||
expect(state.currentIndex.value).toBe(5)
|
||||
expect(state.repeatCount.value).toBe(1)
|
||||
expect(state.isPaused.value).toBe(false)
|
||||
})
|
||||
|
||||
it('should not restore execution_index and next_index (transient values)', () => {
|
||||
const widget = createMockWidget()
|
||||
const state = useLoraCyclerState(widget)
|
||||
|
||||
// Set execution indices
|
||||
state.executionIndex.value = 2
|
||||
state.nextIndex.value = 3
|
||||
|
||||
// Restore from config (these fields in config should be ignored)
|
||||
state.restoreFromConfig(createMockCyclerConfig({
|
||||
execution_index: 5,
|
||||
next_index: 6
|
||||
}))
|
||||
|
||||
// Execution indices should remain unchanged
|
||||
expect(state.executionIndex.value).toBe(2)
|
||||
expect(state.nextIndex.value).toBe(3)
|
||||
})
|
||||
})
|
||||
|
||||
describe('buildConfig', () => {
|
||||
it('should build config object from current state', () => {
|
||||
const widget = createMockWidget()
|
||||
const state = useLoraCyclerState(widget)
|
||||
|
||||
state.currentIndex.value = 3
|
||||
state.totalCount.value = 10
|
||||
state.modelStrength.value = 0.8
|
||||
state.repeatCount.value = 2
|
||||
state.repeatUsed.value = 1
|
||||
state.isPaused.value = true
|
||||
|
||||
const config = state.buildConfig()
|
||||
|
||||
expect(config.current_index).toBe(3)
|
||||
expect(config.total_count).toBe(10)
|
||||
expect(config.model_strength).toBe(0.8)
|
||||
expect(config.repeat_count).toBe(2)
|
||||
expect(config.repeat_used).toBe(1)
|
||||
expect(config.is_paused).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('setIndex', () => {
|
||||
it('should set index within valid range', () => {
|
||||
const widget = createMockWidget()
|
||||
const state = useLoraCyclerState(widget)
|
||||
state.totalCount.value = 10
|
||||
|
||||
state.setIndex(5)
|
||||
expect(state.currentIndex.value).toBe(5)
|
||||
|
||||
state.setIndex(1)
|
||||
expect(state.currentIndex.value).toBe(1)
|
||||
|
||||
state.setIndex(10)
|
||||
expect(state.currentIndex.value).toBe(10)
|
||||
})
|
||||
|
||||
it('should not set index outside valid range', () => {
|
||||
const widget = createMockWidget()
|
||||
const state = useLoraCyclerState(widget)
|
||||
state.totalCount.value = 10
|
||||
state.currentIndex.value = 5
|
||||
|
||||
state.setIndex(0)
|
||||
expect(state.currentIndex.value).toBe(5) // unchanged
|
||||
|
||||
state.setIndex(11)
|
||||
expect(state.currentIndex.value).toBe(5) // unchanged
|
||||
|
||||
state.setIndex(-1)
|
||||
expect(state.currentIndex.value).toBe(5) // unchanged
|
||||
})
|
||||
})
|
||||
|
||||
describe('resetIndex', () => {
|
||||
it('should reset index to 1 and clear repeatUsed and displayRepeatUsed', () => {
|
||||
const widget = createMockWidget()
|
||||
const state = useLoraCyclerState(widget)
|
||||
|
||||
state.currentIndex.value = 5
|
||||
state.repeatUsed.value = 2
|
||||
state.displayRepeatUsed.value = 2
|
||||
state.isPaused.value = true
|
||||
|
||||
state.resetIndex()
|
||||
|
||||
expect(state.currentIndex.value).toBe(1)
|
||||
expect(state.repeatUsed.value).toBe(0)
|
||||
expect(state.displayRepeatUsed.value).toBe(0)
|
||||
expect(state.isPaused.value).toBe(true) // isPaused should NOT be reset
|
||||
})
|
||||
})
|
||||
|
||||
describe('togglePause', () => {
|
||||
it('should toggle pause state', () => {
|
||||
const widget = createMockWidget()
|
||||
const state = useLoraCyclerState(widget)
|
||||
|
||||
expect(state.isPaused.value).toBe(false)
|
||||
|
||||
state.togglePause()
|
||||
expect(state.isPaused.value).toBe(true)
|
||||
|
||||
state.togglePause()
|
||||
expect(state.isPaused.value).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('generateNextIndex', () => {
|
||||
it('should shift indices correctly', () => {
|
||||
const widget = createMockWidget()
|
||||
const state = useLoraCyclerState(widget)
|
||||
state.totalCount.value = 5
|
||||
state.currentIndex.value = 1
|
||||
state.nextIndex.value = 2
|
||||
|
||||
// First call: executionIndex becomes 2 (previous nextIndex), nextIndex becomes 3
|
||||
state.generateNextIndex()
|
||||
|
||||
expect(state.executionIndex.value).toBe(2)
|
||||
expect(state.nextIndex.value).toBe(3)
|
||||
|
||||
// Second call: executionIndex becomes 3, nextIndex becomes 4
|
||||
state.generateNextIndex()
|
||||
|
||||
expect(state.executionIndex.value).toBe(3)
|
||||
expect(state.nextIndex.value).toBe(4)
|
||||
})
|
||||
|
||||
it('should wrap index from totalCount to 1', () => {
|
||||
const widget = createMockWidget()
|
||||
const state = useLoraCyclerState(widget)
|
||||
state.totalCount.value = 5
|
||||
state.nextIndex.value = 5 // At the last index
|
||||
|
||||
state.generateNextIndex()
|
||||
|
||||
expect(state.executionIndex.value).toBe(5)
|
||||
expect(state.nextIndex.value).toBe(1) // Wrapped to 1
|
||||
})
|
||||
|
||||
it('should use currentIndex when nextIndex is null', () => {
|
||||
const widget = createMockWidget()
|
||||
const state = useLoraCyclerState(widget)
|
||||
state.totalCount.value = 5
|
||||
state.currentIndex.value = 3
|
||||
state.nextIndex.value = null
|
||||
|
||||
state.generateNextIndex()
|
||||
|
||||
// executionIndex becomes previous nextIndex (null)
|
||||
expect(state.executionIndex.value).toBeNull()
|
||||
// nextIndex is calculated from currentIndex (3) -> 4
|
||||
expect(state.nextIndex.value).toBe(4)
|
||||
})
|
||||
})
|
||||
|
||||
describe('initializeNextIndex', () => {
|
||||
it('should initialize nextIndex to currentIndex + 1 when null', () => {
|
||||
const widget = createMockWidget()
|
||||
const state = useLoraCyclerState(widget)
|
||||
state.totalCount.value = 5
|
||||
state.currentIndex.value = 1
|
||||
state.nextIndex.value = null
|
||||
|
||||
state.initializeNextIndex()
|
||||
|
||||
expect(state.nextIndex.value).toBe(2)
|
||||
})
|
||||
|
||||
it('should wrap nextIndex when currentIndex is at totalCount', () => {
|
||||
const widget = createMockWidget()
|
||||
const state = useLoraCyclerState(widget)
|
||||
state.totalCount.value = 5
|
||||
state.currentIndex.value = 5
|
||||
state.nextIndex.value = null
|
||||
|
||||
state.initializeNextIndex()
|
||||
|
||||
expect(state.nextIndex.value).toBe(1) // Wrapped
|
||||
})
|
||||
|
||||
it('should not change nextIndex if already set', () => {
|
||||
const widget = createMockWidget()
|
||||
const state = useLoraCyclerState(widget)
|
||||
state.totalCount.value = 5
|
||||
state.currentIndex.value = 1
|
||||
state.nextIndex.value = 4
|
||||
|
||||
state.initializeNextIndex()
|
||||
|
||||
expect(state.nextIndex.value).toBe(4) // Unchanged
|
||||
})
|
||||
})
|
||||
|
||||
describe('Index Wrapping Edge Cases', () => {
|
||||
it('should handle single item pool', () => {
|
||||
const widget = createMockWidget()
|
||||
const state = useLoraCyclerState(widget)
|
||||
state.totalCount.value = 1
|
||||
state.currentIndex.value = 1
|
||||
state.nextIndex.value = null
|
||||
|
||||
state.initializeNextIndex()
|
||||
|
||||
expect(state.nextIndex.value).toBe(1) // Wraps back to 1
|
||||
})
|
||||
|
||||
it('should handle zero total count gracefully', () => {
|
||||
const widget = createMockWidget()
|
||||
const state = useLoraCyclerState(widget)
|
||||
state.totalCount.value = 0
|
||||
state.currentIndex.value = 1
|
||||
state.nextIndex.value = null
|
||||
|
||||
state.initializeNextIndex()
|
||||
|
||||
// Should still calculate, even if totalCount is 0
|
||||
expect(state.nextIndex.value).toBe(2) // No wrapping since totalCount <= 0
|
||||
})
|
||||
})
|
||||
|
||||
describe('hashPoolConfig', () => {
|
||||
it('should generate consistent hash for same config', () => {
|
||||
const widget = createMockWidget()
|
||||
const state = useLoraCyclerState(widget)
|
||||
|
||||
const config1 = createMockPoolConfig()
|
||||
const config2 = createMockPoolConfig()
|
||||
|
||||
const hash1 = state.hashPoolConfig(config1)
|
||||
const hash2 = state.hashPoolConfig(config2)
|
||||
|
||||
expect(hash1).toBe(hash2)
|
||||
})
|
||||
|
||||
it('should generate different hash for different configs', () => {
|
||||
const widget = createMockWidget()
|
||||
const state = useLoraCyclerState(widget)
|
||||
|
||||
const config1 = createMockPoolConfig({
|
||||
filters: {
|
||||
baseModels: ['SD 1.5'],
|
||||
tags: { include: [], exclude: [] },
|
||||
folders: { include: [], exclude: [] },
|
||||
license: { noCreditRequired: false, allowSelling: false }
|
||||
}
|
||||
})
|
||||
|
||||
const config2 = createMockPoolConfig({
|
||||
filters: {
|
||||
baseModels: ['SDXL'],
|
||||
tags: { include: [], exclude: [] },
|
||||
folders: { include: [], exclude: [] },
|
||||
license: { noCreditRequired: false, allowSelling: false }
|
||||
}
|
||||
})
|
||||
|
||||
const hash1 = state.hashPoolConfig(config1)
|
||||
const hash2 = state.hashPoolConfig(config2)
|
||||
|
||||
expect(hash1).not.toBe(hash2)
|
||||
})
|
||||
|
||||
it('should return empty string for null config', () => {
|
||||
const widget = createMockWidget()
|
||||
const state = useLoraCyclerState(widget)
|
||||
|
||||
expect(state.hashPoolConfig(null)).toBe('')
|
||||
})
|
||||
|
||||
it('should return empty string for config without filters', () => {
|
||||
const widget = createMockWidget()
|
||||
const state = useLoraCyclerState(widget)
|
||||
|
||||
const config = { version: 1, preview: { matchCount: 0, lastUpdated: 0 } } as any
|
||||
|
||||
expect(state.hashPoolConfig(config)).toBe('')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Clip Strength Synchronization', () => {
|
||||
it('should sync clipStrength with modelStrength when useCustomClipRange is false', async () => {
|
||||
const widget = createMockWidget()
|
||||
const state = useLoraCyclerState(widget)
|
||||
|
||||
state.useCustomClipRange.value = false
|
||||
state.modelStrength.value = 0.5
|
||||
|
||||
// Wait for Vue reactivity
|
||||
await vi.waitFor(() => {
|
||||
expect(state.clipStrength.value).toBe(0.5)
|
||||
})
|
||||
})
|
||||
|
||||
it('should not sync clipStrength when useCustomClipRange is true', async () => {
|
||||
const widget = createMockWidget()
|
||||
const state = useLoraCyclerState(widget)
|
||||
|
||||
state.useCustomClipRange.value = true
|
||||
state.clipStrength.value = 0.7
|
||||
state.modelStrength.value = 0.5
|
||||
|
||||
// clipStrength should remain unchanged
|
||||
await vi.waitFor(() => {
|
||||
expect(state.clipStrength.value).toBe(0.7)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Widget Value Synchronization', () => {
|
||||
it('should update widget.value when state changes', async () => {
|
||||
const widget = createMockWidget()
|
||||
const state = useLoraCyclerState(widget)
|
||||
|
||||
state.currentIndex.value = 3
|
||||
state.repeatCount.value = 2
|
||||
|
||||
// Wait for Vue reactivity
|
||||
await vi.waitFor(() => {
|
||||
expect(widget.value?.current_index).toBe(3)
|
||||
expect(widget.value?.repeat_count).toBe(2)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Repeat Logic State', () => {
|
||||
it('should track repeatUsed correctly', () => {
|
||||
const widget = createMockWidget()
|
||||
const state = useLoraCyclerState(widget)
|
||||
|
||||
state.repeatCount.value = 3
|
||||
expect(state.repeatUsed.value).toBe(0)
|
||||
|
||||
state.repeatUsed.value = 1
|
||||
expect(state.repeatUsed.value).toBe(1)
|
||||
|
||||
state.repeatUsed.value = 3
|
||||
expect(state.repeatUsed.value).toBe(3)
|
||||
})
|
||||
})
|
||||
|
||||
describe('fetchCyclerList', () => {
|
||||
it('should call API and return lora list', async () => {
|
||||
const mockLoras = [
|
||||
{ file_name: 'lora1.safetensors', model_name: 'LoRA 1' },
|
||||
{ file_name: 'lora2.safetensors', model_name: 'LoRA 2' }
|
||||
]
|
||||
|
||||
setupFetchMock({ success: true, loras: mockLoras })
|
||||
|
||||
const widget = createMockWidget()
|
||||
const state = useLoraCyclerState(widget)
|
||||
|
||||
const result = await state.fetchCyclerList(null)
|
||||
|
||||
expect(result).toEqual(mockLoras)
|
||||
expect(state.isLoading.value).toBe(false)
|
||||
})
|
||||
|
||||
it('should include pool config filters in request', async () => {
|
||||
const mockFetch = setupFetchMock({ success: true, loras: [] })
|
||||
|
||||
const widget = createMockWidget()
|
||||
const state = useLoraCyclerState(widget)
|
||||
|
||||
const poolConfig = createMockPoolConfig()
|
||||
await state.fetchCyclerList(poolConfig)
|
||||
|
||||
expect(mockFetch).toHaveBeenCalledWith(
|
||||
'/api/lm/loras/cycler-list',
|
||||
expect.objectContaining({
|
||||
method: 'POST',
|
||||
body: expect.stringContaining('pool_config')
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('should set isLoading during fetch', async () => {
|
||||
let resolvePromise: (value: unknown) => void
|
||||
const pendingPromise = new Promise(resolve => {
|
||||
resolvePromise = resolve
|
||||
})
|
||||
|
||||
// Use mockFetch from setup instead of overriding global
|
||||
const { mockFetch } = await import('../setup')
|
||||
mockFetch.mockReset()
|
||||
mockFetch.mockReturnValue(pendingPromise)
|
||||
|
||||
const widget = createMockWidget()
|
||||
const state = useLoraCyclerState(widget)
|
||||
|
||||
const fetchPromise = state.fetchCyclerList(null)
|
||||
|
||||
expect(state.isLoading.value).toBe(true)
|
||||
|
||||
// Resolve the fetch
|
||||
resolvePromise!({
|
||||
ok: true,
|
||||
json: () => Promise.resolve({ success: true, loras: [] })
|
||||
})
|
||||
|
||||
await fetchPromise
|
||||
|
||||
expect(state.isLoading.value).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('refreshList', () => {
|
||||
it('should update totalCount from API response', async () => {
|
||||
const mockLoras = [
|
||||
{ file_name: 'lora1.safetensors', model_name: 'LoRA 1' },
|
||||
{ file_name: 'lora2.safetensors', model_name: 'LoRA 2' },
|
||||
{ file_name: 'lora3.safetensors', model_name: 'LoRA 3' }
|
||||
]
|
||||
|
||||
// Reset and setup fresh mock
|
||||
resetFetchMock()
|
||||
setupFetchMock({ success: true, loras: mockLoras })
|
||||
|
||||
const widget = createMockWidget()
|
||||
const state = useLoraCyclerState(widget)
|
||||
|
||||
await state.refreshList(null)
|
||||
|
||||
expect(state.totalCount.value).toBe(3)
|
||||
})
|
||||
|
||||
it('should reset index to 1 when pool config hash changes', async () => {
|
||||
resetFetchMock()
|
||||
setupFetchMock({ success: true, loras: [{ file_name: 'lora1.safetensors', model_name: 'LoRA 1' }] })
|
||||
|
||||
const widget = createMockWidget()
|
||||
const state = useLoraCyclerState(widget)
|
||||
|
||||
// Set initial state
|
||||
state.currentIndex.value = 5
|
||||
state.poolConfigHash.value = 'old-hash'
|
||||
|
||||
// Refresh with new config (different hash)
|
||||
const newConfig = createMockPoolConfig({
|
||||
filters: {
|
||||
baseModels: ['SDXL'],
|
||||
tags: { include: [], exclude: [] },
|
||||
folders: { include: [], exclude: [] },
|
||||
license: { noCreditRequired: false, allowSelling: false }
|
||||
}
|
||||
})
|
||||
|
||||
await state.refreshList(newConfig)
|
||||
|
||||
expect(state.currentIndex.value).toBe(1)
|
||||
})
|
||||
|
||||
it('should clamp index when totalCount decreases', async () => {
|
||||
// Setup mock first, then create state
|
||||
resetFetchMock()
|
||||
setupFetchMock({
|
||||
success: true,
|
||||
loras: [
|
||||
{ file_name: 'lora1.safetensors', model_name: 'LoRA 1' },
|
||||
{ file_name: 'lora2.safetensors', model_name: 'LoRA 2' },
|
||||
{ file_name: 'lora3.safetensors', model_name: 'LoRA 3' }
|
||||
]
|
||||
})
|
||||
|
||||
const widget = createMockWidget()
|
||||
const state = useLoraCyclerState(widget)
|
||||
|
||||
// Set initial state with high index
|
||||
state.currentIndex.value = 10
|
||||
state.totalCount.value = 10
|
||||
|
||||
await state.refreshList(null)
|
||||
|
||||
expect(state.totalCount.value).toBe(3)
|
||||
expect(state.currentIndex.value).toBe(3) // Clamped to max
|
||||
})
|
||||
|
||||
it('should update currentLoraName and currentLoraFilename', async () => {
|
||||
resetFetchMock()
|
||||
setupFetchMock({
|
||||
success: true,
|
||||
loras: [
|
||||
{ file_name: 'lora1.safetensors', model_name: 'LoRA 1' },
|
||||
{ file_name: 'lora2.safetensors', model_name: 'LoRA 2' }
|
||||
]
|
||||
})
|
||||
|
||||
const widget = createMockWidget()
|
||||
const state = useLoraCyclerState(widget)
|
||||
// Set totalCount first so setIndex works, then set index
|
||||
state.totalCount.value = 2
|
||||
state.currentIndex.value = 2
|
||||
|
||||
await state.refreshList(null)
|
||||
|
||||
expect(state.currentLoraFilename.value).toBe('lora2.safetensors')
|
||||
})
|
||||
|
||||
it('should handle empty list gracefully', async () => {
|
||||
resetFetchMock()
|
||||
setupFetchMock({ success: true, loras: [] })
|
||||
|
||||
const widget = createMockWidget()
|
||||
const state = useLoraCyclerState(widget)
|
||||
state.currentIndex.value = 5
|
||||
state.totalCount.value = 5
|
||||
|
||||
await state.refreshList(null)
|
||||
|
||||
expect(state.totalCount.value).toBe(0)
|
||||
// When totalCount is 0, Math.max(1, 0) = 1, but if currentIndex > totalCount it gets clamped to max(1, totalCount)
|
||||
// Looking at the actual code: Math.max(1, totalCount) where totalCount=0 gives 1
|
||||
expect(state.currentIndex.value).toBe(1)
|
||||
expect(state.currentLoraName.value).toBe('')
|
||||
expect(state.currentLoraFilename.value).toBe('')
|
||||
})
|
||||
})
|
||||
|
||||
describe('isClipStrengthDisabled computed', () => {
|
||||
it('should return true when useCustomClipRange is false', () => {
|
||||
const widget = createMockWidget()
|
||||
const state = useLoraCyclerState(widget)
|
||||
|
||||
state.useCustomClipRange.value = false
|
||||
expect(state.isClipStrengthDisabled.value).toBe(true)
|
||||
})
|
||||
|
||||
it('should return false when useCustomClipRange is true', () => {
|
||||
const widget = createMockWidget()
|
||||
const state = useLoraCyclerState(widget)
|
||||
|
||||
state.useCustomClipRange.value = true
|
||||
expect(state.isClipStrengthDisabled.value).toBe(false)
|
||||
})
|
||||
})
|
||||
})
|
||||
175
vue-widgets/tests/fixtures/mockConfigs.ts
vendored
Normal file
175
vue-widgets/tests/fixtures/mockConfigs.ts
vendored
Normal file
@@ -0,0 +1,175 @@
|
||||
/**
|
||||
* Test fixtures for LoRA Cycler testing
|
||||
*/
|
||||
|
||||
import type { CyclerConfig, LoraPoolConfig } from '@/composables/types'
|
||||
import type { CyclerLoraItem } from '@/composables/useLoraCyclerState'
|
||||
|
||||
/**
|
||||
* Creates a default CyclerConfig for testing
|
||||
*/
|
||||
export function createMockCyclerConfig(overrides: Partial<CyclerConfig> = {}): CyclerConfig {
|
||||
return {
|
||||
current_index: 1,
|
||||
total_count: 5,
|
||||
pool_config_hash: '',
|
||||
model_strength: 1.0,
|
||||
clip_strength: 1.0,
|
||||
use_same_clip_strength: true,
|
||||
sort_by: 'filename',
|
||||
current_lora_name: 'lora1.safetensors',
|
||||
current_lora_filename: 'lora1.safetensors',
|
||||
execution_index: null,
|
||||
next_index: null,
|
||||
repeat_count: 1,
|
||||
repeat_used: 0,
|
||||
is_paused: false,
|
||||
...overrides
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a mock LoraPoolConfig for testing
|
||||
*/
|
||||
export function createMockPoolConfig(overrides: Partial<LoraPoolConfig> = {}): LoraPoolConfig {
|
||||
return {
|
||||
version: 1,
|
||||
filters: {
|
||||
baseModels: ['SD 1.5'],
|
||||
tags: { include: [], exclude: [] },
|
||||
folders: { include: [], exclude: [] },
|
||||
license: {
|
||||
noCreditRequired: false,
|
||||
allowSelling: false
|
||||
}
|
||||
},
|
||||
preview: { matchCount: 10, lastUpdated: Date.now() },
|
||||
...overrides
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a list of mock LoRA items for testing
|
||||
*/
|
||||
export function createMockLoraList(count: number = 5): CyclerLoraItem[] {
|
||||
return Array.from({ length: count }, (_, i) => ({
|
||||
file_name: `lora${i + 1}.safetensors`,
|
||||
model_name: `LoRA Model ${i + 1}`
|
||||
}))
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a mock widget object for testing useLoraCyclerState
|
||||
*/
|
||||
export function createMockWidget(initialValue?: CyclerConfig) {
|
||||
return {
|
||||
value: initialValue,
|
||||
callback: undefined as ((v: CyclerConfig) => void) | undefined
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a mock node object for testing component integration
|
||||
*/
|
||||
export function createMockNode(options: {
|
||||
id?: number
|
||||
poolConfig?: LoraPoolConfig | null
|
||||
} = {}) {
|
||||
const { id = 1, poolConfig = null } = options
|
||||
return {
|
||||
id,
|
||||
inputs: [],
|
||||
widgets: [],
|
||||
graph: null,
|
||||
getPoolConfig: () => poolConfig,
|
||||
onExecuted: undefined as ((output: unknown) => void) | undefined
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates mock execution output from the backend
|
||||
*/
|
||||
export function createMockExecutionOutput(options: {
|
||||
nextIndex?: number
|
||||
totalCount?: number
|
||||
nextLoraName?: string
|
||||
nextLoraFilename?: string
|
||||
currentLoraName?: string
|
||||
currentLoraFilename?: string
|
||||
} = {}) {
|
||||
const {
|
||||
nextIndex = 2,
|
||||
totalCount = 5,
|
||||
nextLoraName = 'lora2.safetensors',
|
||||
nextLoraFilename = 'lora2.safetensors',
|
||||
currentLoraName = 'lora1.safetensors',
|
||||
currentLoraFilename = 'lora1.safetensors'
|
||||
} = options
|
||||
|
||||
return {
|
||||
next_index: [nextIndex],
|
||||
total_count: [totalCount],
|
||||
next_lora_name: [nextLoraName],
|
||||
next_lora_filename: [nextLoraFilename],
|
||||
current_lora_name: [currentLoraName],
|
||||
current_lora_filename: [currentLoraFilename]
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Sample LoRA lists for specific test scenarios
|
||||
*/
|
||||
export const SAMPLE_LORA_LISTS = {
|
||||
// 3 LoRAs for simple cycling tests
|
||||
small: createMockLoraList(3),
|
||||
|
||||
// 5 LoRAs for standard tests
|
||||
medium: createMockLoraList(5),
|
||||
|
||||
// 10 LoRAs for larger tests
|
||||
large: createMockLoraList(10),
|
||||
|
||||
// Empty list for edge case testing
|
||||
empty: [] as CyclerLoraItem[],
|
||||
|
||||
// Single LoRA for edge case testing
|
||||
single: createMockLoraList(1)
|
||||
}
|
||||
|
||||
/**
|
||||
* Sample pool configs for testing
|
||||
*/
|
||||
export const SAMPLE_POOL_CONFIGS = {
|
||||
// Default SD 1.5 filter
|
||||
sd15: createMockPoolConfig({
|
||||
filters: {
|
||||
baseModels: ['SD 1.5'],
|
||||
tags: { include: [], exclude: [] },
|
||||
folders: { include: [], exclude: [] },
|
||||
license: { noCreditRequired: false, allowSelling: false }
|
||||
}
|
||||
}),
|
||||
|
||||
// SDXL filter
|
||||
sdxl: createMockPoolConfig({
|
||||
filters: {
|
||||
baseModels: ['SDXL'],
|
||||
tags: { include: [], exclude: [] },
|
||||
folders: { include: [], exclude: [] },
|
||||
license: { noCreditRequired: false, allowSelling: false }
|
||||
}
|
||||
}),
|
||||
|
||||
// Filter with tags
|
||||
withTags: createMockPoolConfig({
|
||||
filters: {
|
||||
baseModels: ['SD 1.5'],
|
||||
tags: { include: ['anime', 'style'], exclude: ['realistic'] },
|
||||
folders: { include: [], exclude: [] },
|
||||
license: { noCreditRequired: false, allowSelling: false }
|
||||
}
|
||||
}),
|
||||
|
||||
// Empty/null config
|
||||
empty: null as LoraPoolConfig | null
|
||||
}
|
||||
885
vue-widgets/tests/integration/batchQueue.test.ts
Normal file
885
vue-widgets/tests/integration/batchQueue.test.ts
Normal file
@@ -0,0 +1,885 @@
|
||||
/**
|
||||
* Integration tests for batch queue execution scenarios
|
||||
*
|
||||
* These tests simulate ComfyUI's execution modes to verify correct LoRA cycling behavior.
|
||||
*/
|
||||
|
||||
import { describe, it, expect, beforeEach, vi } from 'vitest'
|
||||
import { useLoraCyclerState } from '@/composables/useLoraCyclerState'
|
||||
import type { CyclerConfig } from '@/composables/types'
|
||||
import {
|
||||
createMockWidget,
|
||||
createMockCyclerConfig,
|
||||
createMockLoraList,
|
||||
createMockPoolConfig
|
||||
} from '../fixtures/mockConfigs'
|
||||
import { setupFetchMock, resetFetchMock } from '../setup'
|
||||
import { BatchQueueSimulator, IndexTracker } from '../utils/BatchQueueSimulator'
|
||||
|
||||
/**
|
||||
* Creates a test harness that mimics the LoraCyclerWidget's behavior
|
||||
*/
|
||||
function createTestHarness(options: {
|
||||
totalCount?: number
|
||||
initialIndex?: number
|
||||
repeatCount?: number
|
||||
isPaused?: boolean
|
||||
} = {}) {
|
||||
const {
|
||||
totalCount = 5,
|
||||
initialIndex = 1,
|
||||
repeatCount = 1,
|
||||
isPaused = false
|
||||
} = options
|
||||
|
||||
const widget = createMockWidget() as any
|
||||
const state = useLoraCyclerState(widget)
|
||||
|
||||
// Initialize state
|
||||
state.totalCount.value = totalCount
|
||||
state.currentIndex.value = initialIndex
|
||||
state.repeatCount.value = repeatCount
|
||||
state.isPaused.value = isPaused
|
||||
|
||||
// Track if first execution
|
||||
const HAS_EXECUTED = Symbol('HAS_EXECUTED')
|
||||
widget[HAS_EXECUTED] = false
|
||||
|
||||
// Execution queue for batch synchronization
|
||||
interface ExecutionContext {
|
||||
isPaused: boolean
|
||||
repeatUsed: number
|
||||
repeatCount: number
|
||||
shouldAdvanceDisplay: boolean
|
||||
displayRepeatUsed: number // Value to show in UI after completion
|
||||
}
|
||||
const executionQueue: ExecutionContext[] = []
|
||||
|
||||
// beforeQueued hook (mirrors LoraCyclerWidget.vue logic)
|
||||
widget.beforeQueued = () => {
|
||||
if (state.isPaused.value) {
|
||||
executionQueue.push({
|
||||
isPaused: true,
|
||||
repeatUsed: state.repeatUsed.value,
|
||||
repeatCount: state.repeatCount.value,
|
||||
shouldAdvanceDisplay: false,
|
||||
displayRepeatUsed: state.displayRepeatUsed.value // Keep current display value when paused
|
||||
})
|
||||
// CRITICAL: Clear execution_index when paused to force backend to use current_index
|
||||
const pausedConfig = state.buildConfig()
|
||||
pausedConfig.execution_index = null
|
||||
widget.value = pausedConfig
|
||||
return
|
||||
}
|
||||
|
||||
if (widget[HAS_EXECUTED]) {
|
||||
if (state.repeatUsed.value < state.repeatCount.value) {
|
||||
state.repeatUsed.value++
|
||||
} else {
|
||||
state.repeatUsed.value = 1
|
||||
state.generateNextIndex()
|
||||
}
|
||||
} else {
|
||||
state.repeatUsed.value = 1
|
||||
state.initializeNextIndex()
|
||||
widget[HAS_EXECUTED] = true
|
||||
}
|
||||
|
||||
const shouldAdvanceDisplay = state.repeatUsed.value >= state.repeatCount.value
|
||||
// Calculate the display value to show after this execution completes
|
||||
// When advancing to a new LoRA: reset to 0 (fresh start for new LoRA)
|
||||
// When repeating same LoRA: show current repeat step
|
||||
const displayRepeatUsed = shouldAdvanceDisplay ? 0 : state.repeatUsed.value
|
||||
|
||||
executionQueue.push({
|
||||
isPaused: false,
|
||||
repeatUsed: state.repeatUsed.value,
|
||||
repeatCount: state.repeatCount.value,
|
||||
shouldAdvanceDisplay,
|
||||
displayRepeatUsed
|
||||
})
|
||||
|
||||
widget.value = state.buildConfig()
|
||||
}
|
||||
|
||||
// Mock node with onExecuted
|
||||
const node = {
|
||||
id: 1,
|
||||
onExecuted: (output: any) => {
|
||||
const context = executionQueue.shift()
|
||||
|
||||
const shouldAdvanceDisplay = context
|
||||
? context.shouldAdvanceDisplay
|
||||
: (!state.isPaused.value && state.repeatUsed.value >= state.repeatCount.value)
|
||||
|
||||
// Update displayRepeatUsed (deferred like index updates)
|
||||
if (context) {
|
||||
state.displayRepeatUsed.value = context.displayRepeatUsed
|
||||
}
|
||||
|
||||
if (shouldAdvanceDisplay && output?.next_index !== undefined) {
|
||||
const val = Array.isArray(output.next_index) ? output.next_index[0] : output.next_index
|
||||
state.currentIndex.value = val
|
||||
}
|
||||
if (output?.total_count !== undefined) {
|
||||
const val = Array.isArray(output.total_count) ? output.total_count[0] : output.total_count
|
||||
state.totalCount.value = val
|
||||
}
|
||||
if (shouldAdvanceDisplay) {
|
||||
if (output?.next_lora_name !== undefined) {
|
||||
const val = Array.isArray(output.next_lora_name) ? output.next_lora_name[0] : output.next_lora_name
|
||||
state.currentLoraName.value = val
|
||||
}
|
||||
if (output?.next_lora_filename !== undefined) {
|
||||
const val = Array.isArray(output.next_lora_filename) ? output.next_lora_filename[0] : output.next_lora_filename
|
||||
state.currentLoraFilename.value = val
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reset execution state (mimics manual index change)
|
||||
const resetExecutionState = () => {
|
||||
widget[HAS_EXECUTED] = false
|
||||
state.executionIndex.value = null
|
||||
state.nextIndex.value = null
|
||||
executionQueue.length = 0
|
||||
}
|
||||
|
||||
return {
|
||||
widget,
|
||||
state,
|
||||
node,
|
||||
executionQueue,
|
||||
resetExecutionState,
|
||||
getConfig: () => state.buildConfig(),
|
||||
HAS_EXECUTED
|
||||
}
|
||||
}
|
||||
|
||||
describe('Batch Queue Integration Tests', () => {
|
||||
beforeEach(() => {
|
||||
resetFetchMock()
|
||||
})
|
||||
|
||||
describe('Basic Cycling', () => {
|
||||
it('should cycle through N LoRAs in batch of N (batch queue mode)', async () => {
|
||||
const harness = createTestHarness({ totalCount: 3 })
|
||||
const simulator = new BatchQueueSimulator({ totalCount: 3 })
|
||||
|
||||
// Simulate batch queue of 3 prompts
|
||||
await simulator.runBatchQueue(
|
||||
3,
|
||||
{
|
||||
beforeQueued: () => harness.widget.beforeQueued(),
|
||||
onExecuted: (output) => harness.node.onExecuted(output)
|
||||
},
|
||||
() => harness.getConfig()
|
||||
)
|
||||
|
||||
// After cycling through all 3, currentIndex should wrap back to 1
|
||||
// First execution: index 1, next becomes 2
|
||||
// Second execution: index 2, next becomes 3
|
||||
// Third execution: index 3, next becomes 1
|
||||
expect(harness.state.currentIndex.value).toBe(1)
|
||||
})
|
||||
|
||||
it('should cycle through N LoRAs in batch of N (sequential mode)', async () => {
|
||||
const harness = createTestHarness({ totalCount: 3 })
|
||||
const simulator = new BatchQueueSimulator({ totalCount: 3 })
|
||||
|
||||
// Simulate sequential execution of 3 prompts
|
||||
await simulator.runSequential(
|
||||
3,
|
||||
{
|
||||
beforeQueued: () => harness.widget.beforeQueued(),
|
||||
onExecuted: (output) => harness.node.onExecuted(output)
|
||||
},
|
||||
() => harness.getConfig()
|
||||
)
|
||||
|
||||
// Same result as batch mode
|
||||
expect(harness.state.currentIndex.value).toBe(1)
|
||||
})
|
||||
|
||||
it('should handle partial cycle (batch of 2 in pool of 5)', async () => {
|
||||
const harness = createTestHarness({ totalCount: 5, initialIndex: 1 })
|
||||
const simulator = new BatchQueueSimulator({ totalCount: 5 })
|
||||
|
||||
await simulator.runBatchQueue(
|
||||
2,
|
||||
{
|
||||
beforeQueued: () => harness.widget.beforeQueued(),
|
||||
onExecuted: (output) => harness.node.onExecuted(output)
|
||||
},
|
||||
() => harness.getConfig()
|
||||
)
|
||||
|
||||
// After 2 executions starting from 1: 1 -> 2 -> 3
|
||||
expect(harness.state.currentIndex.value).toBe(3)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Repeat Functionality', () => {
|
||||
it('should repeat each LoRA repeatCount times', async () => {
|
||||
const harness = createTestHarness({ totalCount: 3, repeatCount: 2 })
|
||||
const simulator = new BatchQueueSimulator({ totalCount: 3 })
|
||||
|
||||
// With repeatCount=2, need 6 executions to cycle through 3 LoRAs
|
||||
await simulator.runBatchQueue(
|
||||
6,
|
||||
{
|
||||
beforeQueued: () => harness.widget.beforeQueued(),
|
||||
onExecuted: (output) => harness.node.onExecuted(output)
|
||||
},
|
||||
() => harness.getConfig()
|
||||
)
|
||||
|
||||
// Should have cycled back to beginning
|
||||
expect(harness.state.currentIndex.value).toBe(1)
|
||||
})
|
||||
|
||||
it('should track repeatUsed correctly during batch', async () => {
|
||||
const harness = createTestHarness({ totalCount: 3, repeatCount: 3 })
|
||||
|
||||
// First beforeQueued: repeatUsed = 1
|
||||
harness.widget.beforeQueued()
|
||||
expect(harness.state.repeatUsed.value).toBe(1)
|
||||
|
||||
// Second beforeQueued: repeatUsed = 2
|
||||
harness.widget.beforeQueued()
|
||||
expect(harness.state.repeatUsed.value).toBe(2)
|
||||
|
||||
// Third beforeQueued: repeatUsed = 3 (will advance on next)
|
||||
harness.widget.beforeQueued()
|
||||
expect(harness.state.repeatUsed.value).toBe(3)
|
||||
|
||||
// Fourth beforeQueued: repeatUsed resets to 1, index advances
|
||||
harness.widget.beforeQueued()
|
||||
expect(harness.state.repeatUsed.value).toBe(1)
|
||||
expect(harness.state.nextIndex.value).toBe(3) // Advanced from 2 to 3
|
||||
})
|
||||
|
||||
it('should not advance display until repeat cycle completes', async () => {
|
||||
const harness = createTestHarness({ totalCount: 5, repeatCount: 2 })
|
||||
const simulator = new BatchQueueSimulator({ totalCount: 5 })
|
||||
|
||||
// First execution: repeatUsed=1 < repeatCount=2, shouldAdvanceDisplay=false
|
||||
// Second execution: repeatUsed=2 >= repeatCount=2, shouldAdvanceDisplay=true
|
||||
|
||||
const indexHistory: number[] = []
|
||||
|
||||
// Override onExecuted to track index changes
|
||||
const originalOnExecuted = harness.node.onExecuted
|
||||
harness.node.onExecuted = (output: any) => {
|
||||
originalOnExecuted(output)
|
||||
indexHistory.push(harness.state.currentIndex.value)
|
||||
}
|
||||
|
||||
await simulator.runBatchQueue(
|
||||
4,
|
||||
{
|
||||
beforeQueued: () => harness.widget.beforeQueued(),
|
||||
onExecuted: (output) => harness.node.onExecuted(output)
|
||||
},
|
||||
() => harness.getConfig()
|
||||
)
|
||||
|
||||
// Index should only change on 2nd and 4th execution
|
||||
// Starting at 1: stay 1, advance to 2, stay 2, advance to 3
|
||||
expect(indexHistory).toEqual([1, 2, 2, 3])
|
||||
})
|
||||
|
||||
it('should defer displayRepeatUsed updates until workflow completion', async () => {
|
||||
const harness = createTestHarness({ totalCount: 3, repeatCount: 3 })
|
||||
|
||||
// Initial state
|
||||
expect(harness.state.displayRepeatUsed.value).toBe(0)
|
||||
|
||||
// Queue 3 executions in batch mode (all beforeQueued before any onExecuted)
|
||||
harness.widget.beforeQueued() // repeatUsed = 1
|
||||
harness.widget.beforeQueued() // repeatUsed = 2
|
||||
harness.widget.beforeQueued() // repeatUsed = 3
|
||||
|
||||
// displayRepeatUsed should NOT have changed yet (still 0)
|
||||
// because no onExecuted has been called
|
||||
expect(harness.state.displayRepeatUsed.value).toBe(0)
|
||||
|
||||
// Now simulate workflow completions
|
||||
harness.node.onExecuted({ next_index: 1 })
|
||||
expect(harness.state.displayRepeatUsed.value).toBe(1)
|
||||
|
||||
harness.node.onExecuted({ next_index: 1 })
|
||||
expect(harness.state.displayRepeatUsed.value).toBe(2)
|
||||
|
||||
harness.node.onExecuted({ next_index: 2 })
|
||||
// After completing repeat cycle, displayRepeatUsed resets to 0
|
||||
expect(harness.state.displayRepeatUsed.value).toBe(0)
|
||||
})
|
||||
|
||||
it('should reset displayRepeatUsed to 0 when advancing to new LoRA', async () => {
|
||||
const harness = createTestHarness({ totalCount: 3, repeatCount: 2 })
|
||||
const simulator = new BatchQueueSimulator({ totalCount: 3 })
|
||||
|
||||
const displayHistory: number[] = []
|
||||
|
||||
const originalOnExecuted = harness.node.onExecuted
|
||||
harness.node.onExecuted = (output: any) => {
|
||||
originalOnExecuted(output)
|
||||
displayHistory.push(harness.state.displayRepeatUsed.value)
|
||||
}
|
||||
|
||||
// Run 4 executions: 2 repeats of LoRA 1, 2 repeats of LoRA 2
|
||||
await simulator.runBatchQueue(
|
||||
4,
|
||||
{
|
||||
beforeQueued: () => harness.widget.beforeQueued(),
|
||||
onExecuted: (output) => harness.node.onExecuted(output)
|
||||
},
|
||||
() => harness.getConfig()
|
||||
)
|
||||
|
||||
// displayRepeatUsed should show:
|
||||
// 1st exec: 1 (first repeat of LoRA 1)
|
||||
// 2nd exec: 0 (complete, reset for next LoRA)
|
||||
// 3rd exec: 1 (first repeat of LoRA 2)
|
||||
// 4th exec: 0 (complete, reset for next LoRA)
|
||||
expect(displayHistory).toEqual([1, 0, 1, 0])
|
||||
})
|
||||
|
||||
it('should show current repeat step when not advancing', async () => {
|
||||
const harness = createTestHarness({ totalCount: 3, repeatCount: 4 })
|
||||
const simulator = new BatchQueueSimulator({ totalCount: 3 })
|
||||
|
||||
const displayHistory: number[] = []
|
||||
|
||||
const originalOnExecuted = harness.node.onExecuted
|
||||
harness.node.onExecuted = (output: any) => {
|
||||
originalOnExecuted(output)
|
||||
displayHistory.push(harness.state.displayRepeatUsed.value)
|
||||
}
|
||||
|
||||
// Run 4 executions: all 4 repeats of the same LoRA
|
||||
await simulator.runBatchQueue(
|
||||
4,
|
||||
{
|
||||
beforeQueued: () => harness.widget.beforeQueued(),
|
||||
onExecuted: (output) => harness.node.onExecuted(output)
|
||||
},
|
||||
() => harness.getConfig()
|
||||
)
|
||||
|
||||
// displayRepeatUsed should show:
|
||||
// 1st exec: 1 (repeat 1/4, not advancing)
|
||||
// 2nd exec: 2 (repeat 2/4, not advancing)
|
||||
// 3rd exec: 3 (repeat 3/4, not advancing)
|
||||
// 4th exec: 0 (repeat 4/4, complete, reset for next LoRA)
|
||||
expect(displayHistory).toEqual([1, 2, 3, 0])
|
||||
})
|
||||
})
|
||||
|
||||
describe('Pause Functionality', () => {
|
||||
it('should maintain index when paused', async () => {
|
||||
const harness = createTestHarness({ totalCount: 5, isPaused: true })
|
||||
const simulator = new BatchQueueSimulator({ totalCount: 5 })
|
||||
|
||||
await simulator.runBatchQueue(
|
||||
3,
|
||||
{
|
||||
beforeQueued: () => harness.widget.beforeQueued(),
|
||||
onExecuted: (output) => harness.node.onExecuted(output)
|
||||
},
|
||||
() => harness.getConfig()
|
||||
)
|
||||
|
||||
// Index should not advance when paused
|
||||
expect(harness.state.currentIndex.value).toBe(1)
|
||||
})
|
||||
|
||||
it('should not count paused executions toward repeat limit', async () => {
|
||||
const harness = createTestHarness({ totalCount: 5, repeatCount: 2 })
|
||||
|
||||
// Run 2 executions while paused
|
||||
harness.state.isPaused.value = true
|
||||
harness.widget.beforeQueued()
|
||||
harness.widget.beforeQueued()
|
||||
|
||||
// repeatUsed should still be 0 (paused executions don't count)
|
||||
expect(harness.state.repeatUsed.value).toBe(0)
|
||||
|
||||
// Unpause and run
|
||||
harness.state.isPaused.value = false
|
||||
harness.widget.beforeQueued()
|
||||
expect(harness.state.repeatUsed.value).toBe(1)
|
||||
})
|
||||
|
||||
it('should preserve displayRepeatUsed when paused', async () => {
|
||||
const harness = createTestHarness({ totalCount: 5, repeatCount: 3 })
|
||||
|
||||
// Run one execution to set displayRepeatUsed
|
||||
harness.widget.beforeQueued()
|
||||
harness.node.onExecuted({ next_index: 1 })
|
||||
expect(harness.state.displayRepeatUsed.value).toBe(1)
|
||||
|
||||
// Pause
|
||||
harness.state.isPaused.value = true
|
||||
|
||||
// Queue and execute while paused
|
||||
harness.widget.beforeQueued()
|
||||
harness.node.onExecuted({ next_index: 1 })
|
||||
|
||||
// displayRepeatUsed should remain at 1 (paused executions don't change it)
|
||||
expect(harness.state.displayRepeatUsed.value).toBe(1)
|
||||
|
||||
// Queue another paused execution
|
||||
harness.widget.beforeQueued()
|
||||
harness.node.onExecuted({ next_index: 1 })
|
||||
|
||||
// Still should be 1
|
||||
expect(harness.state.displayRepeatUsed.value).toBe(1)
|
||||
})
|
||||
|
||||
it('should use same LoRA when pause is toggled mid-batch', async () => {
|
||||
// This tests the critical bug scenario:
|
||||
// 1. User queues multiple prompts (not paused)
|
||||
// 2. All beforeQueued calls complete, each advancing execution_index
|
||||
// 3. User clicks pause
|
||||
// 4. onExecuted starts firing - paused executions should use current_index, not execution_index
|
||||
const harness = createTestHarness({ totalCount: 5 })
|
||||
|
||||
// Queue first prompt (not paused) - this sets up execution_index
|
||||
harness.widget.beforeQueued()
|
||||
const config1 = harness.getConfig()
|
||||
expect(config1.execution_index).toBeNull() // First execution uses current_index
|
||||
|
||||
// User clicks pause mid-batch
|
||||
harness.state.isPaused.value = true
|
||||
|
||||
// Queue subsequent prompts while paused
|
||||
harness.widget.beforeQueued()
|
||||
const config2 = harness.getConfig()
|
||||
// CRITICAL: execution_index should be null when paused to force backend to use current_index
|
||||
expect(config2.execution_index).toBeNull()
|
||||
|
||||
harness.widget.beforeQueued()
|
||||
const config3 = harness.getConfig()
|
||||
expect(config3.execution_index).toBeNull()
|
||||
|
||||
// Verify execution queue has correct context
|
||||
expect(harness.executionQueue.length).toBe(3)
|
||||
expect(harness.executionQueue[0].isPaused).toBe(false)
|
||||
expect(harness.executionQueue[1].isPaused).toBe(true)
|
||||
expect(harness.executionQueue[2].isPaused).toBe(true)
|
||||
})
|
||||
|
||||
it('should have null execution_index in widget.value when paused even after non-paused queues', async () => {
|
||||
// More detailed test for the execution_index clearing behavior
|
||||
// This tests that widget.value (what backend receives) has null execution_index
|
||||
const harness = createTestHarness({ totalCount: 5 })
|
||||
|
||||
// Queue 3 prompts while not paused
|
||||
harness.widget.beforeQueued()
|
||||
harness.widget.beforeQueued()
|
||||
harness.widget.beforeQueued()
|
||||
|
||||
// Verify execution_index was set by non-paused queues in widget.value
|
||||
expect(harness.widget.value.execution_index).not.toBeNull()
|
||||
|
||||
// User pauses
|
||||
harness.state.isPaused.value = true
|
||||
|
||||
// Queue while paused - should clear execution_index in widget.value
|
||||
// This is the value that gets sent to the backend
|
||||
harness.widget.beforeQueued()
|
||||
expect(harness.widget.value.execution_index).toBeNull()
|
||||
|
||||
// State's executionIndex may still have the old value (that's fine)
|
||||
// What matters is widget.value which is what the backend uses
|
||||
})
|
||||
|
||||
it('should have hasQueuedPrompts true when execution queue has items', async () => {
|
||||
// This tests the pause button disabled state
|
||||
const harness = createTestHarness({ totalCount: 5 })
|
||||
|
||||
// Initially no queued prompts
|
||||
expect(harness.executionQueue.length).toBe(0)
|
||||
|
||||
// Queue some prompts
|
||||
harness.widget.beforeQueued()
|
||||
harness.widget.beforeQueued()
|
||||
harness.widget.beforeQueued()
|
||||
|
||||
// Execution queue should have items
|
||||
expect(harness.executionQueue.length).toBe(3)
|
||||
})
|
||||
|
||||
it('should have empty execution queue after all executions complete', async () => {
|
||||
// This tests that pause button becomes enabled after executions complete
|
||||
const harness = createTestHarness({ totalCount: 5 })
|
||||
const simulator = new BatchQueueSimulator({ totalCount: 5 })
|
||||
|
||||
// Run batch queue execution
|
||||
await simulator.runBatchQueue(
|
||||
3,
|
||||
{
|
||||
beforeQueued: () => harness.widget.beforeQueued(),
|
||||
onExecuted: (output) => harness.node.onExecuted(output)
|
||||
},
|
||||
() => harness.getConfig()
|
||||
)
|
||||
|
||||
// After all executions, queue should be empty
|
||||
expect(harness.executionQueue.length).toBe(0)
|
||||
})
|
||||
|
||||
it('should resume cycling after unpause', async () => {
|
||||
const harness = createTestHarness({ totalCount: 3, initialIndex: 2 })
|
||||
const simulator = new BatchQueueSimulator({ totalCount: 3 })
|
||||
|
||||
// Execute once while not paused
|
||||
await simulator.runSingle(
|
||||
{
|
||||
beforeQueued: () => harness.widget.beforeQueued(),
|
||||
onExecuted: (output) => harness.node.onExecuted(output)
|
||||
},
|
||||
() => harness.getConfig()
|
||||
)
|
||||
|
||||
// Pause
|
||||
harness.state.isPaused.value = true
|
||||
|
||||
// Execute twice while paused
|
||||
await simulator.runBatchQueue(
|
||||
2,
|
||||
{
|
||||
beforeQueued: () => harness.widget.beforeQueued(),
|
||||
onExecuted: (output) => harness.node.onExecuted(output)
|
||||
},
|
||||
() => harness.getConfig()
|
||||
)
|
||||
|
||||
// Unpause and execute
|
||||
harness.state.isPaused.value = false
|
||||
|
||||
await simulator.runSingle(
|
||||
{
|
||||
beforeQueued: () => harness.widget.beforeQueued(),
|
||||
onExecuted: (output) => harness.node.onExecuted(output)
|
||||
},
|
||||
() => harness.getConfig()
|
||||
)
|
||||
|
||||
// Should continue from where it left off (index 3 -> 1)
|
||||
expect(harness.state.currentIndex.value).toBe(1)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Manual Index Change', () => {
|
||||
it('should reset execution state on manual index change', async () => {
|
||||
const harness = createTestHarness({ totalCount: 5 })
|
||||
|
||||
// Execute a few times
|
||||
harness.widget.beforeQueued()
|
||||
harness.widget.beforeQueued()
|
||||
|
||||
expect(harness.widget[harness.HAS_EXECUTED]).toBe(true)
|
||||
expect(harness.executionQueue.length).toBe(2)
|
||||
|
||||
// User manually changes index (mimics handleIndexUpdate)
|
||||
harness.resetExecutionState()
|
||||
harness.state.setIndex(4)
|
||||
|
||||
expect(harness.widget[harness.HAS_EXECUTED]).toBe(false)
|
||||
expect(harness.state.executionIndex.value).toBeNull()
|
||||
expect(harness.state.nextIndex.value).toBeNull()
|
||||
expect(harness.executionQueue.length).toBe(0)
|
||||
})
|
||||
|
||||
it('should start fresh cycle from manual index', async () => {
|
||||
const harness = createTestHarness({ totalCount: 5 })
|
||||
const simulator = new BatchQueueSimulator({ totalCount: 5 })
|
||||
|
||||
// Execute twice starting from 1
|
||||
await simulator.runBatchQueue(
|
||||
2,
|
||||
{
|
||||
beforeQueued: () => harness.widget.beforeQueued(),
|
||||
onExecuted: (output) => harness.node.onExecuted(output)
|
||||
},
|
||||
() => harness.getConfig()
|
||||
)
|
||||
|
||||
expect(harness.state.currentIndex.value).toBe(3)
|
||||
|
||||
// User manually sets index to 1
|
||||
harness.resetExecutionState()
|
||||
harness.state.setIndex(1)
|
||||
|
||||
// Execute again - should start fresh from 1
|
||||
await simulator.runBatchQueue(
|
||||
2,
|
||||
{
|
||||
beforeQueued: () => harness.widget.beforeQueued(),
|
||||
onExecuted: (output) => harness.node.onExecuted(output)
|
||||
},
|
||||
() => harness.getConfig()
|
||||
)
|
||||
|
||||
expect(harness.state.currentIndex.value).toBe(3)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Execution Queue Mismatch', () => {
|
||||
it('should handle interrupted execution (queue > executed)', async () => {
|
||||
const harness = createTestHarness({ totalCount: 5 })
|
||||
const simulator = new BatchQueueSimulator({ totalCount: 5 })
|
||||
|
||||
// Queue 5 but only execute 2 (simulates cancel)
|
||||
await simulator.runInterrupted(
|
||||
5, // queued
|
||||
2, // executed
|
||||
{
|
||||
beforeQueued: () => harness.widget.beforeQueued(),
|
||||
onExecuted: (output) => harness.node.onExecuted(output)
|
||||
},
|
||||
() => harness.getConfig()
|
||||
)
|
||||
|
||||
// 3 contexts remain in queue
|
||||
expect(harness.executionQueue.length).toBe(3)
|
||||
|
||||
// Index should reflect only the 2 executions that completed
|
||||
expect(harness.state.currentIndex.value).toBe(3)
|
||||
})
|
||||
|
||||
it('should recover from mismatch on next manual index change', async () => {
|
||||
const harness = createTestHarness({ totalCount: 5 })
|
||||
const simulator = new BatchQueueSimulator({ totalCount: 5 })
|
||||
|
||||
// Create mismatch
|
||||
await simulator.runInterrupted(
|
||||
5,
|
||||
2,
|
||||
{
|
||||
beforeQueued: () => harness.widget.beforeQueued(),
|
||||
onExecuted: (output) => harness.node.onExecuted(output)
|
||||
},
|
||||
() => harness.getConfig()
|
||||
)
|
||||
|
||||
expect(harness.executionQueue.length).toBe(3)
|
||||
|
||||
// Manual index change clears queue
|
||||
harness.resetExecutionState()
|
||||
harness.state.setIndex(1)
|
||||
|
||||
expect(harness.executionQueue.length).toBe(0)
|
||||
|
||||
// Can execute normally again
|
||||
await simulator.runSingle(
|
||||
{
|
||||
beforeQueued: () => harness.widget.beforeQueued(),
|
||||
onExecuted: (output) => harness.node.onExecuted(output)
|
||||
},
|
||||
() => harness.getConfig()
|
||||
)
|
||||
|
||||
expect(harness.state.currentIndex.value).toBe(2)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Edge Cases', () => {
|
||||
it('should handle single item pool', async () => {
|
||||
const harness = createTestHarness({ totalCount: 1 })
|
||||
const simulator = new BatchQueueSimulator({ totalCount: 1 })
|
||||
|
||||
await simulator.runBatchQueue(
|
||||
3,
|
||||
{
|
||||
beforeQueued: () => harness.widget.beforeQueued(),
|
||||
onExecuted: (output) => harness.node.onExecuted(output)
|
||||
},
|
||||
() => harness.getConfig()
|
||||
)
|
||||
|
||||
// Should always stay at index 1
|
||||
expect(harness.state.currentIndex.value).toBe(1)
|
||||
})
|
||||
|
||||
it('should handle empty pool gracefully', async () => {
|
||||
const harness = createTestHarness({ totalCount: 0 })
|
||||
|
||||
// beforeQueued should still work without errors
|
||||
expect(() => harness.widget.beforeQueued()).not.toThrow()
|
||||
})
|
||||
|
||||
it('should handle rapid sequential executions', async () => {
|
||||
const harness = createTestHarness({ totalCount: 5 })
|
||||
const simulator = new BatchQueueSimulator({ totalCount: 5 })
|
||||
|
||||
// Run 20 sequential executions
|
||||
await simulator.runSequential(
|
||||
20,
|
||||
{
|
||||
beforeQueued: () => harness.widget.beforeQueued(),
|
||||
onExecuted: (output) => harness.node.onExecuted(output)
|
||||
},
|
||||
() => harness.getConfig()
|
||||
)
|
||||
|
||||
// 20 % 5 = 0, so should wrap back to 1
|
||||
// But first execution uses index 1, so after 20 executions we're at 21 % 5 = 1
|
||||
expect(harness.state.currentIndex.value).toBe(1)
|
||||
})
|
||||
|
||||
it('should preserve state consistency across many cycles', async () => {
|
||||
const harness = createTestHarness({ totalCount: 3, repeatCount: 2 })
|
||||
const simulator = new BatchQueueSimulator({ totalCount: 3 })
|
||||
|
||||
// Run 100 executions in batches
|
||||
for (let batch = 0; batch < 10; batch++) {
|
||||
await simulator.runBatchQueue(
|
||||
10,
|
||||
{
|
||||
beforeQueued: () => harness.widget.beforeQueued(),
|
||||
onExecuted: (output) => harness.node.onExecuted(output)
|
||||
},
|
||||
() => harness.getConfig()
|
||||
)
|
||||
}
|
||||
|
||||
// Verify state is still valid
|
||||
expect(harness.state.currentIndex.value).toBeGreaterThanOrEqual(1)
|
||||
expect(harness.state.currentIndex.value).toBeLessThanOrEqual(3)
|
||||
expect(harness.state.repeatUsed.value).toBeGreaterThanOrEqual(1)
|
||||
expect(harness.state.repeatUsed.value).toBeLessThanOrEqual(2)
|
||||
expect(harness.executionQueue.length).toBe(0)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Invariant Assertions', () => {
|
||||
it('should always have valid index (1 <= currentIndex <= totalCount)', async () => {
|
||||
const harness = createTestHarness({ totalCount: 5 })
|
||||
const simulator = new BatchQueueSimulator({ totalCount: 5 })
|
||||
|
||||
const checkInvariant = () => {
|
||||
const { currentIndex, totalCount } = harness.state
|
||||
if (totalCount.value > 0) {
|
||||
expect(currentIndex.value).toBeGreaterThanOrEqual(1)
|
||||
expect(currentIndex.value).toBeLessThanOrEqual(totalCount.value)
|
||||
}
|
||||
}
|
||||
|
||||
// Override onExecuted to check invariant after each execution
|
||||
const originalOnExecuted = harness.node.onExecuted
|
||||
harness.node.onExecuted = (output: any) => {
|
||||
originalOnExecuted(output)
|
||||
checkInvariant()
|
||||
}
|
||||
|
||||
await simulator.runBatchQueue(
|
||||
20,
|
||||
{
|
||||
beforeQueued: () => harness.widget.beforeQueued(),
|
||||
onExecuted: (output) => harness.node.onExecuted(output)
|
||||
},
|
||||
() => harness.getConfig()
|
||||
)
|
||||
})
|
||||
|
||||
it('should always have repeatUsed <= repeatCount', async () => {
|
||||
const harness = createTestHarness({ totalCount: 5, repeatCount: 3 })
|
||||
|
||||
const checkInvariant = () => {
|
||||
expect(harness.state.repeatUsed.value).toBeLessThanOrEqual(harness.state.repeatCount.value)
|
||||
}
|
||||
|
||||
// Check after each beforeQueued
|
||||
for (let i = 0; i < 20; i++) {
|
||||
harness.widget.beforeQueued()
|
||||
checkInvariant()
|
||||
}
|
||||
})
|
||||
|
||||
it('should consume all execution contexts (queue empty after matching executions)', async () => {
|
||||
const harness = createTestHarness({ totalCount: 5 })
|
||||
const simulator = new BatchQueueSimulator({ totalCount: 5 })
|
||||
|
||||
await simulator.runBatchQueue(
|
||||
7,
|
||||
{
|
||||
beforeQueued: () => harness.widget.beforeQueued(),
|
||||
onExecuted: (output) => harness.node.onExecuted(output)
|
||||
},
|
||||
() => harness.getConfig()
|
||||
)
|
||||
|
||||
expect(harness.executionQueue.length).toBe(0)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Batch vs Sequential Mode Equivalence', () => {
|
||||
it('should produce same final state in both modes (basic cycle)', async () => {
|
||||
// Create two identical harnesses
|
||||
const batchHarness = createTestHarness({ totalCount: 5 })
|
||||
const seqHarness = createTestHarness({ totalCount: 5 })
|
||||
|
||||
const batchSimulator = new BatchQueueSimulator({ totalCount: 5 })
|
||||
const seqSimulator = new BatchQueueSimulator({ totalCount: 5 })
|
||||
|
||||
// Run same number of executions in different modes
|
||||
await batchSimulator.runBatchQueue(
|
||||
7,
|
||||
{
|
||||
beforeQueued: () => batchHarness.widget.beforeQueued(),
|
||||
onExecuted: (output) => batchHarness.node.onExecuted(output)
|
||||
},
|
||||
() => batchHarness.getConfig()
|
||||
)
|
||||
|
||||
await seqSimulator.runSequential(
|
||||
7,
|
||||
{
|
||||
beforeQueued: () => seqHarness.widget.beforeQueued(),
|
||||
onExecuted: (output) => seqHarness.node.onExecuted(output)
|
||||
},
|
||||
() => seqHarness.getConfig()
|
||||
)
|
||||
|
||||
// Final state should be identical
|
||||
expect(batchHarness.state.currentIndex.value).toBe(seqHarness.state.currentIndex.value)
|
||||
expect(batchHarness.state.repeatUsed.value).toBe(seqHarness.state.repeatUsed.value)
|
||||
expect(batchHarness.state.displayRepeatUsed.value).toBe(seqHarness.state.displayRepeatUsed.value)
|
||||
})
|
||||
|
||||
it('should produce same final state in both modes (with repeat)', async () => {
|
||||
const batchHarness = createTestHarness({ totalCount: 3, repeatCount: 2 })
|
||||
const seqHarness = createTestHarness({ totalCount: 3, repeatCount: 2 })
|
||||
|
||||
const batchSimulator = new BatchQueueSimulator({ totalCount: 3 })
|
||||
const seqSimulator = new BatchQueueSimulator({ totalCount: 3 })
|
||||
|
||||
await batchSimulator.runBatchQueue(
|
||||
10,
|
||||
{
|
||||
beforeQueued: () => batchHarness.widget.beforeQueued(),
|
||||
onExecuted: (output) => batchHarness.node.onExecuted(output)
|
||||
},
|
||||
() => batchHarness.getConfig()
|
||||
)
|
||||
|
||||
await seqSimulator.runSequential(
|
||||
10,
|
||||
{
|
||||
beforeQueued: () => seqHarness.widget.beforeQueued(),
|
||||
onExecuted: (output) => seqHarness.node.onExecuted(output)
|
||||
},
|
||||
() => seqHarness.getConfig()
|
||||
)
|
||||
|
||||
expect(batchHarness.state.currentIndex.value).toBe(seqHarness.state.currentIndex.value)
|
||||
expect(batchHarness.state.repeatUsed.value).toBe(seqHarness.state.repeatUsed.value)
|
||||
expect(batchHarness.state.displayRepeatUsed.value).toBe(seqHarness.state.displayRepeatUsed.value)
|
||||
})
|
||||
})
|
||||
})
|
||||
75
vue-widgets/tests/setup.ts
Normal file
75
vue-widgets/tests/setup.ts
Normal file
@@ -0,0 +1,75 @@
|
||||
/**
|
||||
* Vitest test setup file
|
||||
* Configures global mocks for ComfyUI modules and browser APIs
|
||||
*/
|
||||
|
||||
import { vi } from 'vitest'
|
||||
|
||||
// Mock ComfyUI app module
|
||||
vi.mock('../../../scripts/app.js', () => ({
|
||||
app: {
|
||||
graph: {
|
||||
_nodes: []
|
||||
},
|
||||
registerExtension: vi.fn()
|
||||
}
|
||||
}))
|
||||
|
||||
// Mock ComfyUI loras_widget module
|
||||
vi.mock('../loras_widget.js', () => ({
|
||||
addLoraCard: vi.fn(),
|
||||
removeLoraCard: vi.fn()
|
||||
}))
|
||||
|
||||
// Mock ComfyUI autocomplete module
|
||||
vi.mock('../autocomplete.js', () => ({
|
||||
setupAutocomplete: vi.fn()
|
||||
}))
|
||||
|
||||
// Global fetch mock - exported so tests can access it directly
|
||||
export const mockFetch = vi.fn()
|
||||
vi.stubGlobal('fetch', mockFetch)
|
||||
|
||||
// Helper to reset fetch mock between tests
|
||||
export function resetFetchMock() {
|
||||
mockFetch.mockReset()
|
||||
// Re-stub global to ensure it's the same mock
|
||||
vi.stubGlobal('fetch', mockFetch)
|
||||
}
|
||||
|
||||
// Helper to setup fetch mock with default success response
|
||||
export function setupFetchMock(response: unknown = { success: true, loras: [] }) {
|
||||
// Ensure we're using the same mock
|
||||
mockFetch.mockReset()
|
||||
mockFetch.mockResolvedValue({
|
||||
ok: true,
|
||||
json: () => Promise.resolve(response)
|
||||
})
|
||||
vi.stubGlobal('fetch', mockFetch)
|
||||
return mockFetch
|
||||
}
|
||||
|
||||
// Helper to setup fetch mock with error response
|
||||
export function setupFetchErrorMock(error: string = 'Network error') {
|
||||
mockFetch.mockReset()
|
||||
mockFetch.mockRejectedValue(new Error(error))
|
||||
vi.stubGlobal('fetch', mockFetch)
|
||||
return mockFetch
|
||||
}
|
||||
|
||||
// Mock btoa for hashing (jsdom should have this, but just in case)
|
||||
if (typeof global.btoa === 'undefined') {
|
||||
vi.stubGlobal('btoa', (str: string) => Buffer.from(str).toString('base64'))
|
||||
}
|
||||
|
||||
// Mock console methods to reduce noise in tests
|
||||
vi.spyOn(console, 'log').mockImplementation(() => {})
|
||||
vi.spyOn(console, 'error').mockImplementation(() => {})
|
||||
vi.spyOn(console, 'warn').mockImplementation(() => {})
|
||||
|
||||
// Re-enable console for debugging when needed
|
||||
export function enableConsole() {
|
||||
vi.spyOn(console, 'log').mockRestore()
|
||||
vi.spyOn(console, 'error').mockRestore()
|
||||
vi.spyOn(console, 'warn').mockRestore()
|
||||
}
|
||||
230
vue-widgets/tests/utils/BatchQueueSimulator.ts
Normal file
230
vue-widgets/tests/utils/BatchQueueSimulator.ts
Normal file
@@ -0,0 +1,230 @@
|
||||
/**
|
||||
* BatchQueueSimulator - Simulates ComfyUI's two execution modes
|
||||
*
|
||||
* ComfyUI has two distinct execution patterns:
|
||||
* 1. Batch Queue Mode: ALL beforeQueued calls happen BEFORE any onExecuted calls
|
||||
* 2. Sequential Mode: beforeQueued and onExecuted interleave for each prompt
|
||||
*
|
||||
* This simulator helps test how the widget behaves in both modes.
|
||||
*/
|
||||
|
||||
import type { CyclerConfig } from '@/composables/types'
|
||||
|
||||
export interface ExecutionHooks {
|
||||
/** Called when a prompt is queued (before execution) */
|
||||
beforeQueued: () => void
|
||||
/** Called when execution completes with output */
|
||||
onExecuted: (output: unknown) => void
|
||||
}
|
||||
|
||||
export interface SimulatorOptions {
|
||||
/** Total number of LoRAs in the pool */
|
||||
totalCount: number
|
||||
/** Function to generate output for each execution */
|
||||
generateOutput?: (executionIndex: number, config: CyclerConfig) => unknown
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates execution output based on the current state
|
||||
*/
|
||||
function defaultGenerateOutput(executionIndex: number, config: CyclerConfig) {
|
||||
// Calculate what the next index would be after this execution
|
||||
let nextIdx = (config.execution_index ?? config.current_index) + 1
|
||||
if (nextIdx > config.total_count) {
|
||||
nextIdx = 1
|
||||
}
|
||||
|
||||
return {
|
||||
next_index: [nextIdx],
|
||||
total_count: [config.total_count],
|
||||
next_lora_name: [`lora${nextIdx}.safetensors`],
|
||||
next_lora_filename: [`lora${nextIdx}.safetensors`],
|
||||
current_lora_name: [`lora${config.execution_index ?? config.current_index}.safetensors`],
|
||||
current_lora_filename: [`lora${config.execution_index ?? config.current_index}.safetensors`]
|
||||
}
|
||||
}
|
||||
|
||||
export class BatchQueueSimulator {
|
||||
private executionCount = 0
|
||||
private options: Required<SimulatorOptions>
|
||||
|
||||
constructor(options: SimulatorOptions) {
|
||||
this.options = {
|
||||
totalCount: options.totalCount,
|
||||
generateOutput: options.generateOutput ?? defaultGenerateOutput
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Reset the simulator state
|
||||
*/
|
||||
reset() {
|
||||
this.executionCount = 0
|
||||
}
|
||||
|
||||
/**
|
||||
* Simulates Batch Queue Mode execution
|
||||
*
|
||||
* In this mode, ComfyUI queues multiple prompts at once:
|
||||
* - ALL beforeQueued() calls happen first (for all prompts in the batch)
|
||||
* - THEN all onExecuted() calls happen (as each prompt completes)
|
||||
*
|
||||
* This is the mode used when queueing multiple prompts from the UI.
|
||||
*
|
||||
* @param count Number of prompts to simulate
|
||||
* @param hooks The widget's execution hooks
|
||||
* @param getConfig Function to get current widget config state
|
||||
*/
|
||||
async runBatchQueue(
|
||||
count: number,
|
||||
hooks: ExecutionHooks,
|
||||
getConfig: () => CyclerConfig
|
||||
): Promise<void> {
|
||||
// Phase 1: All beforeQueued calls (snapshot configs)
|
||||
const snapshotConfigs: CyclerConfig[] = []
|
||||
|
||||
for (let i = 0; i < count; i++) {
|
||||
hooks.beforeQueued()
|
||||
// Snapshot the config after beforeQueued updates it
|
||||
snapshotConfigs.push({ ...getConfig() })
|
||||
}
|
||||
|
||||
// Phase 2: All onExecuted calls (in order)
|
||||
for (let i = 0; i < count; i++) {
|
||||
const config = snapshotConfigs[i]
|
||||
const output = this.options.generateOutput(this.executionCount, config)
|
||||
hooks.onExecuted(output)
|
||||
this.executionCount++
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Simulates Sequential Mode execution
|
||||
*
|
||||
* In this mode, execution is one-at-a-time:
|
||||
* - beforeQueued() is called
|
||||
* - onExecuted() is called
|
||||
* - Then the next prompt's beforeQueued() is called
|
||||
* - And so on...
|
||||
*
|
||||
* This is the mode used in API-driven execution or single prompt queuing.
|
||||
*
|
||||
* @param count Number of prompts to simulate
|
||||
* @param hooks The widget's execution hooks
|
||||
* @param getConfig Function to get current widget config state
|
||||
*/
|
||||
async runSequential(
|
||||
count: number,
|
||||
hooks: ExecutionHooks,
|
||||
getConfig: () => CyclerConfig
|
||||
): Promise<void> {
|
||||
for (let i = 0; i < count; i++) {
|
||||
// Queue the prompt
|
||||
hooks.beforeQueued()
|
||||
const config = { ...getConfig() }
|
||||
|
||||
// Execute it immediately
|
||||
const output = this.options.generateOutput(this.executionCount, config)
|
||||
hooks.onExecuted(output)
|
||||
this.executionCount++
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Simulates a single execution (queue + execute)
|
||||
*/
|
||||
async runSingle(
|
||||
hooks: ExecutionHooks,
|
||||
getConfig: () => CyclerConfig
|
||||
): Promise<void> {
|
||||
return this.runSequential(1, hooks, getConfig)
|
||||
}
|
||||
|
||||
/**
|
||||
* Simulates interrupted execution (some beforeQueued calls without matching onExecuted)
|
||||
*
|
||||
* This can happen if the user cancels execution mid-batch.
|
||||
*
|
||||
* @param queuedCount Number of prompts queued (beforeQueued called)
|
||||
* @param executedCount Number of prompts that actually executed
|
||||
* @param hooks The widget's execution hooks
|
||||
* @param getConfig Function to get current widget config state
|
||||
*/
|
||||
async runInterrupted(
|
||||
queuedCount: number,
|
||||
executedCount: number,
|
||||
hooks: ExecutionHooks,
|
||||
getConfig: () => CyclerConfig
|
||||
): Promise<void> {
|
||||
if (executedCount > queuedCount) {
|
||||
throw new Error('executedCount cannot be greater than queuedCount')
|
||||
}
|
||||
|
||||
// Phase 1: All beforeQueued calls
|
||||
const snapshotConfigs: CyclerConfig[] = []
|
||||
for (let i = 0; i < queuedCount; i++) {
|
||||
hooks.beforeQueued()
|
||||
snapshotConfigs.push({ ...getConfig() })
|
||||
}
|
||||
|
||||
// Phase 2: Only some onExecuted calls
|
||||
for (let i = 0; i < executedCount; i++) {
|
||||
const config = snapshotConfigs[i]
|
||||
const output = this.options.generateOutput(this.executionCount, config)
|
||||
hooks.onExecuted(output)
|
||||
this.executionCount++
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper to create execution hooks from a widget-like object
|
||||
*/
|
||||
export function createHooksFromWidget(widget: {
|
||||
beforeQueued?: () => void
|
||||
}, node: {
|
||||
onExecuted?: (output: unknown) => void
|
||||
}): ExecutionHooks {
|
||||
return {
|
||||
beforeQueued: () => widget.beforeQueued?.(),
|
||||
onExecuted: (output) => node.onExecuted?.(output)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Tracks index history during simulation for assertions
|
||||
*/
|
||||
export class IndexTracker {
|
||||
public indexHistory: number[] = []
|
||||
public repeatHistory: number[] = []
|
||||
public pauseHistory: boolean[] = []
|
||||
|
||||
reset() {
|
||||
this.indexHistory = []
|
||||
this.repeatHistory = []
|
||||
this.pauseHistory = []
|
||||
}
|
||||
|
||||
record(config: CyclerConfig) {
|
||||
this.indexHistory.push(config.current_index)
|
||||
this.repeatHistory.push(config.repeat_used)
|
||||
this.pauseHistory.push(config.is_paused)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the sequence of indices that were actually used for execution
|
||||
*/
|
||||
getExecutionIndices(): number[] {
|
||||
return this.indexHistory
|
||||
}
|
||||
|
||||
/**
|
||||
* Verify that indices cycle correctly through totalCount
|
||||
*/
|
||||
verifyCyclePattern(expectedPattern: number[]): boolean {
|
||||
if (this.indexHistory.length !== expectedPattern.length) {
|
||||
return false
|
||||
}
|
||||
return this.indexHistory.every((idx, i) => idx === expectedPattern[i])
|
||||
}
|
||||
}
|
||||
@@ -19,6 +19,6 @@
|
||||
"@/*": ["./src/*"]
|
||||
}
|
||||
},
|
||||
"include": ["src/**/*.ts", "src/**/*.tsx", "src/**/*.vue"],
|
||||
"include": ["src/**/*.ts", "src/**/*.tsx", "src/**/*.vue", "tests/**/*.ts"],
|
||||
"references": [{ "path": "./tsconfig.node.json" }]
|
||||
}
|
||||
|
||||
@@ -22,8 +22,10 @@ export default defineConfig({
|
||||
rollupOptions: {
|
||||
external: [
|
||||
'../../../scripts/app.js',
|
||||
'../../../scripts/api.js',
|
||||
'../loras_widget.js',
|
||||
'../autocomplete.js'
|
||||
'../autocomplete.js',
|
||||
'../preview_tooltip.js'
|
||||
],
|
||||
output: {
|
||||
dir: '../web/comfyui/vue-widgets',
|
||||
|
||||
25
vue-widgets/vitest.config.ts
Normal file
25
vue-widgets/vitest.config.ts
Normal file
@@ -0,0 +1,25 @@
|
||||
import { defineConfig } from 'vitest/config'
|
||||
import vue from '@vitejs/plugin-vue'
|
||||
import { resolve } from 'path'
|
||||
|
||||
export default defineConfig({
|
||||
plugins: [vue()],
|
||||
resolve: {
|
||||
alias: {
|
||||
'@': resolve(__dirname, './src')
|
||||
}
|
||||
},
|
||||
test: {
|
||||
globals: true,
|
||||
environment: 'jsdom',
|
||||
setupFiles: ['./tests/setup.ts'],
|
||||
include: ['tests/**/*.test.ts'],
|
||||
coverage: {
|
||||
provider: 'v8',
|
||||
reporter: ['text', 'html', 'json'],
|
||||
reportsDirectory: './coverage',
|
||||
include: ['src/**/*.ts', 'src/**/*.vue'],
|
||||
exclude: ['src/main.ts', 'src/vite-env.d.ts']
|
||||
}
|
||||
}
|
||||
})
|
||||
@@ -2,6 +2,7 @@ import { api } from "../../scripts/api.js";
|
||||
import { app } from "../../scripts/app.js";
|
||||
import { TextAreaCaretHelper } from "./textarea_caret_helper.js";
|
||||
import { getPromptTagAutocompletePreference, getTagSpaceReplacementPreference } from "./settings.js";
|
||||
import { showToast } from "./utils.js";
|
||||
|
||||
// Command definitions for category filtering
|
||||
const TAG_COMMANDS = {
|
||||
@@ -15,6 +16,21 @@ const TAG_COMMANDS = {
|
||||
'/lore': { categories: [15], label: 'Lore' },
|
||||
'/emb': { type: 'embedding', label: 'Embeddings' },
|
||||
'/embedding': { type: 'embedding', label: 'Embeddings' },
|
||||
// Autocomplete toggle commands - only show one based on current state
|
||||
'/ac': {
|
||||
type: 'toggle_setting',
|
||||
settingId: 'loramanager.prompt_tag_autocomplete',
|
||||
value: true,
|
||||
label: 'Autocomplete: ON',
|
||||
condition: () => !getPromptTagAutocompletePreference()
|
||||
},
|
||||
'/noac': {
|
||||
type: 'toggle_setting',
|
||||
settingId: 'loramanager.prompt_tag_autocomplete',
|
||||
value: false,
|
||||
label: 'Autocomplete: OFF',
|
||||
condition: () => getPromptTagAutocompletePreference()
|
||||
},
|
||||
};
|
||||
|
||||
// Category display information
|
||||
@@ -488,6 +504,10 @@ class AutoComplete {
|
||||
this.searchType = 'commands';
|
||||
this._showCommandList(commandResult.commandFilter);
|
||||
return;
|
||||
} else if (commandResult.command?.type === 'toggle_setting') {
|
||||
// Handle toggle setting command (/ac, /noac)
|
||||
this._handleToggleSettingCommand(commandResult.command);
|
||||
return;
|
||||
} else if (commandResult.command) {
|
||||
// Command is active, use filtered search
|
||||
this.showingCommands = false;
|
||||
@@ -509,7 +529,10 @@ class AutoComplete {
|
||||
this.showingCommands = false;
|
||||
this.activeCommand = null;
|
||||
endpoint = '/lm/custom-words/search?enriched=true';
|
||||
searchTerm = rawSearchTerm;
|
||||
// Extract last space-separated token for search
|
||||
// Tag names don't contain spaces, so we only need the last token
|
||||
// This allows "hello 1gi" to search for "1gi" and find "1girl"
|
||||
searchTerm = this._getLastSpaceToken(rawSearchTerm);
|
||||
this.searchType = 'custom_words';
|
||||
} else {
|
||||
// No command and setting disabled - no autocomplete for direct typing
|
||||
@@ -545,6 +568,17 @@ class AutoComplete {
|
||||
return lastSegment.trim();
|
||||
}
|
||||
|
||||
/**
|
||||
* Extract the last space-separated token from a search term
|
||||
* Tag names don't contain spaces, so for tag autocomplete we only need the last token
|
||||
* @param {string} term - The full search term (e.g., "hello 1gi")
|
||||
* @returns {string} - The last token (e.g., "1gi"), or the original term if no spaces
|
||||
*/
|
||||
_getLastSpaceToken(term) {
|
||||
const tokens = term.trim().split(/\s+/);
|
||||
return tokens[tokens.length - 1] || term;
|
||||
}
|
||||
|
||||
async search(term = '', endpoint = null) {
|
||||
try {
|
||||
this.currentSearchTerm = term;
|
||||
@@ -606,9 +640,14 @@ class AutoComplete {
|
||||
|
||||
// Check for exact command match
|
||||
if (TAG_COMMANDS[partialCommand]) {
|
||||
const cmd = TAG_COMMANDS[partialCommand];
|
||||
// Filter out toggle commands that don't meet their condition
|
||||
if (cmd.type === 'toggle_setting' && cmd.condition && !cmd.condition()) {
|
||||
return { showCommands: false, command: null, searchTerm: '' };
|
||||
}
|
||||
return {
|
||||
showCommands: false,
|
||||
command: TAG_COMMANDS[partialCommand],
|
||||
command: cmd,
|
||||
searchTerm: '',
|
||||
};
|
||||
}
|
||||
@@ -627,9 +666,14 @@ class AutoComplete {
|
||||
const searchPart = trimmed.slice(spaceIndex + 1).trim();
|
||||
|
||||
if (TAG_COMMANDS[commandPart]) {
|
||||
const cmd = TAG_COMMANDS[commandPart];
|
||||
// Filter out toggle commands that don't meet their condition
|
||||
if (cmd.type === 'toggle_setting' && cmd.condition && !cmd.condition()) {
|
||||
return { showCommands: false, command: null, searchTerm: trimmed };
|
||||
}
|
||||
return {
|
||||
showCommands: false,
|
||||
command: TAG_COMMANDS[commandPart],
|
||||
command: cmd,
|
||||
searchTerm: searchPart,
|
||||
};
|
||||
}
|
||||
@@ -652,6 +696,11 @@ class AutoComplete {
|
||||
for (const [cmd, info] of Object.entries(TAG_COMMANDS)) {
|
||||
if (seenLabels.has(info.label)) continue;
|
||||
|
||||
// Filter out toggle commands that don't meet their condition
|
||||
if (info.type === 'toggle_setting' && info.condition) {
|
||||
if (!info.condition()) continue;
|
||||
}
|
||||
|
||||
if (!filter || cmd.slice(1).startsWith(filterLower)) {
|
||||
seenLabels.add(info.label);
|
||||
commands.push({ command: cmd, ...info });
|
||||
@@ -1117,7 +1166,16 @@ class AutoComplete {
|
||||
|
||||
// Use getSearchTerm to get the current search term before cursor
|
||||
const beforeCursor = currentValue.substring(0, caretPos);
|
||||
const searchTerm = this.getSearchTerm(beforeCursor);
|
||||
const fullSearchTerm = this.getSearchTerm(beforeCursor);
|
||||
|
||||
// For regular tag autocomplete (no command), only replace the last space-separated token
|
||||
// This allows "hello 1gi" + selecting "1girl" to become "hello 1girl, "
|
||||
// Command mode (e.g., "/char miku") should replace the entire command+search
|
||||
let searchTerm = fullSearchTerm;
|
||||
if (this.modelType === 'prompt' && this.searchType === 'custom_words' && !this.activeCommand) {
|
||||
searchTerm = this._getLastSpaceToken(fullSearchTerm);
|
||||
}
|
||||
|
||||
const searchStartPos = caretPos - searchTerm.length;
|
||||
|
||||
// Only replace the search term, not everything after the last comma
|
||||
@@ -1175,6 +1233,119 @@ class AutoComplete {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle toggle setting command (/ac, /noac)
|
||||
* @param {Object} command - The toggle command with settingId and value
|
||||
*/
|
||||
async _handleToggleSettingCommand(command) {
|
||||
const { settingId, value } = command;
|
||||
|
||||
try {
|
||||
// Use ComfyUI's setting API to update global setting
|
||||
const settingManager = app?.extensionManager?.setting;
|
||||
if (settingManager && typeof settingManager.set === 'function') {
|
||||
await settingManager.set(settingId, value);
|
||||
this._showToggleFeedback(value);
|
||||
this._clearCurrentToken();
|
||||
} else {
|
||||
// Fallback: use legacy settings API
|
||||
const setting = app.ui.settings.settingsById?.[settingId];
|
||||
if (setting) {
|
||||
app.ui.settings.setSettingValue(settingId, value);
|
||||
this._showToggleFeedback(value);
|
||||
this._clearCurrentToken();
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('[Lora Manager] Failed to toggle setting:', error);
|
||||
showToast({
|
||||
severity: 'error',
|
||||
summary: 'Error',
|
||||
detail: 'Failed to toggle autocomplete setting',
|
||||
life: 3000
|
||||
});
|
||||
}
|
||||
|
||||
this.hide();
|
||||
}
|
||||
|
||||
/**
|
||||
* Show visual feedback for toggle action using toast
|
||||
* @param {boolean} enabled - New autocomplete state
|
||||
*/
|
||||
_showToggleFeedback(enabled) {
|
||||
showToast({
|
||||
severity: enabled ? 'success' : 'secondary',
|
||||
summary: enabled ? 'Autocomplete Enabled' : 'Autocomplete Disabled',
|
||||
detail: enabled
|
||||
? 'Tag autocomplete is now ON. Type to see suggestions.'
|
||||
: 'Tag autocomplete is now OFF. Use /ac to re-enable.',
|
||||
life: 3000
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Clear the current command token from input
|
||||
* Preserves leading spaces after delimiters (e.g., "1girl, /ac" -> "1girl, ")
|
||||
*/
|
||||
_clearCurrentToken() {
|
||||
const currentValue = this.inputElement.value;
|
||||
const caretPos = this.inputElement.selectionStart;
|
||||
|
||||
// Find the command text before cursor
|
||||
const beforeCursor = currentValue.substring(0, caretPos);
|
||||
const segments = beforeCursor.split(/[,\>]+/);
|
||||
const lastSegment = segments[segments.length - 1] || '';
|
||||
|
||||
// Find the command start position, preserving leading spaces
|
||||
// lastSegment includes leading spaces (e.g., " /ac"), find where command actually starts
|
||||
const commandMatch = lastSegment.match(/^(\s*)(\/\w+)/);
|
||||
if (commandMatch) {
|
||||
// commandMatch[1] is leading spaces, commandMatch[2] is the command
|
||||
const leadingSpaces = commandMatch[1].length;
|
||||
// Keep the spaces by starting after them
|
||||
const commandStartPos = caretPos - lastSegment.length + leadingSpaces;
|
||||
|
||||
// Skip trailing spaces when deleting
|
||||
let endPos = caretPos;
|
||||
while (endPos < currentValue.length && currentValue[endPos] === ' ') {
|
||||
endPos++;
|
||||
}
|
||||
|
||||
const newValue = currentValue.substring(0, commandStartPos) + currentValue.substring(endPos);
|
||||
const newCaretPos = commandStartPos;
|
||||
|
||||
this.inputElement.value = newValue;
|
||||
|
||||
// Trigger input event to notify about the change
|
||||
const event = new Event('input', { bubbles: true });
|
||||
this.inputElement.dispatchEvent(event);
|
||||
|
||||
// Focus back to input and position cursor
|
||||
this.inputElement.focus();
|
||||
this.inputElement.setSelectionRange(newCaretPos, newCaretPos);
|
||||
} else {
|
||||
// Fallback: delete the whole last segment (original behavior)
|
||||
const commandStartPos = caretPos - lastSegment.length;
|
||||
|
||||
let endPos = caretPos;
|
||||
while (endPos < currentValue.length && currentValue[endPos] === ' ') {
|
||||
endPos++;
|
||||
}
|
||||
|
||||
const newValue = currentValue.substring(0, commandStartPos) + currentValue.substring(endPos);
|
||||
const newCaretPos = commandStartPos;
|
||||
|
||||
this.inputElement.value = newValue;
|
||||
|
||||
const event = new Event('input', { bubbles: true });
|
||||
this.inputElement.dispatchEvent(event);
|
||||
|
||||
this.inputElement.focus();
|
||||
this.inputElement.setSelectionRange(newCaretPos, newCaretPos);
|
||||
}
|
||||
}
|
||||
|
||||
destroy() {
|
||||
if (this.debounceTimer) {
|
||||
clearTimeout(this.debounceTimer);
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
/* Shared styling for the LoRA Manager frontend widgets */
|
||||
.lm-tooltip {
|
||||
position: fixed;
|
||||
z-index: 9999;
|
||||
z-index: 10001;
|
||||
background: rgba(0, 0, 0, 0.85);
|
||||
border-radius: 6px;
|
||||
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.3);
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user