Compare commits

...

26 Commits

Author SHA1 Message Date
Will Miao
bd83f7520e chore: bump version to 0.9.14 2026-02-02 23:17:35 +08:00
Will Miao
b9a4e7a09b docs(release): add v0.9.14 release notes
- Add LoRA Cycler node with iteration support
- Enhance Prompt node with tag autocomplete (Danbooru + e621)
- Add command system (/char, /artist, /ac, /noac) for tag operations
- Reference Lora Cycler and Lora Manager Basic template workflows
- Bug fixes and stability improvements
2026-02-02 23:09:06 +08:00
Will Miao
c30e57ede8 fix(recipes): add data-folder attribute to RecipeCard for correct drag-drop path calculation 2026-02-02 22:18:13 +08:00
Will Miao
0dba1b336d feat(template): update prompt node usage in basic template workflow 2026-02-02 21:58:51 +08:00
Will Miao
820afe9319 feat(recipe_scanner): ensure cache initialization and improve type safety
- Initialize RecipeCache in scan_recipes to prevent None reference errors
- Import PersistedRecipeData directly instead of using string annotation
- Remove redundant import inside _reconcile_recipe_cache method
2026-02-02 21:57:44 +08:00
Will Miao
5a97f4bc75 feat(recipe_scanner): optimize recipe lookup performance
Refactor recipe lookup logic to improve efficiency from O(n²) to O(n + m):
- Build recipe_by_id dictionary for O(1) recipe ID lookups
- Simplify persisted_by_path construction using recipe_id extraction
- Add fallback lookup by recipe_id when path lookup fails
- Maintain same functionality while reducing computational complexity
2026-02-02 19:37:06 +08:00
Will Miao
94da404cc5 fix: skip confirmed not-found models in bulk metadata refresh
When enable_metadata_archive_db=True, the previous filter logic would
repeatedly try to fetch metadata for models that were already confirmed
to not exist on CivitAI (from_civitai=False, civitai_deleted=True).

The fix adds a skip condition to exclude models that:
1. Are confirmed not from CivitAI (from_civitai=False)
2. Are marked as deleted/not found on CivitAI (civitai_deleted=True)
3. Either have no archive DB enabled, or have already been checked (db_checked=True)

This prevents unnecessary API calls to CivArchive for user-trained models
or models from non-CivitAI sources.

Fixes repeated "Error fetching version of CivArchive model by hash" logs
for models that will never be found on CivitAI/CivArchive.
2026-02-02 13:27:18 +08:00
Will Miao
1da476d858 feat(example-images): add check pending models endpoint and improve async handling
- Add /api/example-images/check-pending endpoint to quickly check models needing downloads
- Improve DownloadManager.start_download() to return immediately without blocking
- Add _handle_download_task_done callback for proper error handling and progress saving
- Add check_pending_models() method for lightweight pre-download validation
- Update frontend ExampleImagesManager to use new check-pending endpoint
- Add comprehensive tests for new functionality
2026-02-02 12:31:07 +08:00
Will Miao
1daaff6bd4 feat: add LoRa Manager E2E testing skill documentation
Introduce comprehensive documentation for the new `lora-manager-e2e` skill, which provides end-to-end testing workflows for LoRa Manager. The skill enables automated validation of standalone mode, including server management, UI interaction via Chrome DevTools MCP, and frontend-to-backend integration testing.

Key additions:
- Detailed skill description and prerequisites
- Quick start workflow for server setup and browser debugging
- Common E2E test patterns for page load verification, server restart, and API testing
- Example test flows demonstrating step-by-step validation procedures
- Scripts and MCP command examples for practical implementation

This documentation supports automated testing of LoRa Manager's web interface and backend functionality, ensuring reliable end-to-end validation of features.
2026-02-02 12:15:58 +08:00
Will Miao
e252e44403 refactor(logging): replace print statements with logger for consistency 2026-02-02 10:47:17 +08:00
Will Miao
778ad8abd2 feat(cache): add cache health monitoring and validation system, see #730
- Add cache entry validator service for data integrity checks
- Add cache health monitor service for periodic health checks
- Enhance model cache and scanner with validation support
- Update websocket manager for health status broadcasting
- Add initialization banner service for cache health alerts
- Add comprehensive test coverage for new services
- Update translations across all locales
- Refactor sync translation keys script
2026-02-02 08:30:59 +08:00
Will Miao
68cf381b50 feat(autocomplete): improve tag search to use last token for multi-word prompts
- Modify custom words search to extract last space-separated token from search term
- Add `_getLastSpaceToken` helper method for token extraction
- Update selection replacement logic to only replace last token in multi-word prompts
- Enables searching "hello 1gi" to find "1girl" and replace only "1gi" with "1girl"
- Maintains full command replacement for command mode (e.g., "/char miku")
2026-02-01 22:09:21 +08:00
Will Miao
337f73e711 fix(slider): fix floating point precision issues in SingleSlider and DualRangeSlider
JavaScript floating point arithmetic causes values like 1.1 to become
1.1000000000000014. Add precision limiting to 2 decimal places in
snapToStep function for both sliders.
2026-02-01 21:03:04 +08:00
Will Miao
04ba966a6e feat: Add LoRA selector modal to Cycler widget
- Add LoraListModal component with search and preview tooltip
- Make 'Next LoRA' name clickable to open selector modal
- Integrate PreviewTooltip with custom resolver for Vue widgets
- Disable selector when prompts are queued (consistent with pause button)
- Fix tooltip z-index to display above modal backdrop

Fixes issue: users couldn't easily identify which index corresponds
to specific LoRA in large lists
2026-02-01 20:58:30 +08:00
Will Miao
71c8cf84e0 refactor(LoraCyclerWidget): UI/UX improvements
- Replace REP badge with segmented progress bar for repeat indicator
- Reorganize Starting Index & Repeat controls into aligned groups
- Change repeat format from '× [count] times' to '[count] ×' for better alignment
- Remove unnecessary refresh button and related logic
2026-02-01 20:00:30 +08:00
Will Miao
db1aec94e5 refactor(logging): replace print statements with logger in metadata_collector 2026-02-01 15:41:41 +08:00
Will Miao
553e1868e1 perf(config): limit symlink scan to first level for faster startup
Replace recursive directory traversal with first-level-only symlink scanning
to fix severe performance issues on large model collections (220K+ files).

- Rename _scan_directory_links to _scan_first_level_symlinks
- Only scan symlinks directly under each root directory
- Skip traversal of normal subdirectories entirely
- Update tests to reflect first-level behavior
- Add test_deep_symlink_not_scanned to document intentional limitation

Startup time reduced from 15+ minutes to seconds for affected users.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-01 12:42:35 +08:00
Will Miao
938ceb49b2 feat(autocomplete): add toggle commands for autocomplete setting
- Add `/ac` and `/noac` commands to toggle prompt tag autocomplete on/off
- Commands only appear when relevant (e.g., `/ac` shows when autocomplete is off)
- Show toast notification when toggling setting
- Use ComfyUI's setting API with fallback to legacy API
- Clear autocomplete token after toggling to provide clean UX
2026-02-01 12:34:38 +08:00
Will Miao
c0f03b79a8 feat(settings): change model card footer action default to replace_preview 2026-02-01 07:38:04 +08:00
Will Miao
a492638133 feat(lora-cycler): disable pause button when prompts are queued
- Add `hasQueuedPrompts` reactive flag to track queued executions
- Pass `is-pause-disabled` prop to settings view to disable pause button
- Update pause button title to indicate why it's disabled
- Remove server queue clearing logic from pause toggle handler
- Clear `hasQueuedPrompts` flag when manually changing index or resetting
- Set `hasQueuedPrompts` to true when adding prompts to execution queue
- Update flag when processing queued executions to reflect current queue state
2026-02-01 01:12:39 +08:00
Will Miao
e17d6c8ebf feat(testing): enhance test configuration and add Vue component tests
- Update package.json test script to run both JS and Vue tests
- Simplify LoraCyclerLM output by removing redundant lora name fallback
- Extend Vitest config to include TypeScript test files
- Add Vue testing dependencies and setup for component testing
- Implement comprehensive test suite for BatchQueueSimulator component
- Add test setup file with global mocks for ComfyUI modules
2026-02-01 00:59:50 +08:00
Will Miao
ffcfe5ea3e fix(metadata): rename model_type to sub_type and add embedding subtype, see #797
- Change `model_type` field to `sub_type` for checkpoint models to improve naming consistency
- Add `sub_type="embedding"` for embedding models to properly categorize model subtypes
- Maintain backward compatibility with existing metadata structure
2026-01-31 22:54:53 +08:00
Will Miao
719e18adb6 feat(media): add media type hint support for file extension detection, fixes #795 and fixes #751
- Add optional `media_type_hint` parameter to `_get_file_extension_from_content_or_headers` method
- When `media_type_hint` is "video" and no extension can be determined from content/headers/URL, default to `.mp4`
- Pass image metadata type as hint in both `process_example_images` and `process_example_images_batch` methods
- Add unit tests to verify media type hint behavior and priority
2026-01-31 19:39:37 +08:00
Will Miao
92d471daf5 feat(ui): hide model sub-type in compact density mode, see #793
Add CSS rules to hide the model sub-type and separator elements when the compact-density class is applied. This change saves visual space in compact mode by removing less critical information, improving the layout for dense interfaces.
2026-01-31 11:17:49 +08:00
Will Miao
66babf9ee1 feat(lora-cycler): reset execution state on manual index change
Reset execution state when user manually changes LoRA index to ensure next execution starts from the user-set index. This prevents stale execution state from interfering with user-initiated index changes.
2026-01-31 09:04:26 +08:00
Will Miao
60df2df324 feat: add new Flux Klein models, ZImageBase, and LTXV2 to constants, see #792
- Add Flux.2 Klein 9B, 9B-base, 4B, and 4B-base models to BASE_MODELS, BASE_MODEL_ABBREVIATIONS, and Flux Models category
- Include ZImageBase model and its abbreviation
- Add LTXV2 video model to BASE_MODELS, BASE_MODEL_ABBREVIATIONS, and Video Models category
- Update model categories to reflect new additions
2026-01-31 07:57:21 +08:00
80 changed files with 10276 additions and 775 deletions

View File

@@ -0,0 +1,201 @@
---
name: lora-manager-e2e
description: End-to-end testing and validation for LoRa Manager features. Use when performing automated E2E validation of LoRa Manager standalone mode, including starting/restarting the server, using Chrome DevTools MCP to interact with the web UI at http://127.0.0.1:8188/loras, and verifying frontend-to-backend functionality. Covers workflow validation, UI interaction testing, and integration testing between the standalone Python backend and the browser frontend.
---
# LoRa Manager E2E Testing
This skill provides workflows and utilities for end-to-end testing of LoRa Manager using Chrome DevTools MCP.
## Prerequisites
- LoRa Manager project cloned and dependencies installed (`pip install -r requirements.txt`)
- Chrome browser available for debugging
- Chrome DevTools MCP connected
## Quick Start Workflow
### 1. Start LoRa Manager Standalone
```python
# Use the provided script to start the server
python .agents/skills/lora-manager-e2e/scripts/start_server.py --port 8188
```
Or manually:
```bash
cd /home/miao/workspace/ComfyUI/custom_nodes/ComfyUI-Lora-Manager
python standalone.py --port 8188
```
Wait for server ready message before proceeding.
### 2. Open Chrome Debug Mode
```bash
# Chrome with remote debugging on port 9222
google-chrome --remote-debugging-port=9222 --user-data-dir=/tmp/chrome-lora-manager http://127.0.0.1:8188/loras
```
### 3. Connect Chrome DevTools MCP
Ensure the MCP server is connected to Chrome at `http://localhost:9222`.
### 4. Navigate and Interact
Use Chrome DevTools MCP tools to:
- Take snapshots: `take_snapshot`
- Click elements: `click`
- Fill forms: `fill` or `fill_form`
- Evaluate scripts: `evaluate_script`
- Wait for elements: `wait_for`
## Common E2E Test Patterns
### Pattern: Full Page Load Verification
```python
# Navigate to LoRA list page
navigate_page(type="url", url="http://127.0.0.1:8188/loras")
# Wait for page to load
wait_for(text="LoRAs", timeout=10000)
# Take snapshot to verify UI state
snapshot = take_snapshot()
```
### Pattern: Restart Server for Configuration Changes
```python
# Stop current server (if running)
# Start with new configuration
python .agents/skills/lora-manager-e2e/scripts/start_server.py --port 8188 --restart
# Wait and refresh browser
navigate_page(type="reload", ignoreCache=True)
wait_for(text="LoRAs", timeout=15000)
```
### Pattern: Verify Backend API via Frontend
```python
# Execute script in browser to call backend API
result = evaluate_script(function="""
async () => {
const response = await fetch('/loras/api/list');
const data = await response.json();
return { count: data.length, firstItem: data[0]?.name };
}
""")
```
### Pattern: Form Submission Flow
```python
# Fill a form (e.g., search or filter)
fill_form(elements=[
{"uid": "search-input", "value": "character"},
])
# Click submit button
click(uid="search-button")
# Wait for results
wait_for(text="Results", timeout=5000)
# Verify results via snapshot
snapshot = take_snapshot()
```
### Pattern: Modal Dialog Interaction
```python
# Open modal (e.g., add LoRA)
click(uid="add-lora-button")
# Wait for modal to appear
wait_for(text="Add LoRA", timeout=3000)
# Fill modal form
fill_form(elements=[
{"uid": "lora-name", "value": "Test LoRA"},
{"uid": "lora-path", "value": "/path/to/lora.safetensors"},
])
# Submit
click(uid="modal-submit-button")
# Wait for success message or close
wait_for(text="Success", timeout=5000)
```
## Available Scripts
### scripts/start_server.py
Starts or restarts the LoRa Manager standalone server.
```bash
python scripts/start_server.py [--port PORT] [--restart] [--wait]
```
Options:
- `--port`: Server port (default: 8188)
- `--restart`: Kill existing server before starting
- `--wait`: Wait for server to be ready before exiting
### scripts/wait_for_server.py
Polls server until ready or timeout.
```bash
python scripts/wait_for_server.py [--port PORT] [--timeout SECONDS]
```
## Test Scenarios Reference
See [references/test-scenarios.md](references/test-scenarios.md) for detailed test scenarios including:
- LoRA list display and filtering
- Model metadata editing
- Recipe creation and management
- Settings configuration
- Import/export functionality
## Network Request Verification
Use `list_network_requests` and `get_network_request` to verify API calls:
```python
# List recent XHR/fetch requests
requests = list_network_requests(resourceTypes=["xhr", "fetch"])
# Get details of specific request
details = get_network_request(reqid=123)
```
## Console Message Monitoring
```python
# Check for errors or warnings
messages = list_console_messages(types=["error", "warn"])
```
## Performance Testing
```python
# Start performance trace
performance_start_trace(reload=True, autoStop=False)
# Perform actions...
# Stop and analyze
results = performance_stop_trace()
```
## Cleanup
Always ensure proper cleanup after tests:
1. Stop the standalone server
2. Close browser pages (keep at least one open)
3. Clear temporary data if needed

View File

@@ -0,0 +1,324 @@
# Chrome DevTools MCP Cheatsheet for LoRa Manager
Quick reference for common MCP commands used in LoRa Manager E2E testing.
## Navigation
```python
# Navigate to LoRA list page
navigate_page(type="url", url="http://127.0.0.1:8188/loras")
# Reload page with cache clear
navigate_page(type="reload", ignoreCache=True)
# Go back/forward
navigate_page(type="back")
navigate_page(type="forward")
```
## Waiting
```python
# Wait for text to appear
wait_for(text="LoRAs", timeout=10000)
# Wait for specific element (via evaluate_script)
evaluate_script(function="""
() => {
return new Promise((resolve) => {
const check = () => {
if (document.querySelector('.lora-card')) {
resolve(true);
} else {
setTimeout(check, 100);
}
};
check();
});
}
""")
```
## Taking Snapshots
```python
# Full page snapshot
snapshot = take_snapshot()
# Verbose snapshot (more details)
snapshot = take_snapshot(verbose=True)
# Save to file
take_snapshot(filePath="test-snapshots/page-load.json")
```
## Element Interaction
```python
# Click element
click(uid="element-uid-from-snapshot")
# Double click
click(uid="element-uid", dblClick=True)
# Fill input
fill(uid="search-input", value="test query")
# Fill multiple inputs
fill_form(elements=[
{"uid": "input-1", "value": "value 1"},
{"uid": "input-2", "value": "value 2"},
])
# Hover
hover(uid="lora-card-1")
# Upload file
upload_file(uid="file-input", filePath="/path/to/file.safetensors")
```
## Keyboard Input
```python
# Press key
press_key(key="Enter")
press_key(key="Escape")
press_key(key="Tab")
# Keyboard shortcuts
press_key(key="Control+A") # Select all
press_key(key="Control+F") # Find
```
## JavaScript Evaluation
```python
# Simple evaluation
result = evaluate_script(function="() => document.title")
# Async evaluation
result = evaluate_script(function="""
async () => {
const response = await fetch('/loras/api/list');
return await response.json();
}
""")
# Check element existence
exists = evaluate_script(function="""
() => document.querySelector('.lora-card') !== null
""")
# Get element count
count = evaluate_script(function="""
() => document.querySelectorAll('.lora-card').length
""")
```
## Network Monitoring
```python
# List all network requests
requests = list_network_requests()
# Filter by resource type
xhr_requests = list_network_requests(resourceTypes=["xhr", "fetch"])
# Get specific request details
details = get_network_request(reqid=123)
# Include preserved requests from previous navigations
all_requests = list_network_requests(includePreservedRequests=True)
```
## Console Monitoring
```python
# List all console messages
messages = list_console_messages()
# Filter by type
errors = list_console_messages(types=["error", "warn"])
# Include preserved messages
all_messages = list_console_messages(includePreservedMessages=True)
# Get specific message
details = get_console_message(msgid=1)
```
## Performance Testing
```python
# Start trace with page reload
performance_start_trace(reload=True, autoStop=False)
# Start trace without reload
performance_start_trace(reload=False, autoStop=True, filePath="trace.json.gz")
# Stop trace
results = performance_stop_trace()
# Stop and save
performance_stop_trace(filePath="trace-results.json.gz")
# Analyze specific insight
insight = performance_analyze_insight(
insightSetId="results.insightSets[0].id",
insightName="LCPBreakdown"
)
```
## Page Management
```python
# List open pages
pages = list_pages()
# Select a page
select_page(pageId=0, bringToFront=True)
# Create new page
new_page(url="http://127.0.0.1:8188/loras")
# Close page (keep at least one open!)
close_page(pageId=1)
# Resize page
resize_page(width=1920, height=1080)
```
## Screenshots
```python
# Full page screenshot
take_screenshot(fullPage=True)
# Viewport screenshot
take_screenshot()
# Element screenshot
take_screenshot(uid="lora-card-1")
# Save to file
take_screenshot(filePath="screenshots/page.png", format="png")
# JPEG with quality
take_screenshot(filePath="screenshots/page.jpg", format="jpeg", quality=90)
```
## Dialog Handling
```python
# Accept dialog
handle_dialog(action="accept")
# Accept with text input
handle_dialog(action="accept", promptText="user input")
# Dismiss dialog
handle_dialog(action="dismiss")
```
## Device Emulation
```python
# Mobile viewport
emulate(viewport={"width": 375, "height": 667, "isMobile": True, "hasTouch": True})
# Tablet viewport
emulate(viewport={"width": 768, "height": 1024, "isMobile": True, "hasTouch": True})
# Desktop viewport
emulate(viewport={"width": 1920, "height": 1080})
# Network throttling
emulate(networkConditions="Slow 3G")
emulate(networkConditions="Fast 4G")
# CPU throttling
emulate(cpuThrottlingRate=4) # 4x slowdown
# Geolocation
emulate(geolocation={"latitude": 37.7749, "longitude": -122.4194})
# User agent
emulate(userAgent="Mozilla/5.0 (Custom)")
# Reset emulation
emulate(viewport=None, networkConditions="No emulation", userAgent=None)
```
## Drag and Drop
```python
# Drag element to another
drag(from_uid="draggable-item", to_uid="drop-zone")
```
## Common LoRa Manager Test Patterns
### Verify LoRA Cards Loaded
```python
navigate_page(type="url", url="http://127.0.0.1:8188/loras")
wait_for(text="LoRAs", timeout=10000)
# Check if cards loaded
result = evaluate_script(function="""
() => {
const cards = document.querySelectorAll('.lora-card');
return {
count: cards.length,
hasData: cards.length > 0
};
}
""")
```
### Search and Verify Results
```python
fill(uid="search-input", value="character")
press_key(key="Enter")
wait_for(timeout=2000) # Wait for debounce
# Check results
result = evaluate_script(function="""
() => {
const cards = document.querySelectorAll('.lora-card');
const names = Array.from(cards).map(c => c.dataset.name || c.textContent);
return { count: cards.length, names };
}
""")
```
### Check API Response
```python
# Trigger API call
evaluate_script(function="""
() => window.loraApiCallPromise = fetch('/loras/api/list').then(r => r.json())
""")
# Wait and get result
import time
time.sleep(1)
result = evaluate_script(function="""
async () => await window.loraApiCallPromise
""")
```
### Monitor Console for Errors
```python
# Before test: clear console (navigate reloads)
navigate_page(type="reload")
# ... perform actions ...
# Check for errors
errors = list_console_messages(types=["error"])
assert len(errors) == 0, f"Console errors: {errors}"
```

View File

@@ -0,0 +1,272 @@
# LoRa Manager E2E Test Scenarios
This document provides detailed test scenarios for end-to-end validation of LoRa Manager features.
## Table of Contents
1. [LoRA List Page](#lora-list-page)
2. [Model Details](#model-details)
3. [Recipes](#recipes)
4. [Settings](#settings)
5. [Import/Export](#importexport)
---
## LoRA List Page
### Scenario: Page Load and Display
**Objective**: Verify the LoRA list page loads correctly and displays models.
**Steps**:
1. Navigate to `http://127.0.0.1:8188/loras`
2. Wait for page title "LoRAs" to appear
3. Take snapshot to verify:
- Header with "LoRAs" title is visible
- Search/filter controls are present
- Grid/list view toggle exists
- LoRA cards are displayed (if models exist)
- Pagination controls (if applicable)
**Expected Result**: Page loads without errors, UI elements are present.
### Scenario: Search Functionality
**Objective**: Verify search filters LoRA models correctly.
**Steps**:
1. Ensure at least one LoRA exists with known name (e.g., "test-character")
2. Navigate to LoRA list page
3. Enter search term in search box: "test"
4. Press Enter or click search button
5. Wait for results to update
**Expected Result**: Only LoRAs matching search term are displayed.
**Verification Script**:
```python
# After search, verify filtered results
evaluate_script(function="""
() => {
const cards = document.querySelectorAll('.lora-card');
const names = Array.from(cards).map(c => c.dataset.name);
return { count: cards.length, names };
}
""")
```
### Scenario: Filter by Tags
**Objective**: Verify tag filtering works correctly.
**Steps**:
1. Navigate to LoRA list page
2. Click on a tag (e.g., "character", "style")
3. Wait for filtered results
**Expected Result**: Only LoRAs with selected tag are displayed.
### Scenario: View Mode Toggle
**Objective**: Verify grid/list view toggle works.
**Steps**:
1. Navigate to LoRA list page
2. Click list view button
3. Verify list layout
4. Click grid view button
5. Verify grid layout
**Expected Result**: View mode changes correctly, layout updates.
---
## Model Details
### Scenario: Open Model Details
**Objective**: Verify clicking a LoRA opens its details.
**Steps**:
1. Navigate to LoRA list page
2. Click on a LoRA card
3. Wait for details panel/modal to open
**Expected Result**: Details panel shows:
- Model name
- Preview image
- Metadata (trigger words, tags, etc.)
- Action buttons (edit, delete, etc.)
### Scenario: Edit Model Metadata
**Objective**: Verify metadata editing works end-to-end.
**Steps**:
1. Open a LoRA's details
2. Click "Edit" button
3. Modify trigger words field
4. Add/remove tags
5. Save changes
6. Refresh page
7. Reopen the same LoRA
**Expected Result**: Changes persist after refresh.
### Scenario: Delete Model
**Objective**: Verify model deletion works.
**Steps**:
1. Open a LoRA's details
2. Click "Delete" button
3. Confirm deletion in dialog
4. Wait for removal
**Expected Result**: Model removed from list, success message shown.
---
## Recipes
### Scenario: Recipe List Display
**Objective**: Verify recipes page loads and displays recipes.
**Steps**:
1. Navigate to `http://127.0.0.1:8188/recipes`
2. Wait for "Recipes" title
3. Take snapshot
**Expected Result**: Recipe list displayed with cards/items.
### Scenario: Create New Recipe
**Objective**: Verify recipe creation workflow.
**Steps**:
1. Navigate to recipes page
2. Click "New Recipe" button
3. Fill recipe form:
- Name: "Test Recipe"
- Description: "E2E test recipe"
- Add LoRA models
4. Save recipe
5. Verify recipe appears in list
**Expected Result**: New recipe created and displayed.
### Scenario: Apply Recipe
**Objective**: Verify applying a recipe to ComfyUI.
**Steps**:
1. Open a recipe
2. Click "Apply" or "Load in ComfyUI"
3. Verify action completes
**Expected Result**: Recipe applied successfully.
---
## Settings
### Scenario: Settings Page Load
**Objective**: Verify settings page displays correctly.
**Steps**:
1. Navigate to `http://127.0.0.1:8188/settings`
2. Wait for "Settings" title
3. Take snapshot
**Expected Result**: Settings form with various options displayed.
### Scenario: Change Setting and Restart
**Objective**: Verify settings persist after restart.
**Steps**:
1. Navigate to settings page
2. Change a setting (e.g., default view mode)
3. Save settings
4. Restart server: `python scripts/start_server.py --restart --wait`
5. Refresh browser page
6. Navigate to settings
**Expected Result**: Changed setting value persists.
---
## Import/Export
### Scenario: Export Models List
**Objective**: Verify export functionality.
**Steps**:
1. Navigate to LoRA list
2. Click "Export" button
3. Select format (JSON/CSV)
4. Download file
**Expected Result**: File downloaded with correct data.
### Scenario: Import Models
**Objective**: Verify import functionality.
**Steps**:
1. Prepare import file
2. Navigate to import page
3. Upload file
4. Verify import results
**Expected Result**: Models imported successfully, confirmation shown.
---
## API Integration Tests
### Scenario: Verify API Endpoints
**Objective**: Verify backend API responds correctly.
**Test via browser console**:
```javascript
// List LoRAs
fetch('/loras/api/list').then(r => r.json()).then(console.log)
// Get LoRA details
fetch('/loras/api/detail/<id>').then(r => r.json()).then(console.log)
// Search LoRAs
fetch('/loras/api/search?q=test').then(r => r.json()).then(console.log)
```
**Expected Result**: APIs return valid JSON with expected structure.
---
## Console Error Monitoring
During all tests, monitor browser console for errors:
```python
# Check for JavaScript errors
messages = list_console_messages(types=["error"])
assert len(messages) == 0, f"Console errors found: {messages}"
```
## Network Request Verification
Verify key API calls are made:
```python
# List XHR requests
requests = list_network_requests(resourceTypes=["xhr", "fetch"])
# Look for specific endpoints
lora_list_requests = [r for r in requests if "/api/list" in r.get("url", "")]
assert len(lora_list_requests) > 0, "LoRA list API not called"
```

View File

@@ -0,0 +1,193 @@
#!/usr/bin/env python3
"""
Example E2E test demonstrating LoRa Manager testing workflow.
This script shows how to:
1. Start the standalone server
2. Use Chrome DevTools MCP to interact with the UI
3. Verify functionality end-to-end
Note: This is a template. Actual execution requires Chrome DevTools MCP.
"""
import subprocess
import sys
import time
def run_test():
"""Run example E2E test flow."""
print("=" * 60)
print("LoRa Manager E2E Test Example")
print("=" * 60)
# Step 1: Start server
print("\n[1/5] Starting LoRa Manager standalone server...")
result = subprocess.run(
[sys.executable, "start_server.py", "--port", "8188", "--wait", "--timeout", "30"],
capture_output=True,
text=True
)
if result.returncode != 0:
print(f"Failed to start server: {result.stderr}")
return 1
print("Server ready!")
# Step 2: Open Chrome (manual step - show command)
print("\n[2/5] Open Chrome with debug mode:")
print("google-chrome --remote-debugging-port=9222 --user-data-dir=/tmp/chrome-lora-manager http://127.0.0.1:8188/loras")
print("(In actual test, this would be automated via MCP)")
# Step 3: Navigate and verify page load
print("\n[3/5] Page Load Verification:")
print("""
MCP Commands to execute:
1. navigate_page(type="url", url="http://127.0.0.1:8188/loras")
2. wait_for(text="LoRAs", timeout=10000)
3. snapshot = take_snapshot()
""")
# Step 4: Test search functionality
print("\n[4/5] Search Functionality Test:")
print("""
MCP Commands to execute:
1. fill(uid="search-input", value="test")
2. press_key(key="Enter")
3. wait_for(text="Results", timeout=5000)
4. result = evaluate_script(function="""
() => {
const cards = document.querySelectorAll('.lora-card');
return { count: cards.length };
}
""")
""")
# Step 5: Verify API
print("\n[5/5] API Verification:")
print("""
MCP Commands to execute:
1. api_result = evaluate_script(function="""
async () => {
const response = await fetch('/loras/api/list');
const data = await response.json();
return { count: data.length, status: response.status };
}
""")
2. Verify api_result['status'] == 200
""")
print("\n" + "=" * 60)
print("Test flow completed!")
print("=" * 60)
return 0
def example_restart_flow():
"""Example: Testing configuration change that requires restart."""
print("\n" + "=" * 60)
print("Example: Server Restart Flow")
print("=" * 60)
print("""
Scenario: Change setting and verify after restart
Steps:
1. Navigate to settings page
- navigate_page(type="url", url="http://127.0.0.1:8188/settings")
2. Change a setting (e.g., theme)
- fill(uid="theme-select", value="dark")
- click(uid="save-settings-button")
3. Restart server
- subprocess.run([python, "start_server.py", "--restart", "--wait"])
4. Refresh browser
- navigate_page(type="reload", ignoreCache=True)
- wait_for(text="LoRAs", timeout=15000)
5. Verify setting persisted
- navigate_page(type="url", url="http://127.0.0.1:8188/settings")
- theme = evaluate_script(function="() => document.querySelector('#theme-select').value")
- assert theme == "dark"
""")
def example_modal_interaction():
"""Example: Testing modal dialog interaction."""
print("\n" + "=" * 60)
print("Example: Modal Dialog Interaction")
print("=" * 60)
print("""
Scenario: Add new LoRA via modal
Steps:
1. Open modal
- click(uid="add-lora-button")
- wait_for(text="Add LoRA", timeout=3000)
2. Fill form
- fill_form(elements=[
{"uid": "lora-name", "value": "Test Character"},
{"uid": "lora-path", "value": "/models/test.safetensors"},
])
3. Submit
- click(uid="modal-submit-button")
4. Verify success
- wait_for(text="Successfully added", timeout=5000)
- snapshot = take_snapshot()
""")
def example_network_monitoring():
"""Example: Network request monitoring."""
print("\n" + "=" * 60)
print("Example: Network Request Monitoring")
print("=" * 60)
print("""
Scenario: Verify API calls during user interaction
Steps:
1. Clear network log (implicit on navigation)
- navigate_page(type="url", url="http://127.0.0.1:8188/loras")
2. Perform action that triggers API call
- fill(uid="search-input", value="character")
- press_key(key="Enter")
3. List network requests
- requests = list_network_requests(resourceTypes=["xhr", "fetch"])
4. Find search API call
- search_requests = [r for r in requests if "/api/search" in r.get("url", "")]
- assert len(search_requests) > 0, "Search API was not called"
5. Get request details
- if search_requests:
details = get_network_request(reqid=search_requests[0]["reqid"])
- Verify request method, response status, etc.
""")
if __name__ == "__main__":
print("LoRa Manager E2E Test Examples\n")
print("This script demonstrates E2E testing patterns.\n")
print("Note: Actual execution requires Chrome DevTools MCP connection.\n")
run_test()
example_restart_flow()
example_modal_interaction()
example_network_monitoring()
print("\n" + "=" * 60)
print("All examples shown!")
print("=" * 60)

View File

@@ -0,0 +1,169 @@
#!/usr/bin/env python3
"""
Start or restart LoRa Manager standalone server for E2E testing.
"""
import argparse
import subprocess
import sys
import time
import socket
import signal
import os
def find_server_process(port: int) -> list[int]:
"""Find PIDs of processes listening on the given port."""
try:
result = subprocess.run(
["lsof", "-ti", f":{port}"],
capture_output=True,
text=True,
check=False
)
if result.returncode == 0 and result.stdout.strip():
return [int(pid) for pid in result.stdout.strip().split("\n") if pid]
except FileNotFoundError:
# lsof not available, try netstat
try:
result = subprocess.run(
["netstat", "-tlnp"],
capture_output=True,
text=True,
check=False
)
pids = []
for line in result.stdout.split("\n"):
if f":{port}" in line:
parts = line.split()
for part in parts:
if "/" in part:
try:
pid = int(part.split("/")[0])
pids.append(pid)
except ValueError:
pass
return pids
except FileNotFoundError:
pass
return []
def kill_server(port: int) -> None:
"""Kill processes using the specified port."""
pids = find_server_process(port)
for pid in pids:
try:
os.kill(pid, signal.SIGTERM)
print(f"Sent SIGTERM to process {pid}")
except ProcessLookupError:
pass
# Wait for processes to terminate
time.sleep(1)
# Force kill if still running
pids = find_server_process(port)
for pid in pids:
try:
os.kill(pid, signal.SIGKILL)
print(f"Sent SIGKILL to process {pid}")
except ProcessLookupError:
pass
def is_server_ready(port: int, timeout: float = 0.5) -> bool:
"""Check if server is accepting connections."""
try:
with socket.create_connection(("127.0.0.1", port), timeout=timeout):
return True
except (socket.timeout, ConnectionRefusedError, OSError):
return False
def wait_for_server(port: int, timeout: int = 30) -> bool:
"""Wait for server to become ready."""
start = time.time()
while time.time() - start < timeout:
if is_server_ready(port):
return True
time.sleep(0.5)
return False
def main() -> int:
parser = argparse.ArgumentParser(
description="Start LoRa Manager standalone server for E2E testing"
)
parser.add_argument(
"--port",
type=int,
default=8188,
help="Server port (default: 8188)"
)
parser.add_argument(
"--restart",
action="store_true",
help="Kill existing server before starting"
)
parser.add_argument(
"--wait",
action="store_true",
help="Wait for server to be ready before exiting"
)
parser.add_argument(
"--timeout",
type=int,
default=30,
help="Timeout for waiting (default: 30)"
)
args = parser.parse_args()
# Get project root (parent of .agents directory)
script_dir = os.path.dirname(os.path.abspath(__file__))
skill_dir = os.path.dirname(script_dir)
project_root = os.path.dirname(os.path.dirname(os.path.dirname(skill_dir)))
# Restart if requested
if args.restart:
print(f"Killing existing server on port {args.port}...")
kill_server(args.port)
time.sleep(1)
# Check if already running
if is_server_ready(args.port):
print(f"Server already running on port {args.port}")
return 0
# Start server
print(f"Starting LoRa Manager standalone server on port {args.port}...")
cmd = [sys.executable, "standalone.py", "--port", str(args.port)]
# Start in background
process = subprocess.Popen(
cmd,
cwd=project_root,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
start_new_session=True
)
print(f"Server process started with PID {process.pid}")
# Wait for ready if requested
if args.wait:
print(f"Waiting for server to be ready (timeout: {args.timeout}s)...")
if wait_for_server(args.port, args.timeout):
print(f"Server ready at http://127.0.0.1:{args.port}/loras")
return 0
else:
print(f"Timeout waiting for server")
return 1
print(f"Server starting at http://127.0.0.1:{args.port}/loras")
return 0
if __name__ == "__main__":
sys.exit(main())

View File

@@ -0,0 +1,61 @@
#!/usr/bin/env python3
"""
Wait for LoRa Manager server to become ready.
"""
import argparse
import socket
import sys
import time
def is_server_ready(port: int, timeout: float = 0.5) -> bool:
"""Check if server is accepting connections."""
try:
with socket.create_connection(("127.0.0.1", port), timeout=timeout):
return True
except (socket.timeout, ConnectionRefusedError, OSError):
return False
def wait_for_server(port: int, timeout: int = 30) -> bool:
"""Wait for server to become ready."""
start = time.time()
while time.time() - start < timeout:
if is_server_ready(port):
return True
time.sleep(0.5)
return False
def main() -> int:
parser = argparse.ArgumentParser(
description="Wait for LoRa Manager server to become ready"
)
parser.add_argument(
"--port",
type=int,
default=8188,
help="Server port (default: 8188)"
)
parser.add_argument(
"--timeout",
type=int,
default=30,
help="Timeout in seconds (default: 30)"
)
args = parser.parse_args()
print(f"Waiting for server on port {args.port} (timeout: {args.timeout}s)...")
if wait_for_server(args.port, args.timeout):
print(f"Server ready at http://127.0.0.1:{args.port}/loras")
return 0
else:
print(f"Timeout: Server not ready after {args.timeout}s")
return 1
if __name__ == "__main__":
sys.exit(main())

View File

@@ -34,6 +34,11 @@ Enhance your Civitai browsing experience with our companion browser extension! S
## Release Notes ## Release Notes
### v0.9.14
* **LoRA Cycler Node** - Introduced a new LoRA Cycler node that enables iteration through specified LoRAs with support for repeat count and pause iteration functionality. Refer to the new "Lora Cycler" template workflow for concrete example.
* **Enhanced Prompt Node with Tag Autocomplete** - Enhanced the Prompt node with comprehensive tag autocomplete based on merged Danbooru + e621 tags. Supports tag search and autocomplete functionality. Implemented a command system with shortcuts like `/char` or `/artist` for category-specific tag searching. Added `/ac` or `/noac` commands to quickly enable or disable autocomplete. Refer to the "Lora Manager Basic" template workflow in ComfyUI -> Templates -> ComfyUI-Lora-Manager for detailed tips.
* **Bug Fixes & Stability** - Addressed multiple bugs and improved overall stability.
### v0.9.12 ### v0.9.12
* **LoRA Randomizer System** - Introduced a comprehensive LoRA randomization system featuring LoRA Pool and LoRA Randomizer nodes for flexible and dynamic generation workflows. * **LoRA Randomizer System** - Introduced a comprehensive LoRA randomization system featuring LoRA Pool and LoRA Randomizer nodes for flexible and dynamic generation workflows.
* **LoRA Randomizer Template** - Refer to the new "LoRA Randomizer" template workflow for detailed examples of flexible randomization modes, lock & reuse options, and other features. * **LoRA Randomizer Template** - Refer to the new "LoRA Randomizer" template workflow for detailed examples of flexible randomization modes, lock & reuse options, and other features.

File diff suppressed because one or more lines are too long

View File

@@ -9,9 +9,9 @@
"back": "Zurück", "back": "Zurück",
"next": "Weiter", "next": "Weiter",
"backToTop": "Nach oben", "backToTop": "Nach oben",
"add": "Hinzufügen",
"settings": "Einstellungen", "settings": "Einstellungen",
"help": "Hilfe" "help": "Hilfe",
"add": "Hinzufügen"
}, },
"status": { "status": {
"loading": "Wird geladen...", "loading": "Wird geladen...",
@@ -1572,6 +1572,20 @@
"content": "LoRA Manager is a passion project maintained full-time by a solo developer. Your support on Ko-fi helps cover development costs, keeps new updates coming, and unlocks a license key for the LM Civitai Extension as a thank-you gift. Every contribution truly makes a difference.", "content": "LoRA Manager is a passion project maintained full-time by a solo developer. Your support on Ko-fi helps cover development costs, keeps new updates coming, and unlocks a license key for the LM Civitai Extension as a thank-you gift. Every contribution truly makes a difference.",
"supportCta": "Support on Ko-fi", "supportCta": "Support on Ko-fi",
"learnMore": "LM Civitai Extension Tutorial" "learnMore": "LM Civitai Extension Tutorial"
},
"cacheHealth": {
"corrupted": {
"title": "Cache-Korruption erkannt"
},
"degraded": {
"title": "Cache-Probleme erkannt"
},
"content": "{invalid} von {total} Cache-Einträgen sind ungültig ({rate}). Dies kann zu fehlenden Modellen oder Fehlern führen. Ein Neuaufbau des Caches wird empfohlen.",
"rebuildCache": "Cache neu aufbauen",
"dismiss": "Verwerfen",
"rebuilding": "Cache wird neu aufgebaut...",
"rebuildFailed": "Fehler beim Neuaufbau des Caches: {error}",
"retry": "Wiederholen"
} }
} }
} }

View File

@@ -1572,6 +1572,20 @@
"content": "LoRA Manager is a passion project maintained full-time by a solo developer. Your support on Ko-fi helps cover development costs, keeps new updates coming, and unlocks a license key for the LM Civitai Extension as a thank-you gift. Every contribution truly makes a difference.", "content": "LoRA Manager is a passion project maintained full-time by a solo developer. Your support on Ko-fi helps cover development costs, keeps new updates coming, and unlocks a license key for the LM Civitai Extension as a thank-you gift. Every contribution truly makes a difference.",
"supportCta": "Support on Ko-fi", "supportCta": "Support on Ko-fi",
"learnMore": "LM Civitai Extension Tutorial" "learnMore": "LM Civitai Extension Tutorial"
},
"cacheHealth": {
"corrupted": {
"title": "Cache Corruption Detected"
},
"degraded": {
"title": "Cache Issues Detected"
},
"content": "{invalid} of {total} cache entries are invalid ({rate}). This may cause missing models or errors. Rebuilding the cache is recommended.",
"rebuildCache": "Rebuild Cache",
"dismiss": "Dismiss",
"rebuilding": "Rebuilding cache...",
"rebuildFailed": "Failed to rebuild cache: {error}",
"retry": "Retry"
} }
} }
} }

View File

@@ -1572,6 +1572,20 @@
"content": "LoRA Manager is a passion project maintained full-time by a solo developer. Your support on Ko-fi helps cover development costs, keeps new updates coming, and unlocks a license key for the LM Civitai Extension as a thank-you gift. Every contribution truly makes a difference.", "content": "LoRA Manager is a passion project maintained full-time by a solo developer. Your support on Ko-fi helps cover development costs, keeps new updates coming, and unlocks a license key for the LM Civitai Extension as a thank-you gift. Every contribution truly makes a difference.",
"supportCta": "Support on Ko-fi", "supportCta": "Support on Ko-fi",
"learnMore": "LM Civitai Extension Tutorial" "learnMore": "LM Civitai Extension Tutorial"
},
"cacheHealth": {
"corrupted": {
"title": "Corrupción de caché detectada"
},
"degraded": {
"title": "Problemas de caché detectados"
},
"content": "{invalid} de {total} entradas de caché son inválidas ({rate}). Esto puede causar modelos faltantes o errores. Se recomienda reconstruir la caché.",
"rebuildCache": "Reconstruir caché",
"dismiss": "Descartar",
"rebuilding": "Reconstruyendo caché...",
"rebuildFailed": "Error al reconstruir la caché: {error}",
"retry": "Reintentar"
} }
} }
} }

View File

@@ -1572,6 +1572,20 @@
"content": "LoRA Manager is a passion project maintained full-time by a solo developer. Your support on Ko-fi helps cover development costs, keeps new updates coming, and unlocks a license key for the LM Civitai Extension as a thank-you gift. Every contribution truly makes a difference.", "content": "LoRA Manager is a passion project maintained full-time by a solo developer. Your support on Ko-fi helps cover development costs, keeps new updates coming, and unlocks a license key for the LM Civitai Extension as a thank-you gift. Every contribution truly makes a difference.",
"supportCta": "Support on Ko-fi", "supportCta": "Support on Ko-fi",
"learnMore": "LM Civitai Extension Tutorial" "learnMore": "LM Civitai Extension Tutorial"
},
"cacheHealth": {
"corrupted": {
"title": "Corruption du cache détectée"
},
"degraded": {
"title": "Problèmes de cache détectés"
},
"content": "{invalid} des {total} entrées de cache sont invalides ({rate}). Cela peut provoquer des modèles manquants ou des erreurs. Il est recommandé de reconstruire le cache.",
"rebuildCache": "Reconstruire le cache",
"dismiss": "Ignorer",
"rebuilding": "Reconstruction du cache...",
"rebuildFailed": "Échec de la reconstruction du cache : {error}",
"retry": "Réessayer"
} }
} }
} }

View File

@@ -9,9 +9,9 @@
"back": "חזור", "back": "חזור",
"next": "הבא", "next": "הבא",
"backToTop": "חזור למעלה", "backToTop": "חזור למעלה",
"add": "הוסף",
"settings": "הגדרות", "settings": "הגדרות",
"help": "עזרה" "help": "עזרה",
"add": "הוסף"
}, },
"status": { "status": {
"loading": "טוען...", "loading": "טוען...",
@@ -1572,6 +1572,20 @@
"content": "LoRA Manager is a passion project maintained full-time by a solo developer. Your support on Ko-fi helps cover development costs, keeps new updates coming, and unlocks a license key for the LM Civitai Extension as a thank-you gift. Every contribution truly makes a difference.", "content": "LoRA Manager is a passion project maintained full-time by a solo developer. Your support on Ko-fi helps cover development costs, keeps new updates coming, and unlocks a license key for the LM Civitai Extension as a thank-you gift. Every contribution truly makes a difference.",
"supportCta": "Support on Ko-fi", "supportCta": "Support on Ko-fi",
"learnMore": "LM Civitai Extension Tutorial" "learnMore": "LM Civitai Extension Tutorial"
},
"cacheHealth": {
"corrupted": {
"title": "זוהתה שחיתות במטמון"
},
"degraded": {
"title": "זוהו בעיות במטמון"
},
"content": "{invalid} מתוך {total} רשומות מטמון אינן תקינות ({rate}). זה עלול לגרום לדגמים חסרים או לשגיאות. מומלץ לבנות מחדש את המטמון.",
"rebuildCache": "בניית מטמון מחדש",
"dismiss": "ביטול",
"rebuilding": "בונה מחדש את המטמון...",
"rebuildFailed": "נכשלה בניית המטמון מחדש: {error}",
"retry": "נסה שוב"
} }
} }
} }

View File

@@ -1572,6 +1572,20 @@
"content": "LoRA Manager is a passion project maintained full-time by a solo developer. Your support on Ko-fi helps cover development costs, keeps new updates coming, and unlocks a license key for the LM Civitai Extension as a thank-you gift. Every contribution truly makes a difference.", "content": "LoRA Manager is a passion project maintained full-time by a solo developer. Your support on Ko-fi helps cover development costs, keeps new updates coming, and unlocks a license key for the LM Civitai Extension as a thank-you gift. Every contribution truly makes a difference.",
"supportCta": "Support on Ko-fi", "supportCta": "Support on Ko-fi",
"learnMore": "LM Civitai Extension Tutorial" "learnMore": "LM Civitai Extension Tutorial"
},
"cacheHealth": {
"corrupted": {
"title": "キャッシュの破損が検出されました"
},
"degraded": {
"title": "キャッシュの問題が検出されました"
},
"content": "{total}個のキャッシュエントリのうち{invalid}個が無効です({rate})。モデルが見つからない原因になったり、エラーが発生する可能性があります。キャッシュの再構築を推奨します。",
"rebuildCache": "キャッシュを再構築",
"dismiss": "閉じる",
"rebuilding": "キャッシュを再構築中...",
"rebuildFailed": "キャッシュの再構築に失敗しました: {error}",
"retry": "再試行"
} }
} }
} }

View File

@@ -1572,6 +1572,20 @@
"content": "LoRA Manager is a passion project maintained full-time by a solo developer. Your support on Ko-fi helps cover development costs, keeps new updates coming, and unlocks a license key for the LM Civitai Extension as a thank-you gift. Every contribution truly makes a difference.", "content": "LoRA Manager is a passion project maintained full-time by a solo developer. Your support on Ko-fi helps cover development costs, keeps new updates coming, and unlocks a license key for the LM Civitai Extension as a thank-you gift. Every contribution truly makes a difference.",
"supportCta": "Support on Ko-fi", "supportCta": "Support on Ko-fi",
"learnMore": "LM Civitai Extension Tutorial" "learnMore": "LM Civitai Extension Tutorial"
},
"cacheHealth": {
"corrupted": {
"title": "캐시 손상이 감지되었습니다"
},
"degraded": {
"title": "캐시 문제가 감지되었습니다"
},
"content": "{total}개의 캐시 항목 중 {invalid}개가 유효하지 않습니다 ({rate}). 모델 누락이나 오류가 발생할 수 있습니다. 캐시를 재구축하는 것이 좋습니다.",
"rebuildCache": "캐시 재구축",
"dismiss": "무시",
"rebuilding": "캐시 재구축 중...",
"rebuildFailed": "캐시 재구축 실패: {error}",
"retry": "다시 시도"
} }
} }
} }

View File

@@ -1572,6 +1572,20 @@
"content": "LoRA Manager is a passion project maintained full-time by a solo developer. Your support on Ko-fi helps cover development costs, keeps new updates coming, and unlocks a license key for the LM Civitai Extension as a thank-you gift. Every contribution truly makes a difference.", "content": "LoRA Manager is a passion project maintained full-time by a solo developer. Your support on Ko-fi helps cover development costs, keeps new updates coming, and unlocks a license key for the LM Civitai Extension as a thank-you gift. Every contribution truly makes a difference.",
"supportCta": "Support on Ko-fi", "supportCta": "Support on Ko-fi",
"learnMore": "LM Civitai Extension Tutorial" "learnMore": "LM Civitai Extension Tutorial"
},
"cacheHealth": {
"corrupted": {
"title": "Обнаружено повреждение кэша"
},
"degraded": {
"title": "Обнаружены проблемы с кэшем"
},
"content": "{invalid} из {total} записей кэша недействительны ({rate}). Это может привести к отсутствию моделей или ошибкам. Рекомендуется перестроить кэш.",
"rebuildCache": "Перестроить кэш",
"dismiss": "Отклонить",
"rebuilding": "Перестроение кэша...",
"rebuildFailed": "Не удалось перестроить кэш: {error}",
"retry": "Повторить"
} }
} }
} }

View File

@@ -1572,6 +1572,20 @@
"content": "来爱发电为Lora Manager项目发电支持项目持续开发的同时获取浏览器插件验证码按季支付更优惠支付宝/微信方便支付。感谢支持!🚀", "content": "来爱发电为Lora Manager项目发电支持项目持续开发的同时获取浏览器插件验证码按季支付更优惠支付宝/微信方便支付。感谢支持!🚀",
"supportCta": "为LM发电", "supportCta": "为LM发电",
"learnMore": "浏览器插件教程" "learnMore": "浏览器插件教程"
},
"cacheHealth": {
"corrupted": {
"title": "检测到缓存损坏"
},
"degraded": {
"title": "检测到缓存问题"
},
"content": "{total} 个缓存条目中有 {invalid} 个无效({rate})。这可能导致模型丢失或错误。建议重建缓存。",
"rebuildCache": "重建缓存",
"dismiss": "忽略",
"rebuilding": "正在重建缓存...",
"rebuildFailed": "重建缓存失败:{error}",
"retry": "重试"
} }
} }
} }

View File

@@ -1572,6 +1572,20 @@
"content": "LoRA Manager is a passion project maintained full-time by a solo developer. Your support on Ko-fi helps cover development costs, keeps new updates coming, and unlocks a license key for the LM Civitai Extension as a thank-you gift. Every contribution truly makes a difference.", "content": "LoRA Manager is a passion project maintained full-time by a solo developer. Your support on Ko-fi helps cover development costs, keeps new updates coming, and unlocks a license key for the LM Civitai Extension as a thank-you gift. Every contribution truly makes a difference.",
"supportCta": "Support on Ko-fi", "supportCta": "Support on Ko-fi",
"learnMore": "LM Civitai Extension Tutorial" "learnMore": "LM Civitai Extension Tutorial"
},
"cacheHealth": {
"corrupted": {
"title": "檢測到快取損壞"
},
"degraded": {
"title": "檢測到快取問題"
},
"content": "{total} 個快取項目中有 {invalid} 個無效({rate})。這可能會導致模型遺失或錯誤。建議重建快取。",
"rebuildCache": "重建快取",
"dismiss": "關閉",
"rebuilding": "重建快取中...",
"rebuildFailed": "重建快取失敗:{error}",
"retry": "重試"
} }
} }
} }

View File

@@ -4,7 +4,9 @@
"private": true, "private": true,
"type": "module", "type": "module",
"scripts": { "scripts": {
"test": "vitest run", "test": "npm run test:js && npm run test:vue",
"test:js": "vitest run",
"test:vue": "cd vue-widgets && npx vitest run",
"test:watch": "vitest", "test:watch": "vitest",
"test:coverage": "node scripts/run_frontend_coverage.js" "test:coverage": "node scripts/run_frontend_coverage.js"
}, },

View File

@@ -441,82 +441,53 @@ class Config:
logger.info("Failed to write symlink cache %s: %s", cache_path, exc) logger.info("Failed to write symlink cache %s: %s", cache_path, exc)
def _scan_symbolic_links(self): def _scan_symbolic_links(self):
"""Scan all symbolic links in LoRA, Checkpoint, and Embedding root directories""" """Scan symbolic links in LoRA, Checkpoint, and Embedding root directories.
Only scans the first level of each root directory to avoid performance
issues with large file systems. Detects symlinks and Windows junctions
at the root level only (not nested symlinks in subdirectories).
"""
start = time.perf_counter() start = time.perf_counter()
# Reset mappings before rescanning to avoid stale entries # Reset mappings before rescanning to avoid stale entries
self._path_mappings.clear() self._path_mappings.clear()
self._seed_root_symlink_mappings() self._seed_root_symlink_mappings()
visited_dirs: Set[str] = set()
for root in self._symlink_roots(): for root in self._symlink_roots():
self._scan_directory_links(root, visited_dirs) self._scan_first_level_symlinks(root)
logger.debug( logger.debug(
"Symlink scan finished in %.2f ms with %d mappings", "Symlink scan finished in %.2f ms with %d mappings",
(time.perf_counter() - start) * 1000, (time.perf_counter() - start) * 1000,
len(self._path_mappings), len(self._path_mappings),
) )
def _scan_directory_links(self, root: str, visited_dirs: Set[str]): def _scan_first_level_symlinks(self, root: str):
"""Iteratively scan directory symlinks to avoid deep recursion.""" """Scan only the first level of a directory for symlinks.
This avoids traversing the entire directory tree which can be extremely
slow for large model collections. Only symlinks directly under the root
are detected.
"""
try: try:
# Note: We only use realpath for the initial root if it's not already resolved with os.scandir(root) as it:
# to ensure we have a valid entry point. for entry in it:
root_real = self._normalize_path(os.path.realpath(root)) try:
except OSError: # Only detect symlinks including Windows junctions
root_real = self._normalize_path(root) # Skip normal directories to avoid deep traversal
if not self._entry_is_symlink(entry):
continue
if root_real in visited_dirs: # Resolve the symlink target
return target_path = os.path.realpath(entry.path)
if not os.path.isdir(target_path):
continue
visited_dirs.add(root_real) self.add_path_mapping(entry.path, target_path)
# Stack entries: (display_path, real_resolved_path) except Exception as inner_exc:
stack: List[Tuple[str, str]] = [(root, root_real)] logger.debug(
"Error processing directory entry %s: %s", entry.path, inner_exc
while stack: )
current_display, current_real = stack.pop() except Exception as e:
try: logger.error(f"Error scanning links in {root}: {e}")
with os.scandir(current_display) as it:
for entry in it:
try:
# 1. Detect symlinks including Windows junctions
is_link = self._entry_is_symlink(entry)
if is_link:
# Only resolve realpath when we actually find a link
target_path = os.path.realpath(entry.path)
if not os.path.isdir(target_path):
continue
normalized_target = self._normalize_path(target_path)
self.add_path_mapping(entry.path, target_path)
if normalized_target in visited_dirs:
continue
visited_dirs.add(normalized_target)
stack.append((target_path, normalized_target))
continue
# 2. Process normal directories
if not entry.is_dir(follow_symlinks=False):
continue
# For normal directories, we avoid realpath() call by
# incrementally building the real path relative to current_real.
# This is safe because 'entry' is NOT a symlink.
entry_real = self._normalize_path(os.path.join(current_real, entry.name))
if entry_real in visited_dirs:
continue
visited_dirs.add(entry_real)
stack.append((entry.path, entry_real))
except Exception as inner_exc:
logger.debug(
"Error processing directory entry %s: %s", entry.path, inner_exc
)
except Exception as e:
logger.error(f"Error scanning links in {current_display}: {e}")

View File

@@ -1,4 +1,7 @@
import os import os
import logging
logger = logging.getLogger(__name__)
# Check if running in standalone mode # Check if running in standalone mode
standalone_mode = os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1" or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0" standalone_mode = os.environ.get("LORA_MANAGER_STANDALONE", "0") == "1" or os.environ.get("HF_HUB_DISABLE_TELEMETRY", "0") == "0"
@@ -14,7 +17,7 @@ if not standalone_mode:
# Initialize registry # Initialize registry
registry = MetadataRegistry() registry = MetadataRegistry()
print("ComfyUI Metadata Collector initialized") logger.info("ComfyUI Metadata Collector initialized")
def get_metadata(prompt_id=None): def get_metadata(prompt_id=None):
"""Helper function to get metadata from the registry""" """Helper function to get metadata from the registry"""
@@ -23,7 +26,7 @@ if not standalone_mode:
else: else:
# Standalone mode - provide dummy implementations # Standalone mode - provide dummy implementations
def init(): def init():
print("ComfyUI Metadata Collector disabled in standalone mode") logger.info("ComfyUI Metadata Collector disabled in standalone mode")
def get_metadata(prompt_id=None): def get_metadata(prompt_id=None):
"""Dummy implementation for standalone mode""" """Dummy implementation for standalone mode"""

View File

@@ -1,7 +1,10 @@
import sys import sys
import inspect import inspect
import logging
from .metadata_registry import MetadataRegistry from .metadata_registry import MetadataRegistry
logger = logging.getLogger(__name__)
class MetadataHook: class MetadataHook:
"""Install hooks for metadata collection""" """Install hooks for metadata collection"""
@@ -23,7 +26,7 @@ class MetadataHook:
# If we can't find the execution module, we can't install hooks # If we can't find the execution module, we can't install hooks
if execution is None: if execution is None:
print("Could not locate ComfyUI execution module, metadata collection disabled") logger.warning("Could not locate ComfyUI execution module, metadata collection disabled")
return return
# Detect whether we're using the new async version of ComfyUI # Detect whether we're using the new async version of ComfyUI
@@ -37,16 +40,16 @@ class MetadataHook:
is_async = inspect.iscoroutinefunction(execution._map_node_over_list) is_async = inspect.iscoroutinefunction(execution._map_node_over_list)
if is_async: if is_async:
print("Detected async ComfyUI execution, installing async metadata hooks") logger.info("Detected async ComfyUI execution, installing async metadata hooks")
MetadataHook._install_async_hooks(execution, map_node_func_name) MetadataHook._install_async_hooks(execution, map_node_func_name)
else: else:
print("Detected sync ComfyUI execution, installing sync metadata hooks") logger.info("Detected sync ComfyUI execution, installing sync metadata hooks")
MetadataHook._install_sync_hooks(execution) MetadataHook._install_sync_hooks(execution)
print("Metadata collection hooks installed for runtime values") logger.info("Metadata collection hooks installed for runtime values")
except Exception as e: except Exception as e:
print(f"Error installing metadata hooks: {str(e)}") logger.error(f"Error installing metadata hooks: {str(e)}")
@staticmethod @staticmethod
def _install_sync_hooks(execution): def _install_sync_hooks(execution):
@@ -82,7 +85,7 @@ class MetadataHook:
if node_id is not None: if node_id is not None:
registry.record_node_execution(node_id, class_type, input_data_all, None) registry.record_node_execution(node_id, class_type, input_data_all, None)
except Exception as e: except Exception as e:
print(f"Error collecting metadata (pre-execution): {str(e)}") logger.error(f"Error collecting metadata (pre-execution): {str(e)}")
# Execute the original function # Execute the original function
results = original_map_node_over_list(obj, input_data_all, func, allow_interrupt, execution_block_cb, pre_execute_cb) results = original_map_node_over_list(obj, input_data_all, func, allow_interrupt, execution_block_cb, pre_execute_cb)
@@ -113,7 +116,7 @@ class MetadataHook:
if node_id is not None: if node_id is not None:
registry.update_node_execution(node_id, class_type, results) registry.update_node_execution(node_id, class_type, results)
except Exception as e: except Exception as e:
print(f"Error collecting metadata (post-execution): {str(e)}") logger.error(f"Error collecting metadata (post-execution): {str(e)}")
return results return results
@@ -159,7 +162,7 @@ class MetadataHook:
if node_id is not None: if node_id is not None:
registry.record_node_execution(node_id, class_type, input_data_all, None) registry.record_node_execution(node_id, class_type, input_data_all, None)
except Exception as e: except Exception as e:
print(f"Error collecting metadata (pre-execution): {str(e)}") logger.error(f"Error collecting metadata (pre-execution): {str(e)}")
# Call original function with all args/kwargs # Call original function with all args/kwargs
results = await original_map_node_over_list( results = await original_map_node_over_list(
@@ -176,7 +179,7 @@ class MetadataHook:
if node_id is not None: if node_id is not None:
registry.update_node_execution(node_id, class_type, results) registry.update_node_execution(node_id, class_type, results)
except Exception as e: except Exception as e:
print(f"Error collecting metadata (post-execution): {str(e)}") logger.error(f"Error collecting metadata (post-execution): {str(e)}")
return results return results

View File

@@ -126,9 +126,7 @@ class LoraCyclerLM:
"current_index": [clamped_index], "current_index": [clamped_index],
"next_index": [next_index], "next_index": [next_index],
"total_count": [total_count], "total_count": [total_count],
"current_lora_name": [ "current_lora_name": [current_lora["file_name"]],
current_lora.get("model_name", current_lora["file_name"])
],
"current_lora_filename": [current_lora["file_name"]], "current_lora_filename": [current_lora["file_name"]],
"next_lora_name": [next_display_name], "next_lora_name": [next_display_name],
"next_lora_filename": [next_lora["file_name"]], "next_lora_filename": [next_lora["file_name"]],

View File

@@ -8,6 +8,9 @@ from ..metadata_collector.metadata_processor import MetadataProcessor
from ..metadata_collector import get_metadata from ..metadata_collector import get_metadata
from PIL import Image, PngImagePlugin from PIL import Image, PngImagePlugin
import piexif import piexif
import logging
logger = logging.getLogger(__name__)
class SaveImageLM: class SaveImageLM:
NAME = "Save Image (LoraManager)" NAME = "Save Image (LoraManager)"
@@ -385,7 +388,7 @@ class SaveImageLM:
exif_bytes = piexif.dump(exif_dict) exif_bytes = piexif.dump(exif_dict)
save_kwargs["exif"] = exif_bytes save_kwargs["exif"] = exif_bytes
except Exception as e: except Exception as e:
print(f"Error adding EXIF data: {e}") logger.error(f"Error adding EXIF data: {e}")
img.save(file_path, format="JPEG", **save_kwargs) img.save(file_path, format="JPEG", **save_kwargs)
elif file_format == "webp": elif file_format == "webp":
try: try:
@@ -403,7 +406,7 @@ class SaveImageLM:
exif_bytes = piexif.dump(exif_dict) exif_bytes = piexif.dump(exif_dict)
save_kwargs["exif"] = exif_bytes save_kwargs["exif"] = exif_bytes
except Exception as e: except Exception as e:
print(f"Error adding EXIF data: {e}") logger.error(f"Error adding EXIF data: {e}")
img.save(file_path, format="WEBP", **save_kwargs) img.save(file_path, format="WEBP", **save_kwargs)
@@ -414,7 +417,7 @@ class SaveImageLM:
}) })
except Exception as e: except Exception as e:
print(f"Error saving image: {e}") logger.error(f"Error saving image: {e}")
return results return results

View File

@@ -30,6 +30,7 @@ ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
RouteDefinition("POST", "/api/lm/force-download-example-images", "force_download_example_images"), RouteDefinition("POST", "/api/lm/force-download-example-images", "force_download_example_images"),
RouteDefinition("POST", "/api/lm/cleanup-example-image-folders", "cleanup_example_image_folders"), RouteDefinition("POST", "/api/lm/cleanup-example-image-folders", "cleanup_example_image_folders"),
RouteDefinition("POST", "/api/lm/example-images/set-nsfw-level", "set_example_image_nsfw_level"), RouteDefinition("POST", "/api/lm/example-images/set-nsfw-level", "set_example_image_nsfw_level"),
RouteDefinition("POST", "/api/lm/check-example-images-needed", "check_example_images_needed"),
) )

View File

@@ -92,6 +92,19 @@ class ExampleImagesDownloadHandler:
except ExampleImagesDownloadError as exc: except ExampleImagesDownloadError as exc:
return web.json_response({'success': False, 'error': str(exc)}, status=500) return web.json_response({'success': False, 'error': str(exc)}, status=500)
async def check_example_images_needed(self, request: web.Request) -> web.StreamResponse:
"""Lightweight check to see if any models need example images downloaded."""
try:
payload = await request.json()
model_types = payload.get('model_types', ['lora', 'checkpoint', 'embedding'])
result = await self._download_manager.check_pending_models(model_types)
return web.json_response(result)
except Exception as exc:
return web.json_response(
{'success': False, 'error': str(exc)},
status=500
)
class ExampleImagesManagementHandler: class ExampleImagesManagementHandler:
"""HTTP adapters for import/delete endpoints.""" """HTTP adapters for import/delete endpoints."""
@@ -161,6 +174,7 @@ class ExampleImagesHandlerSet:
"resume_example_images": self.download.resume_example_images, "resume_example_images": self.download.resume_example_images,
"stop_example_images": self.download.stop_example_images, "stop_example_images": self.download.stop_example_images,
"force_download_example_images": self.download.force_download_example_images, "force_download_example_images": self.download.force_download_example_images,
"check_example_images_needed": self.download.check_example_images_needed,
"import_example_images": self.management.import_example_images, "import_example_images": self.management.import_example_images,
"delete_example_image": self.management.delete_example_image, "delete_example_image": self.management.delete_example_image,
"set_example_image_nsfw_level": self.management.set_example_image_nsfw_level, "set_example_image_nsfw_level": self.management.set_example_image_nsfw_level,

View File

@@ -0,0 +1,259 @@
"""
Cache Entry Validator
Validates and repairs cache entries to prevent runtime errors from
missing or invalid critical fields.
"""
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Tuple
import logging
import os
logger = logging.getLogger(__name__)
@dataclass
class ValidationResult:
"""Result of validating a single cache entry."""
is_valid: bool
repaired: bool
errors: List[str] = field(default_factory=list)
entry: Optional[Dict[str, Any]] = None
class CacheEntryValidator:
"""
Validates and repairs cache entry core fields.
Critical fields that cause runtime errors when missing:
- file_path: KeyError in multiple locations
- sha256: KeyError/AttributeError in hash operations
Medium severity fields that may cause sorting/display issues:
- size: KeyError during sorting
- modified: KeyError during sorting
- model_name: AttributeError on .lower() calls
Low severity fields:
- tags: KeyError/TypeError in recipe operations
"""
# Field definitions: (default_value, is_required)
CORE_FIELDS: Dict[str, Tuple[Any, bool]] = {
'file_path': ('', True),
'sha256': ('', True),
'file_name': ('', False),
'model_name': ('', False),
'folder': ('', False),
'size': (0, False),
'modified': (0.0, False),
'tags': ([], False),
'preview_url': ('', False),
'base_model': ('', False),
'from_civitai': (True, False),
'favorite': (False, False),
'exclude': (False, False),
'db_checked': (False, False),
'preview_nsfw_level': (0, False),
'notes': ('', False),
'usage_tips': ('', False),
}
@classmethod
def validate(cls, entry: Dict[str, Any], *, auto_repair: bool = True) -> ValidationResult:
"""
Validate a single cache entry.
Args:
entry: The cache entry dictionary to validate
auto_repair: If True, attempt to repair missing/invalid fields
Returns:
ValidationResult with validation status and optionally repaired entry
"""
if entry is None:
return ValidationResult(
is_valid=False,
repaired=False,
errors=['Entry is None'],
entry=None
)
if not isinstance(entry, dict):
return ValidationResult(
is_valid=False,
repaired=False,
errors=[f'Entry is not a dict: {type(entry).__name__}'],
entry=None
)
errors: List[str] = []
repaired = False
working_entry = dict(entry) if auto_repair else entry
for field_name, (default_value, is_required) in cls.CORE_FIELDS.items():
value = working_entry.get(field_name)
# Check if field is missing or None
if value is None:
if is_required:
errors.append(f"Required field '{field_name}' is missing or None")
if auto_repair:
working_entry[field_name] = cls._get_default_copy(default_value)
repaired = True
continue
# Validate field type and value
field_error = cls._validate_field(field_name, value, default_value)
if field_error:
errors.append(field_error)
if auto_repair:
working_entry[field_name] = cls._get_default_copy(default_value)
repaired = True
# Special validation: file_path must not be empty for required field
file_path = working_entry.get('file_path', '')
if not file_path or (isinstance(file_path, str) and not file_path.strip()):
errors.append("Required field 'file_path' is empty")
# Cannot repair empty file_path - entry is invalid
return ValidationResult(
is_valid=False,
repaired=repaired,
errors=errors,
entry=working_entry if auto_repair else None
)
# Special validation: sha256 must not be empty for required field
sha256 = working_entry.get('sha256', '')
if not sha256 or (isinstance(sha256, str) and not sha256.strip()):
errors.append("Required field 'sha256' is empty")
# Cannot repair empty sha256 - entry is invalid
return ValidationResult(
is_valid=False,
repaired=repaired,
errors=errors,
entry=working_entry if auto_repair else None
)
# Normalize sha256 to lowercase if needed
if isinstance(sha256, str):
normalized_sha = sha256.lower().strip()
if normalized_sha != sha256:
working_entry['sha256'] = normalized_sha
repaired = True
# Determine if entry is valid
# Entry is valid if no critical required field errors remain after repair
# Critical fields are file_path and sha256
CRITICAL_REQUIRED_FIELDS = {'file_path', 'sha256'}
has_critical_errors = any(
"Required field" in error and
any(f"'{field}'" in error for field in CRITICAL_REQUIRED_FIELDS)
for error in errors
)
is_valid = not has_critical_errors
return ValidationResult(
is_valid=is_valid,
repaired=repaired,
errors=errors,
entry=working_entry if auto_repair else entry
)
@classmethod
def validate_batch(
cls,
entries: List[Dict[str, Any]],
*,
auto_repair: bool = True
) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]]:
"""
Validate a batch of cache entries.
Args:
entries: List of cache entry dictionaries to validate
auto_repair: If True, attempt to repair missing/invalid fields
Returns:
Tuple of (valid_entries, invalid_entries)
"""
if not entries:
return [], []
valid_entries: List[Dict[str, Any]] = []
invalid_entries: List[Dict[str, Any]] = []
for entry in entries:
result = cls.validate(entry, auto_repair=auto_repair)
if result.is_valid:
# Use repaired entry if available, otherwise original
valid_entries.append(result.entry if result.entry else entry)
else:
invalid_entries.append(entry)
# Log invalid entries for debugging
file_path = entry.get('file_path', '<unknown>') if isinstance(entry, dict) else '<not a dict>'
logger.warning(
f"Invalid cache entry for '{file_path}': {', '.join(result.errors)}"
)
return valid_entries, invalid_entries
@classmethod
def _validate_field(cls, field_name: str, value: Any, default_value: Any) -> Optional[str]:
"""
Validate a specific field value.
Returns an error message if invalid, None if valid.
"""
expected_type = type(default_value)
# Special handling for numeric types
if expected_type == int:
if not isinstance(value, (int, float)):
return f"Field '{field_name}' should be numeric, got {type(value).__name__}"
elif expected_type == float:
if not isinstance(value, (int, float)):
return f"Field '{field_name}' should be numeric, got {type(value).__name__}"
elif expected_type == bool:
# Be lenient with boolean fields - accept truthy/falsy values
pass
elif expected_type == str:
if not isinstance(value, str):
return f"Field '{field_name}' should be string, got {type(value).__name__}"
elif expected_type == list:
if not isinstance(value, (list, tuple)):
return f"Field '{field_name}' should be list, got {type(value).__name__}"
return None
@classmethod
def _get_default_copy(cls, default_value: Any) -> Any:
"""Get a copy of the default value to avoid shared mutable state."""
if isinstance(default_value, list):
return list(default_value)
if isinstance(default_value, dict):
return dict(default_value)
return default_value
@classmethod
def get_file_path_safe(cls, entry: Dict[str, Any], default: str = '') -> str:
"""Safely get file_path from an entry."""
if not isinstance(entry, dict):
return default
value = entry.get('file_path')
if isinstance(value, str):
return value
return default
@classmethod
def get_sha256_safe(cls, entry: Dict[str, Any], default: str = '') -> str:
"""Safely get sha256 from an entry."""
if not isinstance(entry, dict):
return default
value = entry.get('sha256')
if isinstance(value, str):
return value.lower()
return default

View File

@@ -0,0 +1,201 @@
"""
Cache Health Monitor
Monitors cache health status and determines when user intervention is needed.
"""
from dataclasses import dataclass, field
from enum import Enum
from typing import Any, Dict, List, Optional
import logging
from .cache_entry_validator import CacheEntryValidator, ValidationResult
logger = logging.getLogger(__name__)
class CacheHealthStatus(Enum):
"""Health status of the cache."""
HEALTHY = "healthy"
DEGRADED = "degraded"
CORRUPTED = "corrupted"
@dataclass
class HealthReport:
"""Report of cache health check."""
status: CacheHealthStatus
total_entries: int
valid_entries: int
invalid_entries: int
repaired_entries: int
invalid_paths: List[str] = field(default_factory=list)
message: str = ""
@property
def corruption_rate(self) -> float:
"""Calculate the percentage of invalid entries."""
if self.total_entries <= 0:
return 0.0
return self.invalid_entries / self.total_entries
def to_dict(self) -> Dict[str, Any]:
"""Convert to dictionary for JSON serialization."""
return {
'status': self.status.value,
'total_entries': self.total_entries,
'valid_entries': self.valid_entries,
'invalid_entries': self.invalid_entries,
'repaired_entries': self.repaired_entries,
'corruption_rate': f"{self.corruption_rate:.1%}",
'invalid_paths': self.invalid_paths[:10], # Limit to first 10
'message': self.message,
}
class CacheHealthMonitor:
"""
Monitors cache health and determines appropriate status.
Thresholds:
- HEALTHY: 0% invalid entries
- DEGRADED: 0-5% invalid entries (auto-repaired, user should rebuild)
- CORRUPTED: >5% invalid entries (significant data loss likely)
"""
# Threshold percentages
DEGRADED_THRESHOLD = 0.01 # 1% - show warning
CORRUPTED_THRESHOLD = 0.05 # 5% - critical warning
def __init__(
self,
*,
degraded_threshold: float = DEGRADED_THRESHOLD,
corrupted_threshold: float = CORRUPTED_THRESHOLD
):
"""
Initialize the health monitor.
Args:
degraded_threshold: Corruption rate threshold for DEGRADED status
corrupted_threshold: Corruption rate threshold for CORRUPTED status
"""
self.degraded_threshold = degraded_threshold
self.corrupted_threshold = corrupted_threshold
def check_health(
self,
entries: List[Dict[str, Any]],
*,
auto_repair: bool = True
) -> HealthReport:
"""
Check the health of cache entries.
Args:
entries: List of cache entry dictionaries to check
auto_repair: If True, attempt to repair entries during validation
Returns:
HealthReport with status and statistics
"""
if not entries:
return HealthReport(
status=CacheHealthStatus.HEALTHY,
total_entries=0,
valid_entries=0,
invalid_entries=0,
repaired_entries=0,
message="Cache is empty"
)
total_entries = len(entries)
valid_entries: List[Dict[str, Any]] = []
invalid_entries: List[Dict[str, Any]] = []
repaired_count = 0
invalid_paths: List[str] = []
for entry in entries:
result = CacheEntryValidator.validate(entry, auto_repair=auto_repair)
if result.is_valid:
valid_entries.append(result.entry if result.entry else entry)
if result.repaired:
repaired_count += 1
else:
invalid_entries.append(entry)
# Extract file path for reporting
file_path = CacheEntryValidator.get_file_path_safe(entry, '<unknown>')
invalid_paths.append(file_path)
invalid_count = len(invalid_entries)
valid_count = len(valid_entries)
# Determine status based on corruption rate
corruption_rate = invalid_count / total_entries if total_entries > 0 else 0.0
if invalid_count == 0:
status = CacheHealthStatus.HEALTHY
message = "Cache is healthy"
elif corruption_rate >= self.corrupted_threshold:
status = CacheHealthStatus.CORRUPTED
message = (
f"Cache is corrupted: {invalid_count} invalid entries "
f"({corruption_rate:.1%}). Rebuild recommended."
)
elif corruption_rate >= self.degraded_threshold or invalid_count > 0:
status = CacheHealthStatus.DEGRADED
message = (
f"Cache has {invalid_count} invalid entries "
f"({corruption_rate:.1%}). Consider rebuilding cache."
)
else:
# This shouldn't happen, but handle gracefully
status = CacheHealthStatus.HEALTHY
message = "Cache is healthy"
# Log the health check result
if status != CacheHealthStatus.HEALTHY:
logger.warning(
f"Cache health check: {status.value} - "
f"{invalid_count}/{total_entries} invalid, "
f"{repaired_count} repaired"
)
if invalid_paths:
logger.debug(f"Invalid entry paths: {invalid_paths[:5]}")
return HealthReport(
status=status,
total_entries=total_entries,
valid_entries=valid_count,
invalid_entries=invalid_count,
repaired_entries=repaired_count,
invalid_paths=invalid_paths,
message=message
)
def should_notify_user(self, report: HealthReport) -> bool:
"""
Determine if the user should be notified about cache health.
Args:
report: The health report to evaluate
Returns:
True if user should be notified
"""
return report.status != CacheHealthStatus.HEALTHY
def get_notification_severity(self, report: HealthReport) -> str:
"""
Get the severity level for user notification.
Args:
report: The health report to evaluate
Returns:
Severity string: 'warning' or 'error'
"""
if report.status == CacheHealthStatus.CORRUPTED:
return 'error'
return 'warning'

View File

@@ -30,36 +30,36 @@ class LoraScanner(ModelScanner):
async def diagnose_hash_index(self): async def diagnose_hash_index(self):
"""Diagnostic method to verify hash index functionality""" """Diagnostic method to verify hash index functionality"""
print("\n\n*** DIAGNOSING LORA HASH INDEX ***\n\n", file=sys.stderr) logger.debug("\n\n*** DIAGNOSING LORA HASH INDEX ***\n\n")
# First check if the hash index has any entries # First check if the hash index has any entries
if hasattr(self, '_hash_index'): if hasattr(self, '_hash_index'):
index_entries = len(self._hash_index._hash_to_path) index_entries = len(self._hash_index._hash_to_path)
print(f"Hash index has {index_entries} entries", file=sys.stderr) logger.debug(f"Hash index has {index_entries} entries")
# Print a few example entries if available # Print a few example entries if available
if index_entries > 0: if index_entries > 0:
print("\nSample hash index entries:", file=sys.stderr) logger.debug("\nSample hash index entries:")
count = 0 count = 0
for hash_val, path in self._hash_index._hash_to_path.items(): for hash_val, path in self._hash_index._hash_to_path.items():
if count < 5: # Just show the first 5 if count < 5: # Just show the first 5
print(f"Hash: {hash_val[:8]}... -> Path: {path}", file=sys.stderr) logger.debug(f"Hash: {hash_val[:8]}... -> Path: {path}")
count += 1 count += 1
else: else:
break break
else: else:
print("Hash index not initialized", file=sys.stderr) logger.debug("Hash index not initialized")
# Try looking up by a known hash for testing # Try looking up by a known hash for testing
if not hasattr(self, '_hash_index') or not self._hash_index._hash_to_path: if not hasattr(self, '_hash_index') or not self._hash_index._hash_to_path:
print("No hash entries to test lookup with", file=sys.stderr) logger.debug("No hash entries to test lookup with")
return return
test_hash = next(iter(self._hash_index._hash_to_path.keys())) test_hash = next(iter(self._hash_index._hash_to_path.keys()))
test_path = self._hash_index.get_path(test_hash) test_path = self._hash_index.get_path(test_hash)
print(f"\nTest lookup by hash: {test_hash[:8]}... -> {test_path}", file=sys.stderr) logger.debug(f"\nTest lookup by hash: {test_hash[:8]}... -> {test_path}")
# Also test reverse lookup # Also test reverse lookup
test_hash_result = self._hash_index.get_hash(test_path) test_hash_result = self._hash_index.get_hash(test_path)
print(f"Test reverse lookup: {test_path} -> {test_hash_result[:8]}...\n\n", file=sys.stderr) logger.debug(f"Test reverse lookup: {test_path} -> {test_hash_result[:8]}...\n\n")

View File

@@ -5,7 +5,6 @@ import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Optional, Tuple
from dataclasses import dataclass, field from dataclasses import dataclass, field
from operator import itemgetter
from natsort import natsorted from natsort import natsorted
# Supported sort modes: (sort_key, order) # Supported sort modes: (sort_key, order)
@@ -229,17 +228,17 @@ class ModelCache:
reverse=reverse reverse=reverse
) )
elif sort_key == 'date': elif sort_key == 'date':
# Sort by modified timestamp # Sort by modified timestamp (use .get() with default to handle missing fields)
result = sorted( result = sorted(
data, data,
key=itemgetter('modified'), key=lambda x: x.get('modified', 0.0),
reverse=reverse reverse=reverse
) )
elif sort_key == 'size': elif sort_key == 'size':
# Sort by file size # Sort by file size (use .get() with default to handle missing fields)
result = sorted( result = sorted(
data, data,
key=itemgetter('size'), key=lambda x: x.get('size', 0),
reverse=reverse reverse=reverse
) )
elif sort_key == 'usage': elif sort_key == 'usage':

View File

@@ -20,6 +20,8 @@ from .service_registry import ServiceRegistry
from .websocket_manager import ws_manager from .websocket_manager import ws_manager
from .persistent_model_cache import get_persistent_cache from .persistent_model_cache import get_persistent_cache
from .settings_manager import get_settings_manager from .settings_manager import get_settings_manager
from .cache_entry_validator import CacheEntryValidator
from .cache_health_monitor import CacheHealthMonitor, CacheHealthStatus
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -468,6 +470,39 @@ class ModelScanner:
for tag in adjusted_item.get('tags') or []: for tag in adjusted_item.get('tags') or []:
tags_count[tag] = tags_count.get(tag, 0) + 1 tags_count[tag] = tags_count.get(tag, 0) + 1
# Validate cache entries and check health
valid_entries, invalid_entries = CacheEntryValidator.validate_batch(
adjusted_raw_data, auto_repair=True
)
if invalid_entries:
monitor = CacheHealthMonitor()
report = monitor.check_health(adjusted_raw_data, auto_repair=True)
if report.status != CacheHealthStatus.HEALTHY:
# Broadcast health warning to frontend
await ws_manager.broadcast_cache_health_warning(report, page_type)
logger.warning(
f"{self.model_type.capitalize()} Scanner: Cache health issue detected - "
f"{report.invalid_entries} invalid entries, {report.repaired_entries} repaired"
)
# Use only valid entries
adjusted_raw_data = valid_entries
# Rebuild tags count from valid entries only
tags_count = {}
for item in adjusted_raw_data:
for tag in item.get('tags') or []:
tags_count[tag] = tags_count.get(tag, 0) + 1
# Remove invalid entries from hash index
for invalid_entry in invalid_entries:
file_path = CacheEntryValidator.get_file_path_safe(invalid_entry)
sha256 = CacheEntryValidator.get_sha256_safe(invalid_entry)
if file_path:
hash_index.remove_by_path(file_path, sha256)
scan_result = CacheBuildResult( scan_result = CacheBuildResult(
raw_data=adjusted_raw_data, raw_data=adjusted_raw_data,
hash_index=hash_index, hash_index=hash_index,
@@ -651,7 +686,6 @@ class ModelScanner:
async def _initialize_cache(self) -> None: async def _initialize_cache(self) -> None:
"""Initialize or refresh the cache""" """Initialize or refresh the cache"""
print("init start", flush=True)
self._is_initializing = True # Set flag self._is_initializing = True # Set flag
try: try:
start_time = time.time() start_time = time.time()
@@ -665,7 +699,6 @@ class ModelScanner:
scan_result = await self._gather_model_data() scan_result = await self._gather_model_data()
await self._apply_scan_result(scan_result) await self._apply_scan_result(scan_result)
await self._save_persistent_cache(scan_result) await self._save_persistent_cache(scan_result)
print("init end", flush=True)
logger.info( logger.info(
f"{self.model_type.capitalize()} Scanner: Cache initialization completed in {time.time() - start_time:.2f} seconds, " f"{self.model_type.capitalize()} Scanner: Cache initialization completed in {time.time() - start_time:.2f} seconds, "
@@ -776,6 +809,18 @@ class ModelScanner:
model_data = self.adjust_cached_entry(dict(model_data)) model_data = self.adjust_cached_entry(dict(model_data))
if not model_data: if not model_data:
continue continue
# Validate the new entry before adding
validation_result = CacheEntryValidator.validate(
model_data, auto_repair=True
)
if not validation_result.is_valid:
logger.warning(
f"Skipping invalid entry during reconcile: {path}"
)
continue
model_data = validation_result.entry
self._ensure_license_flags(model_data) self._ensure_license_flags(model_data)
# Add to cache # Add to cache
self._cache.raw_data.append(model_data) self._cache.raw_data.append(model_data)
@@ -1090,6 +1135,17 @@ class ModelScanner:
processed_files += 1 processed_files += 1
if result: if result:
# Validate the entry before adding
validation_result = CacheEntryValidator.validate(
result, auto_repair=True
)
if not validation_result.is_valid:
logger.warning(
f"Skipping invalid scan result: {file_path}"
)
continue
result = validation_result.entry
self._ensure_license_flags(result) self._ensure_license_flags(result)
raw_data.append(result) raw_data.append(result)

View File

@@ -9,7 +9,7 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple
from ..config import config from ..config import config
from .recipe_cache import RecipeCache from .recipe_cache import RecipeCache
from .recipe_fts_index import RecipeFTSIndex from .recipe_fts_index import RecipeFTSIndex
from .persistent_recipe_cache import PersistentRecipeCache, get_persistent_recipe_cache from .persistent_recipe_cache import PersistentRecipeCache, get_persistent_recipe_cache, PersistedRecipeData
from .service_registry import ServiceRegistry from .service_registry import ServiceRegistry
from .lora_scanner import LoraScanner from .lora_scanner import LoraScanner
from .metadata_service import get_default_metadata_provider from .metadata_service import get_default_metadata_provider
@@ -431,6 +431,16 @@ class RecipeScanner:
4. Persist results for next startup 4. Persist results for next startup
""" """
try: try:
# Ensure cache exists to avoid None reference errors
if self._cache is None:
self._cache = RecipeCache(
raw_data=[],
sorted_by_name=[],
sorted_by_date=[],
folders=[],
folder_tree={},
)
# Create a new event loop for this thread # Create a new event loop for this thread
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
@@ -492,7 +502,7 @@ class RecipeScanner:
def _reconcile_recipe_cache( def _reconcile_recipe_cache(
self, self,
persisted: "PersistedRecipeData", persisted: PersistedRecipeData,
recipes_dir: str, recipes_dir: str,
) -> Tuple[List[Dict], bool, Dict[str, str]]: ) -> Tuple[List[Dict], bool, Dict[str, str]]:
"""Reconcile persisted cache with current filesystem state. """Reconcile persisted cache with current filesystem state.
@@ -504,8 +514,6 @@ class RecipeScanner:
Returns: Returns:
Tuple of (recipes list, changed flag, json_paths dict). Tuple of (recipes list, changed flag, json_paths dict).
""" """
from .persistent_recipe_cache import PersistedRecipeData
recipes: List[Dict] = [] recipes: List[Dict] = []
json_paths: Dict[str, str] = {} json_paths: Dict[str, str] = {}
changed = False changed = False
@@ -522,32 +530,37 @@ class RecipeScanner:
except OSError: except OSError:
continue continue
# Build lookup of persisted recipes by json_path # Build recipe_id -> recipe lookup (O(n) instead of O(n²))
persisted_by_path: Dict[str, Dict] = {} recipe_by_id: Dict[str, Dict] = {
for recipe in persisted.raw_data:
recipe_id = str(recipe.get('id', ''))
if recipe_id:
# Find the json_path from file_stats
for json_path, (mtime, size) in persisted.file_stats.items():
if os.path.basename(json_path).startswith(recipe_id):
persisted_by_path[json_path] = recipe
break
# Also index by recipe ID for faster lookups
persisted_by_id: Dict[str, Dict] = {
str(r.get('id', '')): r for r in persisted.raw_data if r.get('id') str(r.get('id', '')): r for r in persisted.raw_data if r.get('id')
} }
# Build json_path -> recipe lookup from file_stats (O(m))
persisted_by_path: Dict[str, Dict] = {}
for json_path in persisted.file_stats.keys():
basename = os.path.basename(json_path)
if basename.lower().endswith('.recipe.json'):
recipe_id = basename[:-len('.recipe.json')]
if recipe_id in recipe_by_id:
persisted_by_path[json_path] = recipe_by_id[recipe_id]
# Process current files # Process current files
for file_path, (current_mtime, current_size) in current_files.items(): for file_path, (current_mtime, current_size) in current_files.items():
cached_stats = persisted.file_stats.get(file_path) cached_stats = persisted.file_stats.get(file_path)
# Extract recipe_id from current file for fallback lookup
basename = os.path.basename(file_path)
recipe_id_from_file = basename[:-len('.recipe.json')] if basename.lower().endswith('.recipe.json') else None
if cached_stats: if cached_stats:
cached_mtime, cached_size = cached_stats cached_mtime, cached_size = cached_stats
# Check if file is unchanged # Check if file is unchanged
if abs(current_mtime - cached_mtime) < 1.0 and current_size == cached_size: if abs(current_mtime - cached_mtime) < 1.0 and current_size == cached_size:
# Use cached data # Try direct path lookup first
cached_recipe = persisted_by_path.get(file_path) cached_recipe = persisted_by_path.get(file_path)
# Fallback to recipe_id lookup if path lookup fails
if not cached_recipe and recipe_id_from_file:
cached_recipe = recipe_by_id.get(recipe_id_from_file)
if cached_recipe: if cached_recipe:
recipe_id = str(cached_recipe.get('id', '')) recipe_id = str(cached_recipe.get('id', ''))
# Track folder from file path # Track folder from file path

View File

@@ -63,7 +63,7 @@ DEFAULT_SETTINGS: Dict[str, Any] = {
"compact_mode": False, "compact_mode": False,
"priority_tags": DEFAULT_PRIORITY_TAG_CONFIG.copy(), "priority_tags": DEFAULT_PRIORITY_TAG_CONFIG.copy(),
"model_name_display": "model_name", "model_name_display": "model_name",
"model_card_footer_action": "example_images", "model_card_footer_action": "replace_preview",
"update_flag_strategy": "same_base", "update_flag_strategy": "same_base",
"auto_organize_exclusions": [], "auto_organize_exclusions": [],
} }

View File

@@ -48,9 +48,14 @@ class BulkMetadataRefreshUseCase:
for model in cache.raw_data for model in cache.raw_data
if model.get("sha256") if model.get("sha256")
and (not model.get("civitai") or not model["civitai"].get("id")) and (not model.get("civitai") or not model["civitai"].get("id"))
and ( and not (
(enable_metadata_archive_db and not model.get("db_checked", False)) # Skip models confirmed not on CivitAI when no need to retry
or (not enable_metadata_archive_db and model.get("from_civitai") is True) model.get("from_civitai") is False
and model.get("civitai_deleted") is True
and (
not enable_metadata_archive_db
or model.get("db_checked", False)
)
) )
] ]

View File

@@ -255,6 +255,42 @@ class WebSocketManager:
self._download_progress.pop(download_id, None) self._download_progress.pop(download_id, None)
logger.debug(f"Cleaned up old download progress for {download_id}") logger.debug(f"Cleaned up old download progress for {download_id}")
async def broadcast_cache_health_warning(self, report: 'HealthReport', page_type: str = None):
"""
Broadcast cache health warning to frontend.
Args:
report: HealthReport instance from CacheHealthMonitor
page_type: The page type (loras, checkpoints, embeddings)
"""
from .cache_health_monitor import CacheHealthStatus
# Only broadcast if there are issues
if report.status == CacheHealthStatus.HEALTHY:
return
payload = {
'type': 'cache_health_warning',
'status': report.status.value,
'message': report.message,
'pageType': page_type,
'details': {
'total': report.total_entries,
'valid': report.valid_entries,
'invalid': report.invalid_entries,
'repaired': report.repaired_entries,
'corruption_rate': f"{report.corruption_rate:.1%}",
'invalid_paths': report.invalid_paths[:5], # Limit to first 5
}
}
logger.info(
f"Broadcasting cache health warning: {report.status.value} "
f"({report.invalid_entries} invalid entries)"
)
await self.broadcast(payload)
def get_connected_clients_count(self) -> int: def get_connected_clients_count(self) -> int:
"""Get number of connected clients""" """Get number of connected clients"""
return len(self._websockets) return len(self._websockets)

View File

@@ -216,6 +216,11 @@ class DownloadManager:
self._progress["failed_models"] = set() self._progress["failed_models"] = set()
self._is_downloading = True self._is_downloading = True
snapshot = self._progress.snapshot()
# Create the download task without awaiting it
# This ensures the HTTP response is returned immediately
# while the actual processing happens in the background
self._download_task = asyncio.create_task( self._download_task = asyncio.create_task(
self._download_all_example_images( self._download_all_example_images(
output_dir, output_dir,
@@ -227,7 +232,10 @@ class DownloadManager:
) )
) )
snapshot = self._progress.snapshot() # Add a callback to handle task completion/errors
self._download_task.add_done_callback(
lambda t: self._handle_download_task_done(t, output_dir)
)
except ExampleImagesDownloadError: except ExampleImagesDownloadError:
# Re-raise our own exception types without wrapping # Re-raise our own exception types without wrapping
self._is_downloading = False self._is_downloading = False
@@ -241,10 +249,25 @@ class DownloadManager:
) )
raise ExampleImagesDownloadError(str(e)) from e raise ExampleImagesDownloadError(str(e)) from e
await self._broadcast_progress(status="running") # Broadcast progress in the background without blocking the response
# This ensures the HTTP response is returned immediately
asyncio.create_task(self._broadcast_progress(status="running"))
return {"success": True, "message": "Download started", "status": snapshot} return {"success": True, "message": "Download started", "status": snapshot}
def _handle_download_task_done(self, task: asyncio.Task, output_dir: str) -> None:
"""Handle download task completion, including saving progress on error."""
try:
# This will re-raise any exception from the task
task.result()
except Exception as e:
logger.error(f"Download task failed with error: {e}", exc_info=True)
# Ensure progress is saved even on failure
try:
self._save_progress(output_dir)
except Exception as save_error:
logger.error(f"Failed to save progress after task failure: {save_error}")
async def get_status(self, request): async def get_status(self, request):
"""Get the current status of example images download.""" """Get the current status of example images download."""
@@ -254,6 +277,130 @@ class DownloadManager:
"status": self._progress.snapshot(), "status": self._progress.snapshot(),
} }
async def check_pending_models(self, model_types: list[str]) -> dict:
"""Quickly check how many models need example images downloaded.
This is a lightweight check that avoids the overhead of starting
a full download task when no work is needed.
Returns:
dict with keys:
- total_models: Total number of models across specified types
- pending_count: Number of models needing example images
- processed_count: Number of already processed models
- failed_count: Number of models marked as failed
- needs_download: True if there are pending models to process
"""
from ..services.service_registry import ServiceRegistry
if self._is_downloading:
return {
"success": True,
"is_downloading": True,
"total_models": 0,
"pending_count": 0,
"processed_count": 0,
"failed_count": 0,
"needs_download": False,
"message": "Download already in progress",
}
try:
# Get scanners
scanners = []
if "lora" in model_types:
lora_scanner = await ServiceRegistry.get_lora_scanner()
scanners.append(("lora", lora_scanner))
if "checkpoint" in model_types:
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
scanners.append(("checkpoint", checkpoint_scanner))
if "embedding" in model_types:
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
scanners.append(("embedding", embedding_scanner))
# Load progress file to check processed models
settings_manager = get_settings_manager()
active_library = settings_manager.get_active_library_name()
output_dir = self._resolve_output_dir(active_library)
processed_models: set[str] = set()
failed_models: set[str] = set()
if output_dir:
progress_file = os.path.join(output_dir, ".download_progress.json")
if os.path.exists(progress_file):
try:
with open(progress_file, "r", encoding="utf-8") as f:
saved_progress = json.load(f)
processed_models = set(saved_progress.get("processed_models", []))
failed_models = set(saved_progress.get("failed_models", []))
except Exception:
pass # Ignore progress file errors for quick check
# Count models
total_models = 0
models_with_hash = 0
for scanner_type, scanner in scanners:
cache = await scanner.get_cached_data()
if cache and cache.raw_data:
for model in cache.raw_data:
total_models += 1
if model.get("sha256"):
models_with_hash += 1
# Calculate pending count
# A model is pending if it has a hash and is not in processed_models
# We also exclude failed_models unless force mode would be used
pending_count = models_with_hash - len(processed_models.intersection(
{m.get("sha256", "").lower() for scanner_type, scanner in scanners
for m in (await scanner.get_cached_data()).raw_data if m.get("sha256")}
))
# More accurate pending count: check which models actually need processing
pending_hashes = set()
for scanner_type, scanner in scanners:
cache = await scanner.get_cached_data()
if cache and cache.raw_data:
for model in cache.raw_data:
raw_hash = model.get("sha256")
if not raw_hash:
continue
model_hash = raw_hash.lower()
if model_hash not in processed_models:
# Check if model folder exists with files
model_dir = ExampleImagePathResolver.get_model_folder(
model_hash, active_library
)
if not _model_directory_has_files(model_dir):
pending_hashes.add(model_hash)
pending_count = len(pending_hashes)
return {
"success": True,
"is_downloading": False,
"total_models": total_models,
"pending_count": pending_count,
"processed_count": len(processed_models),
"failed_count": len(failed_models),
"needs_download": pending_count > 0,
}
except Exception as e:
logger.error(f"Error checking pending models: {e}", exc_info=True)
return {
"success": False,
"error": str(e),
"total_models": 0,
"pending_count": 0,
"processed_count": 0,
"failed_count": 0,
"needs_download": False,
}
async def pause_download(self, request): async def pause_download(self, request):
"""Pause the example images download.""" """Pause the example images download."""

View File

@@ -43,8 +43,15 @@ class ExampleImagesProcessor:
return media_url return media_url
@staticmethod @staticmethod
def _get_file_extension_from_content_or_headers(content, headers, fallback_url=None): def _get_file_extension_from_content_or_headers(content, headers, fallback_url=None, media_type_hint=None):
"""Determine file extension from content magic bytes or headers""" """Determine file extension from content magic bytes or headers
Args:
content: File content bytes
headers: HTTP response headers
fallback_url: Original URL for extension extraction
media_type_hint: Optional media type hint from metadata (e.g., "video" or "image")
"""
# Check magic bytes for common formats # Check magic bytes for common formats
if content: if content:
if content.startswith(b'\xFF\xD8\xFF'): if content.startswith(b'\xFF\xD8\xFF'):
@@ -82,6 +89,10 @@ class ExampleImagesProcessor:
if ext in SUPPORTED_MEDIA_EXTENSIONS['images'] or ext in SUPPORTED_MEDIA_EXTENSIONS['videos']: if ext in SUPPORTED_MEDIA_EXTENSIONS['images'] or ext in SUPPORTED_MEDIA_EXTENSIONS['videos']:
return ext return ext
# Use media type hint from metadata if available
if media_type_hint == "video":
return '.mp4'
# Default fallback # Default fallback
return '.jpg' return '.jpg'
@@ -136,7 +147,7 @@ class ExampleImagesProcessor:
if success: if success:
# Determine file extension from content or headers # Determine file extension from content or headers
media_ext = ExampleImagesProcessor._get_file_extension_from_content_or_headers( media_ext = ExampleImagesProcessor._get_file_extension_from_content_or_headers(
content, headers, original_url content, headers, original_url, image.get("type")
) )
# Check if the detected file type is supported # Check if the detected file type is supported
@@ -219,7 +230,7 @@ class ExampleImagesProcessor:
if success: if success:
# Determine file extension from content or headers # Determine file extension from content or headers
media_ext = ExampleImagesProcessor._get_file_extension_from_content_or_headers( media_ext = ExampleImagesProcessor._get_file_extension_from_content_or_headers(
content, headers, original_url content, headers, original_url, image.get("type")
) )
# Check if the detected file type is supported # Check if the detected file type is supported

View File

@@ -17,7 +17,7 @@ async def extract_lora_metadata(file_path: str) -> Dict:
base_model = determine_base_model(metadata.get("ss_base_model_version")) base_model = determine_base_model(metadata.get("ss_base_model_version"))
return {"base_model": base_model} return {"base_model": base_model}
except Exception as e: except Exception as e:
print(f"Error reading metadata from {file_path}: {str(e)}") logger.error(f"Error reading metadata from {file_path}: {str(e)}")
return {"base_model": "Unknown"} return {"base_model": "Unknown"}
async def extract_checkpoint_metadata(file_path: str) -> dict: async def extract_checkpoint_metadata(file_path: str) -> dict:

View File

@@ -223,7 +223,7 @@ class MetadataManager:
preview_url=normalize_path(preview_url), preview_url=normalize_path(preview_url),
tags=[], tags=[],
modelDescription="", modelDescription="",
model_type="checkpoint", sub_type="checkpoint",
from_civitai=True from_civitai=True
) )
elif model_class.__name__ == "EmbeddingMetadata": elif model_class.__name__ == "EmbeddingMetadata":
@@ -238,6 +238,7 @@ class MetadataManager:
preview_url=normalize_path(preview_url), preview_url=normalize_path(preview_url),
tags=[], tags=[],
modelDescription="", modelDescription="",
sub_type="embedding",
from_civitai=True from_civitai=True
) )
else: # Default to LoraMetadata else: # Default to LoraMetadata

View File

@@ -1,7 +1,7 @@
[project] [project]
name = "comfyui-lora-manager" name = "comfyui-lora-manager"
description = "Revolutionize your workflow with the ultimate LoRA companion for ComfyUI!" description = "Revolutionize your workflow with the ultimate LoRA companion for ComfyUI!"
version = "0.9.13" version = "0.9.14"
license = {file = "LICENSE"} license = {file = "LICENSE"}
dependencies = [ dependencies = [
"aiohttp", "aiohttp",

0
scripts/sync_translation_keys.py Normal file → Executable file
View File

View File

@@ -113,6 +113,12 @@
max-width: 110px; max-width: 110px;
} }
/* Compact mode: hide sub-type to save space */
.compact-density .model-sub-type,
.compact-density .model-separator {
display: none;
}
.compact-density .card-actions i { .compact-density .card-actions i {
font-size: 0.95em; font-size: 0.95em;
padding: 3px; padding: 3px;

View File

@@ -26,6 +26,7 @@ class RecipeCard {
card.dataset.nsfwLevel = this.recipe.preview_nsfw_level || 0; card.dataset.nsfwLevel = this.recipe.preview_nsfw_level || 0;
card.dataset.created = this.recipe.created_date; card.dataset.created = this.recipe.created_date;
card.dataset.id = this.recipe.id || ''; card.dataset.id = this.recipe.id || '';
card.dataset.folder = this.recipe.folder || '';
// Get base model with fallback // Get base model with fallback
const baseModelLabel = (this.recipe.base_model || '').trim() || 'Unknown'; const baseModelLabel = (this.recipe.base_model || '').trim() || 'Unknown';

View File

@@ -199,6 +199,12 @@ class InitializationManager {
if (!data) return; if (!data) return;
console.log('Received progress update:', data); console.log('Received progress update:', data);
// Handle cache health warning messages
if (data.type === 'cache_health_warning') {
this.handleCacheHealthWarning(data);
return;
}
// Check if this update is for our page type // Check if this update is for our page type
if (data.pageType && data.pageType !== this.pageType) { if (data.pageType && data.pageType !== this.pageType) {
console.log(`Ignoring update for ${data.pageType}, we're on ${this.pageType}`); console.log(`Ignoring update for ${data.pageType}, we're on ${this.pageType}`);
@@ -466,6 +472,29 @@ class InitializationManager {
} }
} }
/**
* Handle cache health warning messages from WebSocket
*/
handleCacheHealthWarning(data) {
console.log('Cache health warning received:', data);
// Import bannerService dynamically to avoid circular dependencies
import('../managers/BannerService.js').then(({ bannerService }) => {
// Initialize banner service if not already done
if (!bannerService.initialized) {
bannerService.initialize().then(() => {
bannerService.registerCacheHealthBanner(data);
}).catch(err => {
console.error('Failed to initialize banner service:', err);
});
} else {
bannerService.registerCacheHealthBanner(data);
}
}).catch(err => {
console.error('Failed to load banner service:', err);
});
}
/** /**
* Clean up resources when the component is destroyed * Clean up resources when the component is destroyed
*/ */

View File

@@ -4,9 +4,11 @@ import {
removeStorageItem removeStorageItem
} from '../utils/storageHelpers.js'; } from '../utils/storageHelpers.js';
import { translate } from '../utils/i18nHelpers.js'; import { translate } from '../utils/i18nHelpers.js';
import { state } from '../state/index.js' import { state } from '../state/index.js';
import { getModelApiClient } from '../api/modelApiFactory.js';
const COMMUNITY_SUPPORT_BANNER_ID = 'community-support'; const COMMUNITY_SUPPORT_BANNER_ID = 'community-support';
const CACHE_HEALTH_BANNER_ID = 'cache-health-warning';
const COMMUNITY_SUPPORT_BANNER_DELAY_MS = 5 * 24 * 60 * 60 * 1000; // 5 days const COMMUNITY_SUPPORT_BANNER_DELAY_MS = 5 * 24 * 60 * 60 * 1000; // 5 days
const COMMUNITY_SUPPORT_FIRST_SEEN_AT_KEY = 'community_support_banner_first_seen_at'; const COMMUNITY_SUPPORT_FIRST_SEEN_AT_KEY = 'community_support_banner_first_seen_at';
const COMMUNITY_SUPPORT_VERSION_KEY = 'community_support_banner_state_version'; const COMMUNITY_SUPPORT_VERSION_KEY = 'community_support_banner_state_version';
@@ -293,6 +295,177 @@ class BannerService {
location.reload(); location.reload();
} }
/**
* Register a cache health warning banner
* @param {Object} healthData - Health data from WebSocket
*/
registerCacheHealthBanner(healthData) {
if (!healthData || healthData.status === 'healthy') {
return;
}
// Remove existing cache health banner if any
this.removeBannerElement(CACHE_HEALTH_BANNER_ID);
const isCorrupted = healthData.status === 'corrupted';
const titleKey = isCorrupted
? 'banners.cacheHealth.corrupted.title'
: 'banners.cacheHealth.degraded.title';
const defaultTitle = isCorrupted
? 'Cache Corruption Detected'
: 'Cache Issues Detected';
const title = translate(titleKey, {}, defaultTitle);
const contentKey = 'banners.cacheHealth.content';
const defaultContent = 'Found {invalid} of {total} cache entries are invalid ({rate}). This may cause missing models or errors. Rebuilding the cache is recommended.';
const content = translate(contentKey, {
invalid: healthData.details?.invalid || 0,
total: healthData.details?.total || 0,
rate: healthData.details?.corruption_rate || '0%'
}, defaultContent);
this.registerBanner(CACHE_HEALTH_BANNER_ID, {
id: CACHE_HEALTH_BANNER_ID,
title: title,
content: content,
pageType: healthData.pageType,
actions: [
{
text: translate('banners.cacheHealth.rebuildCache', {}, 'Rebuild Cache'),
icon: 'fas fa-sync-alt',
action: 'rebuild-cache',
type: 'primary'
},
{
text: translate('banners.cacheHealth.dismiss', {}, 'Dismiss'),
icon: 'fas fa-times',
action: 'dismiss',
type: 'secondary'
}
],
dismissible: true,
priority: 10, // High priority
onRegister: (bannerElement) => {
// Attach click handlers for actions
const rebuildBtn = bannerElement.querySelector('[data-action="rebuild-cache"]');
const dismissBtn = bannerElement.querySelector('[data-action="dismiss"]');
if (rebuildBtn) {
rebuildBtn.addEventListener('click', (e) => {
e.preventDefault();
this.handleRebuildCache(bannerElement, healthData.pageType);
});
}
if (dismissBtn) {
dismissBtn.addEventListener('click', (e) => {
e.preventDefault();
this.dismissBanner(CACHE_HEALTH_BANNER_ID);
});
}
}
});
}
/**
* Handle rebuild cache action from banner
* @param {HTMLElement} bannerElement - The banner element
* @param {string} pageType - The page type (loras, checkpoints, embeddings)
*/
async handleRebuildCache(bannerElement, pageType) {
const currentPageType = pageType || this.getCurrentPageType();
try {
const apiClient = getModelApiClient(currentPageType);
// Update banner to show rebuilding status
const actionsContainer = bannerElement.querySelector('.banner-actions');
if (actionsContainer) {
actionsContainer.innerHTML = `
<span class="banner-loading">
<i class="fas fa-spinner fa-spin"></i>
<span>${translate('banners.cacheHealth.rebuilding', {}, 'Rebuilding cache...')}</span>
</span>
`;
}
await apiClient.refreshModels(true);
// Remove banner on success without marking as dismissed
this.removeBannerElement(CACHE_HEALTH_BANNER_ID);
} catch (error) {
console.error('Cache rebuild failed:', error);
const actionsContainer = bannerElement.querySelector('.banner-actions');
if (actionsContainer) {
actionsContainer.innerHTML = `
<span class="banner-error">
<i class="fas fa-exclamation-triangle"></i>
<span>${translate('banners.cacheHealth.rebuildFailed', {}, 'Rebuild failed. Please try again.')}</span>
</span>
<a href="#" class="banner-action banner-action-primary" data-action="rebuild-cache">
<i class="fas fa-sync-alt"></i>
<span>${translate('banners.cacheHealth.retry', {}, 'Retry')}</span>
</a>
`;
// Re-attach click handler
const retryBtn = actionsContainer.querySelector('[data-action="rebuild-cache"]');
if (retryBtn) {
retryBtn.addEventListener('click', (e) => {
e.preventDefault();
this.handleRebuildCache(bannerElement, pageType);
});
}
}
}
}
/**
* Get the current page type from the URL
* @returns {string} Page type (loras, checkpoints, embeddings, recipes)
*/
getCurrentPageType() {
const path = window.location.pathname;
if (path.includes('/checkpoints')) return 'checkpoints';
if (path.includes('/embeddings')) return 'embeddings';
if (path.includes('/recipes')) return 'recipes';
return 'loras';
}
/**
* Get the rebuild cache endpoint for the given page type
* @param {string} pageType - The page type
* @returns {string} The API endpoint URL
*/
getRebuildEndpoint(pageType) {
const endpoints = {
'loras': '/api/lm/loras/reload?rebuild=true',
'checkpoints': '/api/lm/checkpoints/reload?rebuild=true',
'embeddings': '/api/lm/embeddings/reload?rebuild=true'
};
return endpoints[pageType] || endpoints['loras'];
}
/**
* Remove a banner element from DOM without marking as dismissed
* @param {string} bannerId - Banner ID to remove
*/
removeBannerElement(bannerId) {
const bannerElement = document.querySelector(`[data-banner-id="${bannerId}"]`);
if (bannerElement) {
bannerElement.style.animation = 'banner-slide-up 0.3s ease-in-out forwards';
setTimeout(() => {
bannerElement.remove();
this.updateContainerVisibility();
}, 300);
}
// Also remove from banners map
this.banners.delete(bannerId);
}
prepareCommunitySupportBanner() { prepareCommunitySupportBanner() {
if (this.isBannerDismissed(COMMUNITY_SUPPORT_BANNER_ID)) { if (this.isBannerDismissed(COMMUNITY_SUPPORT_BANNER_ID)) {
return; return;

View File

@@ -21,7 +21,7 @@ export class ExampleImagesManager {
// Auto download properties // Auto download properties
this.autoDownloadInterval = null; this.autoDownloadInterval = null;
this.lastAutoDownloadCheck = 0; this.lastAutoDownloadCheck = 0;
this.autoDownloadCheckInterval = 10 * 60 * 1000; // 10 minutes in milliseconds this.autoDownloadCheckInterval = 30 * 60 * 1000; // 30 minutes in milliseconds
this.pageInitTime = Date.now(); // Track when page was initialized this.pageInitTime = Date.now(); // Track when page was initialized
// Initialize download path field and check download status // Initialize download path field and check download status
@@ -808,19 +808,58 @@ export class ExampleImagesManager {
return; return;
} }
this.lastAutoDownloadCheck = now;
if (!this.canAutoDownload()) { if (!this.canAutoDownload()) {
console.log('Auto download conditions not met, skipping check'); console.log('Auto download conditions not met, skipping check');
return; return;
} }
try { try {
console.log('Performing auto download check...'); console.log('Performing auto download pre-check...');
// Step 1: Lightweight pre-check to see if any work is needed
const checkResponse = await fetch('/api/lm/check-example-images-needed', {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify({
model_types: ['lora', 'checkpoint', 'embedding']
})
});
if (!checkResponse.ok) {
console.warn('Auto download pre-check HTTP error:', checkResponse.status);
return;
}
const checkData = await checkResponse.json();
if (!checkData.success) {
console.warn('Auto download pre-check failed:', checkData.error);
return;
}
// Update the check timestamp only after successful pre-check
this.lastAutoDownloadCheck = now;
// If download already in progress, skip
if (checkData.is_downloading) {
console.log('Download already in progress, skipping auto check');
return;
}
// If no models need downloading, skip
if (!checkData.needs_download || checkData.pending_count === 0) {
console.log(`Auto download pre-check complete: ${checkData.processed_count}/${checkData.total_models} models already processed, no work needed`);
return;
}
console.log(`Auto download pre-check: ${checkData.pending_count} models need processing, starting download...`);
// Step 2: Start the actual download (fire-and-forget)
const optimize = state.global.settings.optimize_example_images; const optimize = state.global.settings.optimize_example_images;
const response = await fetch('/api/lm/download-example-images', { fetch('/api/lm/download-example-images', {
method: 'POST', method: 'POST',
headers: { headers: {
'Content-Type': 'application/json' 'Content-Type': 'application/json'
@@ -830,18 +869,29 @@ export class ExampleImagesManager {
model_types: ['lora', 'checkpoint', 'embedding'], model_types: ['lora', 'checkpoint', 'embedding'],
auto_mode: true // Flag to indicate this is an automatic download auto_mode: true // Flag to indicate this is an automatic download
}) })
}).then(response => {
if (!response.ok) {
console.warn('Auto download start HTTP error:', response.status);
return null;
}
return response.json();
}).then(data => {
if (data && !data.success) {
console.warn('Auto download start failed:', data.error);
// If already in progress, push back the next check to avoid hammering the API
if (data.error && data.error.includes('already in progress')) {
console.log('Download already in progress, backing off next check');
this.lastAutoDownloadCheck = now + (5 * 60 * 1000); // Back off for 5 extra minutes
}
} else if (data && data.success) {
console.log('Auto download started:', data.message || 'Download started');
}
}).catch(error => {
console.error('Auto download start error:', error);
}); });
const data = await response.json(); // Immediately return without waiting for the download fetch to complete
// This keeps the UI responsive
if (!data.success) {
console.warn('Auto download check failed:', data.error);
// If already in progress, push back the next check to avoid hammering the API
if (data.error && data.error.includes('already in progress')) {
console.log('Download already in progress, backing off next check');
this.lastAutoDownloadCheck = now + (5 * 60 * 1000); // Back off for 5 extra minutes
}
}
} catch (error) { } catch (error) {
console.error('Auto download check error:', error); console.error('Auto download check error:', error);
} }

View File

@@ -27,6 +27,10 @@ export const BASE_MODELS = {
FLUX_1_KREA: "Flux.1 Krea", FLUX_1_KREA: "Flux.1 Krea",
FLUX_1_KONTEXT: "Flux.1 Kontext", FLUX_1_KONTEXT: "Flux.1 Kontext",
FLUX_2_D: "Flux.2 D", FLUX_2_D: "Flux.2 D",
FLUX_2_KLEIN_9B: "Flux.2 Klein 9B",
FLUX_2_KLEIN_9B_BASE: "Flux.2 Klein 9B-base",
FLUX_2_KLEIN_4B: "Flux.2 Klein 4B",
FLUX_2_KLEIN_4B_BASE: "Flux.2 Klein 4B-base",
AURAFLOW: "AuraFlow", AURAFLOW: "AuraFlow",
CHROMA: "Chroma", CHROMA: "Chroma",
PIXART_A: "PixArt a", PIXART_A: "PixArt a",
@@ -40,10 +44,12 @@ export const BASE_MODELS = {
HIDREAM: "HiDream", HIDREAM: "HiDream",
QWEN: "Qwen", QWEN: "Qwen",
ZIMAGE_TURBO: "ZImageTurbo", ZIMAGE_TURBO: "ZImageTurbo",
ZIMAGE_BASE: "ZImageBase",
// Video models // Video models
SVD: "SVD", SVD: "SVD",
LTXV: "LTXV", LTXV: "LTXV",
LTXV2: "LTXV2",
WAN_VIDEO: "Wan Video", WAN_VIDEO: "Wan Video",
WAN_VIDEO_1_3B_T2V: "Wan Video 1.3B t2v", WAN_VIDEO_1_3B_T2V: "Wan Video 1.3B t2v",
WAN_VIDEO_14B_T2V: "Wan Video 14B t2v", WAN_VIDEO_14B_T2V: "Wan Video 14B t2v",
@@ -120,6 +126,10 @@ export const BASE_MODEL_ABBREVIATIONS = {
[BASE_MODELS.FLUX_1_KREA]: 'F1KR', [BASE_MODELS.FLUX_1_KREA]: 'F1KR',
[BASE_MODELS.FLUX_1_KONTEXT]: 'F1KX', [BASE_MODELS.FLUX_1_KONTEXT]: 'F1KX',
[BASE_MODELS.FLUX_2_D]: 'F2D', [BASE_MODELS.FLUX_2_D]: 'F2D',
[BASE_MODELS.FLUX_2_KLEIN_9B]: 'FK9',
[BASE_MODELS.FLUX_2_KLEIN_9B_BASE]: 'FK9B',
[BASE_MODELS.FLUX_2_KLEIN_4B]: 'FK4',
[BASE_MODELS.FLUX_2_KLEIN_4B_BASE]: 'FK4B',
// Other diffusion models // Other diffusion models
[BASE_MODELS.AURAFLOW]: 'AF', [BASE_MODELS.AURAFLOW]: 'AF',
@@ -135,10 +145,12 @@ export const BASE_MODEL_ABBREVIATIONS = {
[BASE_MODELS.HIDREAM]: 'HID', [BASE_MODELS.HIDREAM]: 'HID',
[BASE_MODELS.QWEN]: 'QWEN', [BASE_MODELS.QWEN]: 'QWEN',
[BASE_MODELS.ZIMAGE_TURBO]: 'ZIT', [BASE_MODELS.ZIMAGE_TURBO]: 'ZIT',
[BASE_MODELS.ZIMAGE_BASE]: 'ZIB',
// Video models // Video models
[BASE_MODELS.SVD]: 'SVD', [BASE_MODELS.SVD]: 'SVD',
[BASE_MODELS.LTXV]: 'LTXV', [BASE_MODELS.LTXV]: 'LTXV',
[BASE_MODELS.LTXV2]: 'LTV2',
[BASE_MODELS.WAN_VIDEO]: 'WAN', [BASE_MODELS.WAN_VIDEO]: 'WAN',
[BASE_MODELS.WAN_VIDEO_1_3B_T2V]: 'WAN', [BASE_MODELS.WAN_VIDEO_1_3B_T2V]: 'WAN',
[BASE_MODELS.WAN_VIDEO_14B_T2V]: 'WAN', [BASE_MODELS.WAN_VIDEO_14B_T2V]: 'WAN',
@@ -328,16 +340,16 @@ export const BASE_MODEL_CATEGORIES = {
'Stable Diffusion 3.x': [BASE_MODELS.SD_3, BASE_MODELS.SD_3_5, BASE_MODELS.SD_3_5_MEDIUM, BASE_MODELS.SD_3_5_LARGE, BASE_MODELS.SD_3_5_LARGE_TURBO], 'Stable Diffusion 3.x': [BASE_MODELS.SD_3, BASE_MODELS.SD_3_5, BASE_MODELS.SD_3_5_MEDIUM, BASE_MODELS.SD_3_5_LARGE, BASE_MODELS.SD_3_5_LARGE_TURBO],
'SDXL': [BASE_MODELS.SDXL, BASE_MODELS.SDXL_LIGHTNING, BASE_MODELS.SDXL_HYPER], 'SDXL': [BASE_MODELS.SDXL, BASE_MODELS.SDXL_LIGHTNING, BASE_MODELS.SDXL_HYPER],
'Video Models': [ 'Video Models': [
BASE_MODELS.SVD, BASE_MODELS.LTXV, BASE_MODELS.HUNYUAN_VIDEO, BASE_MODELS.WAN_VIDEO, BASE_MODELS.SVD, BASE_MODELS.LTXV, BASE_MODELS.LTXV2, BASE_MODELS.HUNYUAN_VIDEO, BASE_MODELS.WAN_VIDEO,
BASE_MODELS.WAN_VIDEO_1_3B_T2V, BASE_MODELS.WAN_VIDEO_14B_T2V, BASE_MODELS.WAN_VIDEO_1_3B_T2V, BASE_MODELS.WAN_VIDEO_14B_T2V,
BASE_MODELS.WAN_VIDEO_14B_I2V_480P, BASE_MODELS.WAN_VIDEO_14B_I2V_720P, BASE_MODELS.WAN_VIDEO_14B_I2V_480P, BASE_MODELS.WAN_VIDEO_14B_I2V_720P,
BASE_MODELS.WAN_VIDEO_2_2_TI2V_5B, BASE_MODELS.WAN_VIDEO_2_2_T2V_A14B, BASE_MODELS.WAN_VIDEO_2_2_TI2V_5B, BASE_MODELS.WAN_VIDEO_2_2_T2V_A14B,
BASE_MODELS.WAN_VIDEO_2_2_I2V_A14B BASE_MODELS.WAN_VIDEO_2_2_I2V_A14B
], ],
'Flux Models': [BASE_MODELS.FLUX_1_D, BASE_MODELS.FLUX_1_S, BASE_MODELS.FLUX_1_KONTEXT, BASE_MODELS.FLUX_1_KREA, BASE_MODELS.FLUX_2_D], 'Flux Models': [BASE_MODELS.FLUX_1_D, BASE_MODELS.FLUX_1_S, BASE_MODELS.FLUX_1_KONTEXT, BASE_MODELS.FLUX_1_KREA, BASE_MODELS.FLUX_2_D, BASE_MODELS.FLUX_2_KLEIN_9B, BASE_MODELS.FLUX_2_KLEIN_9B_BASE, BASE_MODELS.FLUX_2_KLEIN_4B, BASE_MODELS.FLUX_2_KLEIN_4B_BASE],
'Other Models': [ 'Other Models': [
BASE_MODELS.ILLUSTRIOUS, BASE_MODELS.PONY, BASE_MODELS.HIDREAM, BASE_MODELS.ILLUSTRIOUS, BASE_MODELS.PONY, BASE_MODELS.HIDREAM,
BASE_MODELS.QWEN, BASE_MODELS.AURAFLOW, BASE_MODELS.CHROMA, BASE_MODELS.ZIMAGE_TURBO, BASE_MODELS.QWEN, BASE_MODELS.AURAFLOW, BASE_MODELS.CHROMA, BASE_MODELS.ZIMAGE_TURBO, BASE_MODELS.ZIMAGE_BASE,
BASE_MODELS.PIXART_A, BASE_MODELS.PIXART_E, BASE_MODELS.HUNYUAN_1, BASE_MODELS.PIXART_A, BASE_MODELS.PIXART_E, BASE_MODELS.HUNYUAN_1,
BASE_MODELS.LUMINA, BASE_MODELS.KOLORS, BASE_MODELS.NOOBAI, BASE_MODELS.LUMINA, BASE_MODELS.KOLORS, BASE_MODELS.NOOBAI,
BASE_MODELS.UNKNOWN BASE_MODELS.UNKNOWN

View File

@@ -230,8 +230,58 @@ def test_new_symlink_triggers_rescan(monkeypatch: pytest.MonkeyPatch, tmp_path):
assert normalized_external in second_cfg._path_mappings assert normalized_external in second_cfg._path_mappings
def test_removed_deep_symlink_triggers_rescan(monkeypatch: pytest.MonkeyPatch, tmp_path): def test_removed_first_level_symlink_triggers_rescan(monkeypatch: pytest.MonkeyPatch, tmp_path):
"""Removing a deep symlink should trigger cache invalidation.""" """Removing a first-level symlink should trigger cache invalidation."""
loras_dir, settings_dir = _setup_paths(monkeypatch, tmp_path)
# Create first-level symlink (directly under loras root)
external_dir = tmp_path / "external"
external_dir.mkdir()
symlink = loras_dir / "external_models"
symlink.symlink_to(external_dir, target_is_directory=True)
# Initial scan finds the symlink
first_cfg = config_module.Config()
normalized_external = _normalize(str(external_dir))
assert normalized_external in first_cfg._path_mappings
# Remove the symlink
symlink.unlink()
# Second config should detect invalid cached mapping and rescan
second_cfg = config_module.Config()
assert normalized_external not in second_cfg._path_mappings
def test_retargeted_first_level_symlink_triggers_rescan(monkeypatch: pytest.MonkeyPatch, tmp_path):
"""Changing a first-level symlink's target should trigger cache invalidation."""
loras_dir, settings_dir = _setup_paths(monkeypatch, tmp_path)
# Create first-level symlink
target_v1 = tmp_path / "external_v1"
target_v1.mkdir()
target_v2 = tmp_path / "external_v2"
target_v2.mkdir()
symlink = loras_dir / "external_models"
symlink.symlink_to(target_v1, target_is_directory=True)
# Initial scan
first_cfg = config_module.Config()
assert _normalize(str(target_v1)) in first_cfg._path_mappings
# Retarget the symlink
symlink.unlink()
symlink.symlink_to(target_v2, target_is_directory=True)
# Second config should detect changed target and rescan
second_cfg = config_module.Config()
assert _normalize(str(target_v2)) in second_cfg._path_mappings
assert _normalize(str(target_v1)) not in second_cfg._path_mappings
def test_deep_symlink_not_scanned(monkeypatch: pytest.MonkeyPatch, tmp_path):
"""Deep symlinks (below first level) are not scanned to avoid performance issues."""
loras_dir, settings_dir = _setup_paths(monkeypatch, tmp_path) loras_dir, settings_dir = _setup_paths(monkeypatch, tmp_path)
# Create nested structure with deep symlink # Create nested structure with deep symlink
@@ -242,46 +292,12 @@ def test_removed_deep_symlink_triggers_rescan(monkeypatch: pytest.MonkeyPatch, t
deep_symlink = subdir / "styles" deep_symlink = subdir / "styles"
deep_symlink.symlink_to(external_dir, target_is_directory=True) deep_symlink.symlink_to(external_dir, target_is_directory=True)
# Initial scan finds the deep symlink # Config should not detect deep symlinks (only first-level)
first_cfg = config_module.Config() cfg = config_module.Config()
normalized_external = _normalize(str(external_dir)) normalized_external = _normalize(str(external_dir))
assert normalized_external in first_cfg._path_mappings assert normalized_external not in cfg._path_mappings
# Remove the deep symlink
deep_symlink.unlink()
# Second config should detect invalid cached mapping and rescan
second_cfg = config_module.Config()
assert normalized_external not in second_cfg._path_mappings
def test_retargeted_deep_symlink_triggers_rescan(monkeypatch: pytest.MonkeyPatch, tmp_path):
"""Changing a deep symlink's target should trigger cache invalidation."""
loras_dir, settings_dir = _setup_paths(monkeypatch, tmp_path)
# Create nested structure
subdir = loras_dir / "anime"
subdir.mkdir()
target_v1 = tmp_path / "external_v1"
target_v1.mkdir()
target_v2 = tmp_path / "external_v2"
target_v2.mkdir()
deep_symlink = subdir / "styles"
deep_symlink.symlink_to(target_v1, target_is_directory=True)
# Initial scan
first_cfg = config_module.Config()
assert _normalize(str(target_v1)) in first_cfg._path_mappings
# Retarget the symlink
deep_symlink.unlink()
deep_symlink.symlink_to(target_v2, target_is_directory=True)
# Second config should detect changed target and rescan
second_cfg = config_module.Config()
assert _normalize(str(target_v2)) in second_cfg._path_mappings
assert _normalize(str(target_v1)) not in second_cfg._path_mappings
def test_legacy_symlink_cache_automatic_cleanup(monkeypatch: pytest.MonkeyPatch, tmp_path): def test_legacy_symlink_cache_automatic_cleanup(monkeypatch: pytest.MonkeyPatch, tmp_path):
"""Test that legacy symlink cache is automatically cleaned up after migration.""" """Test that legacy symlink cache is automatically cleaned up after migration."""
settings_dir = tmp_path / "settings" settings_dir = tmp_path / "settings"

View File

@@ -47,6 +47,8 @@ class StubDownloadManager:
self.resume_error: Exception | None = None self.resume_error: Exception | None = None
self.stop_error: Exception | None = None self.stop_error: Exception | None = None
self.force_error: Exception | None = None self.force_error: Exception | None = None
self.check_pending_result: dict[str, Any] | None = None
self.check_pending_calls: list[list[str]] = []
async def get_status(self, request: web.Request) -> dict[str, Any]: async def get_status(self, request: web.Request) -> dict[str, Any]:
return {"success": True, "status": "idle"} return {"success": True, "status": "idle"}
@@ -75,6 +77,20 @@ class StubDownloadManager:
raise self.force_error raise self.force_error
return {"success": True, "payload": payload} return {"success": True, "payload": payload}
async def check_pending_models(self, model_types: list[str]) -> dict[str, Any]:
self.check_pending_calls.append(model_types)
if self.check_pending_result is not None:
return self.check_pending_result
return {
"success": True,
"is_downloading": False,
"total_models": 100,
"pending_count": 10,
"processed_count": 90,
"failed_count": 0,
"needs_download": True,
}
class StubImportUseCase: class StubImportUseCase:
def __init__(self) -> None: def __init__(self) -> None:
@@ -236,3 +252,123 @@ async def test_import_route_returns_validation_errors():
assert response.status == 400 assert response.status == 400
body = await _json(response) body = await _json(response)
assert body == {"success": False, "error": "bad payload"} assert body == {"success": False, "error": "bad payload"}
async def test_check_example_images_needed_returns_pending_counts():
"""Test that check_example_images_needed endpoint returns pending model counts."""
async with registrar_app() as harness:
harness.download_manager.check_pending_result = {
"success": True,
"is_downloading": False,
"total_models": 5500,
"pending_count": 12,
"processed_count": 5488,
"failed_count": 45,
"needs_download": True,
}
response = await harness.client.post(
"/api/lm/check-example-images-needed",
json={"model_types": ["lora", "checkpoint"]},
)
assert response.status == 200
body = await _json(response)
assert body["success"] is True
assert body["total_models"] == 5500
assert body["pending_count"] == 12
assert body["processed_count"] == 5488
assert body["failed_count"] == 45
assert body["needs_download"] is True
assert body["is_downloading"] is False
# Verify the manager was called with correct model types
assert harness.download_manager.check_pending_calls == [["lora", "checkpoint"]]
async def test_check_example_images_needed_handles_download_in_progress():
"""Test that check_example_images_needed returns correct status when download is running."""
async with registrar_app() as harness:
harness.download_manager.check_pending_result = {
"success": True,
"is_downloading": True,
"total_models": 0,
"pending_count": 0,
"processed_count": 0,
"failed_count": 0,
"needs_download": False,
"message": "Download already in progress",
}
response = await harness.client.post(
"/api/lm/check-example-images-needed",
json={"model_types": ["lora"]},
)
assert response.status == 200
body = await _json(response)
assert body["success"] is True
assert body["is_downloading"] is True
assert body["needs_download"] is False
async def test_check_example_images_needed_handles_no_pending_models():
"""Test that check_example_images_needed returns correct status when no work is needed."""
async with registrar_app() as harness:
harness.download_manager.check_pending_result = {
"success": True,
"is_downloading": False,
"total_models": 5500,
"pending_count": 0,
"processed_count": 5500,
"failed_count": 0,
"needs_download": False,
}
response = await harness.client.post(
"/api/lm/check-example-images-needed",
json={"model_types": ["lora", "checkpoint", "embedding"]},
)
assert response.status == 200
body = await _json(response)
assert body["success"] is True
assert body["pending_count"] == 0
assert body["needs_download"] is False
assert body["processed_count"] == 5500
async def test_check_example_images_needed_uses_default_model_types():
"""Test that check_example_images_needed uses default model types when not specified."""
async with registrar_app() as harness:
response = await harness.client.post(
"/api/lm/check-example-images-needed",
json={}, # No model_types specified
)
assert response.status == 200
# Should use default model types
assert harness.download_manager.check_pending_calls == [["lora", "checkpoint", "embedding"]]
async def test_check_example_images_needed_returns_error_on_exception():
"""Test that check_example_images_needed returns 500 on internal error."""
async with registrar_app() as harness:
# Simulate an error by setting result to an error state
# Actually, we need to make the method raise an exception
original_method = harness.download_manager.check_pending_models
async def failing_check(_model_types):
raise RuntimeError("Database connection failed")
harness.download_manager.check_pending_models = failing_check
response = await harness.client.post(
"/api/lm/check-example-images-needed",
json={"model_types": ["lora"]},
)
assert response.status == 500
body = await _json(response)
assert body["success"] is False
assert "Database connection failed" in body["error"]

View File

@@ -502,6 +502,7 @@ def test_handler_set_route_mapping_includes_all_handlers() -> None:
"resume_example_images", "resume_example_images",
"stop_example_images", "stop_example_images",
"force_download_example_images", "force_download_example_images",
"check_example_images_needed",
"import_example_images", "import_example_images",
"delete_example_image", "delete_example_image",
"set_example_image_nsfw_level", "set_example_image_nsfw_level",

View File

@@ -0,0 +1,283 @@
"""
Unit tests for CacheEntryValidator
"""
import pytest
from py.services.cache_entry_validator import (
CacheEntryValidator,
ValidationResult,
)
class TestCacheEntryValidator:
"""Tests for CacheEntryValidator class"""
def test_validate_valid_entry(self):
"""Test validation of a valid cache entry"""
entry = {
'file_path': '/models/test.safetensors',
'sha256': 'abc123def456',
'file_name': 'test.safetensors',
'model_name': 'Test Model',
'size': 1024,
'modified': 1234567890.0,
'tags': ['tag1', 'tag2'],
}
result = CacheEntryValidator.validate(entry, auto_repair=False)
assert result.is_valid is True
assert result.repaired is False
assert len(result.errors) == 0
assert result.entry == entry
def test_validate_missing_required_field_sha256(self):
"""Test validation fails when required sha256 field is missing"""
entry = {
'file_path': '/models/test.safetensors',
# sha256 missing
'file_name': 'test.safetensors',
}
result = CacheEntryValidator.validate(entry, auto_repair=False)
assert result.is_valid is False
assert result.repaired is False
assert any('sha256' in error for error in result.errors)
def test_validate_missing_required_field_file_path(self):
"""Test validation fails when required file_path field is missing"""
entry = {
# file_path missing
'sha256': 'abc123def456',
'file_name': 'test.safetensors',
}
result = CacheEntryValidator.validate(entry, auto_repair=False)
assert result.is_valid is False
assert result.repaired is False
assert any('file_path' in error for error in result.errors)
def test_validate_empty_required_field_sha256(self):
"""Test validation fails when sha256 is empty string"""
entry = {
'file_path': '/models/test.safetensors',
'sha256': '', # Empty string
}
result = CacheEntryValidator.validate(entry, auto_repair=False)
assert result.is_valid is False
assert result.repaired is False
assert any('sha256' in error for error in result.errors)
def test_validate_empty_required_field_file_path(self):
"""Test validation fails when file_path is empty string"""
entry = {
'file_path': '', # Empty string
'sha256': 'abc123def456',
}
result = CacheEntryValidator.validate(entry, auto_repair=False)
assert result.is_valid is False
assert result.repaired is False
assert any('file_path' in error for error in result.errors)
def test_validate_none_required_field(self):
"""Test validation fails when required field is None"""
entry = {
'file_path': None,
'sha256': 'abc123def456',
}
result = CacheEntryValidator.validate(entry, auto_repair=False)
assert result.is_valid is False
assert result.repaired is False
assert any('file_path' in error for error in result.errors)
def test_validate_none_entry(self):
"""Test validation handles None entry"""
result = CacheEntryValidator.validate(None, auto_repair=False)
assert result.is_valid is False
assert result.repaired is False
assert any('None' in error for error in result.errors)
assert result.entry is None
def test_validate_non_dict_entry(self):
"""Test validation handles non-dict entry"""
result = CacheEntryValidator.validate("not a dict", auto_repair=False)
assert result.is_valid is False
assert result.repaired is False
assert any('not a dict' in error for error in result.errors)
assert result.entry is None
def test_auto_repair_missing_non_required_field(self):
"""Test auto-repair adds missing non-required fields"""
entry = {
'file_path': '/models/test.safetensors',
'sha256': 'abc123def456',
# file_name, model_name, tags missing
}
result = CacheEntryValidator.validate(entry, auto_repair=True)
assert result.is_valid is True
assert result.repaired is True
assert result.entry['file_name'] == ''
assert result.entry['model_name'] == ''
assert result.entry['tags'] == []
def test_auto_repair_wrong_type_field(self):
"""Test auto-repair fixes fields with wrong type"""
entry = {
'file_path': '/models/test.safetensors',
'sha256': 'abc123def456',
'size': 'not a number', # Should be int
'tags': 'not a list', # Should be list
}
result = CacheEntryValidator.validate(entry, auto_repair=True)
assert result.is_valid is True
assert result.repaired is True
assert result.entry['size'] == 0 # Default value
assert result.entry['tags'] == [] # Default value
def test_normalize_sha256_lowercase(self):
"""Test sha256 is normalized to lowercase"""
entry = {
'file_path': '/models/test.safetensors',
'sha256': 'ABC123DEF456', # Uppercase
}
result = CacheEntryValidator.validate(entry, auto_repair=True)
assert result.is_valid is True
assert result.entry['sha256'] == 'abc123def456'
def test_validate_batch_all_valid(self):
"""Test batch validation with all valid entries"""
entries = [
{
'file_path': '/models/test1.safetensors',
'sha256': 'abc123',
},
{
'file_path': '/models/test2.safetensors',
'sha256': 'def456',
},
]
valid, invalid = CacheEntryValidator.validate_batch(entries, auto_repair=False)
assert len(valid) == 2
assert len(invalid) == 0
def test_validate_batch_mixed_validity(self):
"""Test batch validation with mixed valid/invalid entries"""
entries = [
{
'file_path': '/models/test1.safetensors',
'sha256': 'abc123',
},
{
'file_path': '/models/test2.safetensors',
# sha256 missing - invalid
},
{
'file_path': '/models/test3.safetensors',
'sha256': 'def456',
},
]
valid, invalid = CacheEntryValidator.validate_batch(entries, auto_repair=False)
assert len(valid) == 2
assert len(invalid) == 1
# invalid list contains the actual invalid entries (not by index)
assert invalid[0]['file_path'] == '/models/test2.safetensors'
def test_validate_batch_empty_list(self):
"""Test batch validation with empty list"""
valid, invalid = CacheEntryValidator.validate_batch([], auto_repair=False)
assert len(valid) == 0
assert len(invalid) == 0
def test_get_file_path_safe(self):
"""Test safe file_path extraction"""
entry = {'file_path': '/models/test.safetensors', 'sha256': 'abc123'}
assert CacheEntryValidator.get_file_path_safe(entry) == '/models/test.safetensors'
def test_get_file_path_safe_missing(self):
"""Test safe file_path extraction when missing"""
entry = {'sha256': 'abc123'}
assert CacheEntryValidator.get_file_path_safe(entry) == ''
def test_get_file_path_safe_not_dict(self):
"""Test safe file_path extraction from non-dict"""
assert CacheEntryValidator.get_file_path_safe(None) == ''
assert CacheEntryValidator.get_file_path_safe('string') == ''
def test_get_sha256_safe(self):
"""Test safe sha256 extraction"""
entry = {'file_path': '/models/test.safetensors', 'sha256': 'ABC123'}
assert CacheEntryValidator.get_sha256_safe(entry) == 'abc123'
def test_get_sha256_safe_missing(self):
"""Test safe sha256 extraction when missing"""
entry = {'file_path': '/models/test.safetensors'}
assert CacheEntryValidator.get_sha256_safe(entry) == ''
def test_get_sha256_safe_not_dict(self):
"""Test safe sha256 extraction from non-dict"""
assert CacheEntryValidator.get_sha256_safe(None) == ''
assert CacheEntryValidator.get_sha256_safe('string') == ''
def test_validate_with_all_optional_fields(self):
"""Test validation with all optional fields present"""
entry = {
'file_path': '/models/test.safetensors',
'sha256': 'abc123',
'file_name': 'test.safetensors',
'model_name': 'Test Model',
'folder': 'test_folder',
'size': 1024,
'modified': 1234567890.0,
'tags': ['tag1', 'tag2'],
'preview_url': 'http://example.com/preview.jpg',
'base_model': 'SD1.5',
'from_civitai': True,
'favorite': True,
'exclude': False,
'db_checked': True,
'preview_nsfw_level': 1,
'notes': 'Test notes',
'usage_tips': 'Test tips',
}
result = CacheEntryValidator.validate(entry, auto_repair=False)
assert result.is_valid is True
assert result.repaired is False
assert result.entry == entry
def test_validate_numeric_field_accepts_float_for_int(self):
"""Test that numeric fields accept float for int type"""
entry = {
'file_path': '/models/test.safetensors',
'sha256': 'abc123',
'size': 1024.5, # Float for int field
'modified': 1234567890.0,
}
result = CacheEntryValidator.validate(entry, auto_repair=False)
assert result.is_valid is True
assert result.repaired is False

View File

@@ -0,0 +1,364 @@
"""
Unit tests for CacheHealthMonitor
"""
import pytest
from py.services.cache_health_monitor import (
CacheHealthMonitor,
CacheHealthStatus,
HealthReport,
)
class TestCacheHealthMonitor:
"""Tests for CacheHealthMonitor class"""
def test_check_health_all_valid_entries(self):
"""Test health check with 100% valid entries"""
monitor = CacheHealthMonitor()
entries = [
{
'file_path': f'/models/test{i}.safetensors',
'sha256': f'hash{i}',
}
for i in range(100)
]
report = monitor.check_health(entries, auto_repair=False)
assert report.status == CacheHealthStatus.HEALTHY
assert report.total_entries == 100
assert report.valid_entries == 100
assert report.invalid_entries == 0
assert report.repaired_entries == 0
assert report.corruption_rate == 0.0
assert report.message == "Cache is healthy"
def test_check_health_degraded_cache(self):
"""Test health check with 1-5% invalid entries (degraded)"""
monitor = CacheHealthMonitor()
# Create 100 entries, 2 invalid (2%)
entries = [
{
'file_path': f'/models/test{i}.safetensors',
'sha256': f'hash{i}',
}
for i in range(98)
]
# Add 2 invalid entries
entries.append({'file_path': '/models/invalid1.safetensors'}) # Missing sha256
entries.append({'file_path': '/models/invalid2.safetensors'}) # Missing sha256
report = monitor.check_health(entries, auto_repair=False)
assert report.status == CacheHealthStatus.DEGRADED
assert report.total_entries == 100
assert report.valid_entries == 98
assert report.invalid_entries == 2
assert report.corruption_rate == 0.02
# Message describes the issue without necessarily containing the word "degraded"
assert 'invalid entries' in report.message.lower()
def test_check_health_corrupted_cache(self):
"""Test health check with >5% invalid entries (corrupted)"""
monitor = CacheHealthMonitor()
# Create 100 entries, 10 invalid (10%)
entries = [
{
'file_path': f'/models/test{i}.safetensors',
'sha256': f'hash{i}',
}
for i in range(90)
]
# Add 10 invalid entries
for i in range(10):
entries.append({'file_path': f'/models/invalid{i}.safetensors'})
report = monitor.check_health(entries, auto_repair=False)
assert report.status == CacheHealthStatus.CORRUPTED
assert report.total_entries == 100
assert report.valid_entries == 90
assert report.invalid_entries == 10
assert report.corruption_rate == 0.10
assert 'corrupted' in report.message.lower()
def test_check_health_empty_cache(self):
"""Test health check with empty cache"""
monitor = CacheHealthMonitor()
report = monitor.check_health([], auto_repair=False)
assert report.status == CacheHealthStatus.HEALTHY
assert report.total_entries == 0
assert report.valid_entries == 0
assert report.invalid_entries == 0
assert report.corruption_rate == 0.0
assert report.message == "Cache is empty"
def test_check_health_single_invalid_entry(self):
"""Test health check with 1 invalid entry out of 1 (100% corruption)"""
monitor = CacheHealthMonitor()
entries = [{'file_path': '/models/invalid.safetensors'}]
report = monitor.check_health(entries, auto_repair=False)
assert report.status == CacheHealthStatus.CORRUPTED
assert report.total_entries == 1
assert report.valid_entries == 0
assert report.invalid_entries == 1
assert report.corruption_rate == 1.0
def test_check_health_boundary_degraded_threshold(self):
"""Test health check at degraded threshold (1%)"""
monitor = CacheHealthMonitor(degraded_threshold=0.01)
# 100 entries, 1 invalid (exactly 1%)
entries = [
{
'file_path': f'/models/test{i}.safetensors',
'sha256': f'hash{i}',
}
for i in range(99)
]
entries.append({'file_path': '/models/invalid.safetensors'})
report = monitor.check_health(entries, auto_repair=False)
assert report.status == CacheHealthStatus.DEGRADED
assert report.corruption_rate == 0.01
def test_check_health_boundary_corrupted_threshold(self):
"""Test health check at corrupted threshold (5%)"""
monitor = CacheHealthMonitor(corrupted_threshold=0.05)
# 100 entries, 5 invalid (exactly 5%)
entries = [
{
'file_path': f'/models/test{i}.safetensors',
'sha256': f'hash{i}',
}
for i in range(95)
]
for i in range(5):
entries.append({'file_path': f'/models/invalid{i}.safetensors'})
report = monitor.check_health(entries, auto_repair=False)
assert report.status == CacheHealthStatus.CORRUPTED
assert report.corruption_rate == 0.05
def test_check_health_below_degraded_threshold(self):
"""Test health check below degraded threshold (0%)"""
monitor = CacheHealthMonitor(degraded_threshold=0.01)
# All entries valid
entries = [
{
'file_path': f'/models/test{i}.safetensors',
'sha256': f'hash{i}',
}
for i in range(100)
]
report = monitor.check_health(entries, auto_repair=False)
assert report.status == CacheHealthStatus.HEALTHY
assert report.corruption_rate == 0.0
def test_check_health_auto_repair(self):
"""Test health check with auto_repair enabled"""
monitor = CacheHealthMonitor()
# 1 entry with all fields (won't be repaired), 1 entry with missing non-required fields (will be repaired)
complete_entry = {
'file_path': '/models/test1.safetensors',
'sha256': 'hash1',
'file_name': 'test1.safetensors',
'model_name': 'Model 1',
'folder': '',
'size': 0,
'modified': 0.0,
'tags': ['tag1'],
'preview_url': '',
'base_model': '',
'from_civitai': True,
'favorite': False,
'exclude': False,
'db_checked': False,
'preview_nsfw_level': 0,
'notes': '',
'usage_tips': '',
}
incomplete_entry = {
'file_path': '/models/test2.safetensors',
'sha256': 'hash2',
# Missing many optional fields (will be repaired)
}
entries = [complete_entry, incomplete_entry]
report = monitor.check_health(entries, auto_repair=True)
assert report.status == CacheHealthStatus.HEALTHY
assert report.total_entries == 2
assert report.valid_entries == 2
assert report.invalid_entries == 0
assert report.repaired_entries == 1
def test_should_notify_user_healthy(self):
"""Test should_notify_user for healthy cache"""
monitor = CacheHealthMonitor()
report = HealthReport(
status=CacheHealthStatus.HEALTHY,
total_entries=100,
valid_entries=100,
invalid_entries=0,
repaired_entries=0,
message="Cache is healthy"
)
assert monitor.should_notify_user(report) is False
def test_should_notify_user_degraded(self):
"""Test should_notify_user for degraded cache"""
monitor = CacheHealthMonitor()
report = HealthReport(
status=CacheHealthStatus.DEGRADED,
total_entries=100,
valid_entries=98,
invalid_entries=2,
repaired_entries=0,
message="Cache is degraded"
)
assert monitor.should_notify_user(report) is True
def test_should_notify_user_corrupted(self):
"""Test should_notify_user for corrupted cache"""
monitor = CacheHealthMonitor()
report = HealthReport(
status=CacheHealthStatus.CORRUPTED,
total_entries=100,
valid_entries=90,
invalid_entries=10,
repaired_entries=0,
message="Cache is corrupted"
)
assert monitor.should_notify_user(report) is True
def test_get_notification_severity_degraded(self):
"""Test get_notification_severity for degraded cache"""
monitor = CacheHealthMonitor()
report = HealthReport(
status=CacheHealthStatus.DEGRADED,
total_entries=100,
valid_entries=98,
invalid_entries=2,
repaired_entries=0,
message="Cache is degraded"
)
assert monitor.get_notification_severity(report) == 'warning'
def test_get_notification_severity_corrupted(self):
"""Test get_notification_severity for corrupted cache"""
monitor = CacheHealthMonitor()
report = HealthReport(
status=CacheHealthStatus.CORRUPTED,
total_entries=100,
valid_entries=90,
invalid_entries=10,
repaired_entries=0,
message="Cache is corrupted"
)
assert monitor.get_notification_severity(report) == 'error'
def test_report_to_dict(self):
"""Test HealthReport to_dict conversion"""
report = HealthReport(
status=CacheHealthStatus.DEGRADED,
total_entries=100,
valid_entries=98,
invalid_entries=2,
repaired_entries=1,
invalid_paths=['/path1', '/path2'],
message="Cache issues detected"
)
result = report.to_dict()
assert result['status'] == 'degraded'
assert result['total_entries'] == 100
assert result['valid_entries'] == 98
assert result['invalid_entries'] == 2
assert result['repaired_entries'] == 1
assert result['corruption_rate'] == '2.0%'
assert len(result['invalid_paths']) == 2
assert result['message'] == "Cache issues detected"
def test_report_corruption_rate_zero_division(self):
"""Test corruption_rate calculation with zero entries"""
report = HealthReport(
status=CacheHealthStatus.HEALTHY,
total_entries=0,
valid_entries=0,
invalid_entries=0,
repaired_entries=0,
message="Cache is empty"
)
assert report.corruption_rate == 0.0
def test_check_health_collects_invalid_paths(self):
"""Test health check collects invalid entry paths"""
monitor = CacheHealthMonitor()
entries = [
{
'file_path': '/models/valid.safetensors',
'sha256': 'hash1',
},
{
'file_path': '/models/invalid1.safetensors',
},
{
'file_path': '/models/invalid2.safetensors',
},
]
report = monitor.check_health(entries, auto_repair=False)
assert len(report.invalid_paths) == 2
assert '/models/invalid1.safetensors' in report.invalid_paths
assert '/models/invalid2.safetensors' in report.invalid_paths
def test_report_to_dict_limits_invalid_paths(self):
"""Test that to_dict limits invalid_paths to first 10"""
report = HealthReport(
status=CacheHealthStatus.CORRUPTED,
total_entries=15,
valid_entries=0,
invalid_entries=15,
repaired_entries=0,
invalid_paths=[f'/path{i}' for i in range(15)],
message="Cache corrupted"
)
result = report.to_dict()
assert len(result['invalid_paths']) == 10
assert result['invalid_paths'][0] == '/path0'
assert result['invalid_paths'][-1] == '/path9'

View File

@@ -0,0 +1,368 @@
"""Tests for the check_pending_models lightweight pre-check functionality."""
from __future__ import annotations
import json
from types import SimpleNamespace
import pytest
from py.services.settings_manager import get_settings_manager
from py.utils import example_images_download_manager as download_module
class StubScanner:
"""Scanner double returning predetermined cache contents."""
def __init__(self, models: list[dict]) -> None:
self._cache = SimpleNamespace(raw_data=models)
async def get_cached_data(self):
return self._cache
def _patch_scanners(
monkeypatch: pytest.MonkeyPatch,
lora_scanner: StubScanner | None = None,
checkpoint_scanner: StubScanner | None = None,
embedding_scanner: StubScanner | None = None,
) -> None:
"""Patch ServiceRegistry to return stub scanners."""
async def _get_lora_scanner(cls):
return lora_scanner or StubScanner([])
async def _get_checkpoint_scanner(cls):
return checkpoint_scanner or StubScanner([])
async def _get_embedding_scanner(cls):
return embedding_scanner or StubScanner([])
monkeypatch.setattr(
download_module.ServiceRegistry,
"get_lora_scanner",
classmethod(_get_lora_scanner),
)
monkeypatch.setattr(
download_module.ServiceRegistry,
"get_checkpoint_scanner",
classmethod(_get_checkpoint_scanner),
)
monkeypatch.setattr(
download_module.ServiceRegistry,
"get_embedding_scanner",
classmethod(_get_embedding_scanner),
)
class RecordingWebSocketManager:
"""Collects broadcast payloads for assertions."""
def __init__(self) -> None:
self.payloads: list[dict] = []
async def broadcast(self, payload: dict) -> None:
self.payloads.append(payload)
@pytest.mark.asyncio
@pytest.mark.usefixtures("tmp_path")
async def test_check_pending_models_returns_zero_when_all_processed(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
settings_manager,
):
"""Test that check_pending_models returns 0 pending when all models are processed."""
ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager)
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
# Create processed models
processed_hashes = ["a" * 64, "b" * 64, "c" * 64]
models = [
{"sha256": h, "model_name": f"Model {i}"}
for i, h in enumerate(processed_hashes)
]
# Create progress file with all models processed
progress_file = tmp_path / ".download_progress.json"
progress_file.write_text(
json.dumps({"processed_models": processed_hashes, "failed_models": []}),
encoding="utf-8",
)
# Create model directories with files (simulating completed downloads)
for h in processed_hashes:
model_dir = tmp_path / h
model_dir.mkdir()
(model_dir / "image_0.png").write_text("data")
_patch_scanners(monkeypatch, lora_scanner=StubScanner(models))
result = await manager.check_pending_models(["lora"])
assert result["success"] is True
assert result["is_downloading"] is False
assert result["total_models"] == 3
assert result["pending_count"] == 0
assert result["processed_count"] == 3
assert result["needs_download"] is False
@pytest.mark.asyncio
@pytest.mark.usefixtures("tmp_path")
async def test_check_pending_models_finds_unprocessed_models(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
settings_manager,
):
"""Test that check_pending_models correctly identifies unprocessed models."""
ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager)
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
# Create models - some processed, some not
processed_hash = "a" * 64
unprocessed_hash = "b" * 64
models = [
{"sha256": processed_hash, "model_name": "Processed Model"},
{"sha256": unprocessed_hash, "model_name": "Unprocessed Model"},
]
# Create progress file with only one model processed
progress_file = tmp_path / ".download_progress.json"
progress_file.write_text(
json.dumps({"processed_models": [processed_hash], "failed_models": []}),
encoding="utf-8",
)
# Create directory only for processed model
processed_dir = tmp_path / processed_hash
processed_dir.mkdir()
(processed_dir / "image_0.png").write_text("data")
_patch_scanners(monkeypatch, lora_scanner=StubScanner(models))
result = await manager.check_pending_models(["lora"])
assert result["success"] is True
assert result["total_models"] == 2
assert result["pending_count"] == 1
assert result["processed_count"] == 1
assert result["needs_download"] is True
@pytest.mark.asyncio
@pytest.mark.usefixtures("tmp_path")
async def test_check_pending_models_skips_models_without_hash(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
settings_manager,
):
"""Test that models without sha256 are not counted as pending."""
ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager)
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
# Models - one with hash, one without
models = [
{"sha256": "a" * 64, "model_name": "Hashed Model"},
{"sha256": None, "model_name": "No Hash Model"},
{"model_name": "Missing Hash Model"}, # No sha256 key at all
]
_patch_scanners(monkeypatch, lora_scanner=StubScanner(models))
result = await manager.check_pending_models(["lora"])
assert result["success"] is True
assert result["total_models"] == 3
assert result["pending_count"] == 1 # Only the one with hash
assert result["needs_download"] is True
@pytest.mark.asyncio
@pytest.mark.usefixtures("tmp_path")
async def test_check_pending_models_handles_multiple_model_types(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
settings_manager,
):
"""Test that check_pending_models aggregates counts across multiple model types."""
ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager)
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
lora_models = [
{"sha256": "a" * 64, "model_name": "Lora 1"},
{"sha256": "b" * 64, "model_name": "Lora 2"},
]
checkpoint_models = [
{"sha256": "c" * 64, "model_name": "Checkpoint 1"},
]
embedding_models = [
{"sha256": "d" * 64, "model_name": "Embedding 1"},
{"sha256": "e" * 64, "model_name": "Embedding 2"},
{"sha256": "f" * 64, "model_name": "Embedding 3"},
]
_patch_scanners(
monkeypatch,
lora_scanner=StubScanner(lora_models),
checkpoint_scanner=StubScanner(checkpoint_models),
embedding_scanner=StubScanner(embedding_models),
)
result = await manager.check_pending_models(["lora", "checkpoint", "embedding"])
assert result["success"] is True
assert result["total_models"] == 6 # 2 + 1 + 3
assert result["pending_count"] == 6 # All unprocessed
assert result["needs_download"] is True
@pytest.mark.asyncio
@pytest.mark.usefixtures("tmp_path")
async def test_check_pending_models_returns_error_when_download_in_progress(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
settings_manager,
):
"""Test that check_pending_models returns special response when download is running."""
ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager)
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
# Simulate download in progress
manager._is_downloading = True
result = await manager.check_pending_models(["lora"])
assert result["success"] is True
assert result["is_downloading"] is True
assert result["needs_download"] is False
assert result["pending_count"] == 0
assert "already in progress" in result["message"].lower()
@pytest.mark.asyncio
@pytest.mark.usefixtures("tmp_path")
async def test_check_pending_models_handles_empty_library(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
settings_manager,
):
"""Test that check_pending_models handles empty model library."""
ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager)
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
_patch_scanners(monkeypatch, lora_scanner=StubScanner([]))
result = await manager.check_pending_models(["lora"])
assert result["success"] is True
assert result["total_models"] == 0
assert result["pending_count"] == 0
assert result["processed_count"] == 0
assert result["needs_download"] is False
@pytest.mark.asyncio
@pytest.mark.usefixtures("tmp_path")
async def test_check_pending_models_reads_failed_models(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
settings_manager,
):
"""Test that check_pending_models correctly reports failed model count."""
ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager)
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
models = [{"sha256": "a" * 64, "model_name": "Model"}]
# Create progress file with failed models
progress_file = tmp_path / ".download_progress.json"
progress_file.write_text(
json.dumps({"processed_models": [], "failed_models": ["a" * 64, "b" * 64]}),
encoding="utf-8",
)
_patch_scanners(monkeypatch, lora_scanner=StubScanner(models))
result = await manager.check_pending_models(["lora"])
assert result["success"] is True
assert result["failed_count"] == 2
@pytest.mark.asyncio
@pytest.mark.usefixtures("tmp_path")
async def test_check_pending_models_handles_missing_progress_file(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
settings_manager,
):
"""Test that check_pending_models works correctly when no progress file exists."""
ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager)
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
models = [
{"sha256": "a" * 64, "model_name": "Model 1"},
{"sha256": "b" * 64, "model_name": "Model 2"},
]
_patch_scanners(monkeypatch, lora_scanner=StubScanner(models))
# No progress file created
result = await manager.check_pending_models(["lora"])
assert result["success"] is True
assert result["total_models"] == 2
assert result["pending_count"] == 2 # All pending since no progress
assert result["processed_count"] == 0
assert result["failed_count"] == 0
assert result["needs_download"] is True
@pytest.mark.asyncio
@pytest.mark.usefixtures("tmp_path")
async def test_check_pending_models_handles_corrupted_progress_file(
monkeypatch: pytest.MonkeyPatch,
tmp_path,
settings_manager,
):
"""Test that check_pending_models handles corrupted progress file gracefully."""
ws_manager = RecordingWebSocketManager()
manager = download_module.DownloadManager(ws_manager=ws_manager)
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
models = [{"sha256": "a" * 64, "model_name": "Model"}]
# Create corrupted progress file
progress_file = tmp_path / ".download_progress.json"
progress_file.write_text("not valid json", encoding="utf-8")
_patch_scanners(monkeypatch, lora_scanner=StubScanner(models))
result = await manager.check_pending_models(["lora"])
# Should still succeed, treating all as unprocessed
assert result["success"] is True
assert result["total_models"] == 1
assert result["pending_count"] == 1
@pytest.fixture
def settings_manager():
return get_settings_manager()

View File

@@ -0,0 +1,167 @@
"""
Integration tests for cache validation in ModelScanner
"""
import pytest
import asyncio
from py.services.model_scanner import ModelScanner
from py.services.cache_entry_validator import CacheEntryValidator
from py.services.cache_health_monitor import CacheHealthMonitor, CacheHealthStatus
@pytest.mark.asyncio
async def test_model_scanner_validates_cache_entries(tmp_path_factory):
"""Test that ModelScanner validates cache entries during initialization"""
# Create temporary test data
tmp_dir = tmp_path_factory.mktemp("test_loras")
# Create test files
test_file = tmp_dir / "test_model.safetensors"
test_file.write_bytes(b"fake model data" * 100)
# Mock model scanner (we can't easily instantiate a full scanner in tests)
# Instead, test the validation logic directly
entries = [
{
'file_path': str(test_file),
'sha256': 'abc123def456',
'file_name': 'test_model.safetensors',
},
{
'file_path': str(tmp_dir / 'invalid.safetensors'),
# Missing sha256 - invalid
},
]
valid, invalid = CacheEntryValidator.validate_batch(entries, auto_repair=True)
assert len(valid) == 1
assert len(invalid) == 1
assert valid[0]['sha256'] == 'abc123def456'
@pytest.mark.asyncio
async def test_model_scanner_detects_degraded_cache():
"""Test that ModelScanner detects degraded cache health"""
# Create 100 entries with 2% corruption
entries = [
{
'file_path': f'/models/test{i}.safetensors',
'sha256': f'hash{i}',
}
for i in range(98)
]
# Add 2 invalid entries
entries.append({'file_path': '/models/invalid1.safetensors'})
entries.append({'file_path': '/models/invalid2.safetensors'})
monitor = CacheHealthMonitor()
report = monitor.check_health(entries, auto_repair=True)
assert report.status == CacheHealthStatus.DEGRADED
assert report.invalid_entries == 2
assert report.valid_entries == 98
@pytest.mark.asyncio
async def test_model_scanner_detects_corrupted_cache():
"""Test that ModelScanner detects corrupted cache health"""
# Create 100 entries with 10% corruption
entries = [
{
'file_path': f'/models/test{i}.safetensors',
'sha256': f'hash{i}',
}
for i in range(90)
]
# Add 10 invalid entries
for i in range(10):
entries.append({'file_path': f'/models/invalid{i}.safetensors'})
monitor = CacheHealthMonitor()
report = monitor.check_health(entries, auto_repair=True)
assert report.status == CacheHealthStatus.CORRUPTED
assert report.invalid_entries == 10
assert report.valid_entries == 90
@pytest.mark.asyncio
async def test_model_scanner_removes_invalid_from_hash_index():
"""Test that ModelScanner removes invalid entries from hash index"""
from py.services.model_hash_index import ModelHashIndex
# Create a hash index with some entries
hash_index = ModelHashIndex()
valid_entry = {
'file_path': '/models/valid.safetensors',
'sha256': 'abc123',
}
invalid_entry = {
'file_path': '/models/invalid.safetensors',
'sha256': '', # Empty sha256
}
# Add entries to hash index
hash_index.add_entry(valid_entry['sha256'], valid_entry['file_path'])
hash_index.add_entry(invalid_entry['sha256'], invalid_entry['file_path'])
# Verify both entries are in the index (using get_hash method)
assert hash_index.get_hash(valid_entry['file_path']) == valid_entry['sha256']
# Invalid entry won't be added due to empty sha256
assert hash_index.get_hash(invalid_entry['file_path']) is None
# Simulate removing invalid entry (it's not actually there, but let's test the method)
hash_index.remove_by_path(
CacheEntryValidator.get_file_path_safe(invalid_entry),
CacheEntryValidator.get_sha256_safe(invalid_entry)
)
# Verify valid entry remains
assert hash_index.get_hash(valid_entry['file_path']) == valid_entry['sha256']
def test_cache_entry_validator_handles_various_field_types():
"""Test that validator handles various field types correctly"""
# Test with different field types
entry = {
'file_path': '/models/test.safetensors',
'sha256': 'abc123',
'size': 1024, # int
'modified': 1234567890.0, # float
'favorite': True, # bool
'tags': ['tag1', 'tag2'], # list
'exclude': False, # bool
}
result = CacheEntryValidator.validate(entry, auto_repair=False)
assert result.is_valid is True
assert result.repaired is False
def test_cache_health_report_serialization():
"""Test that HealthReport can be serialized to dict"""
from py.services.cache_health_monitor import HealthReport
report = HealthReport(
status=CacheHealthStatus.DEGRADED,
total_entries=100,
valid_entries=98,
invalid_entries=2,
repaired_entries=1,
invalid_paths=['/path1', '/path2'],
message="Cache issues detected"
)
result = report.to_dict()
assert result['status'] == 'degraded'
assert result['total_entries'] == 100
assert result['valid_entries'] == 98
assert result['invalid_entries'] == 2
assert result['repaired_entries'] == 1
assert result['corruption_rate'] == '2.0%'
assert len(result['invalid_paths']) == 2
assert result['message'] == "Cache issues detected"

View File

@@ -242,6 +242,148 @@ async def test_bulk_metadata_refresh_reports_errors() -> None:
assert progress.events[-1]["error"] == "boom" assert progress.events[-1]["error"] == "boom"
async def test_bulk_metadata_refresh_skips_confirmed_not_found_models(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Models marked as from_civitai=False and civitai_deleted=True should be skipped."""
scanner = MockScanner()
scanner._cache.raw_data = [
{
"file_path": "model1.safetensors",
"sha256": "hash1",
"from_civitai": False,
"civitai_deleted": True,
"model_name": "NotOnCivitAI",
},
{
"file_path": "model2.safetensors",
"sha256": "hash2",
"from_civitai": True,
"model_name": "OnCivitAI",
},
]
service = MockModelService(scanner)
metadata_sync = StubMetadataSync()
settings = StubSettings(enable_metadata_archive_db=False)
progress = ProgressCollector()
async def fake_hydrate(model_data: Dict[str, Any]) -> Dict[str, Any]:
# Preserve the original data (simulating no metadata file on disk)
return model_data
monkeypatch.setattr(MetadataManager, "hydrate_model_data", staticmethod(fake_hydrate))
use_case = BulkMetadataRefreshUseCase(
service=service,
metadata_sync=metadata_sync,
settings_service=settings,
logger=logging.getLogger("test"),
)
result = await use_case.execute_with_error_handling(progress_callback=progress)
assert result["success"] is True
# Only model2 should be processed (model1 is skipped)
assert result["processed"] == 1
assert result["updated"] == 1
assert len(metadata_sync.calls) == 1
assert metadata_sync.calls[0]["file_path"] == "model2.safetensors"
async def test_bulk_metadata_refresh_skips_when_archive_checked(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Models with db_checked=True should be skipped even if archive DB is enabled."""
scanner = MockScanner()
scanner._cache.raw_data = [
{
"file_path": "model1.safetensors",
"sha256": "hash1",
"from_civitai": False,
"civitai_deleted": True,
"db_checked": True,
"model_name": "ArchiveChecked",
},
{
"file_path": "model2.safetensors",
"sha256": "hash2",
"from_civitai": False,
"civitai_deleted": True,
"db_checked": False,
"model_name": "ArchiveNotChecked",
},
]
service = MockModelService(scanner)
metadata_sync = StubMetadataSync()
settings = StubSettings(enable_metadata_archive_db=True)
progress = ProgressCollector()
async def fake_hydrate(model_data: Dict[str, Any]) -> Dict[str, Any]:
return model_data
monkeypatch.setattr(MetadataManager, "hydrate_model_data", staticmethod(fake_hydrate))
use_case = BulkMetadataRefreshUseCase(
service=service,
metadata_sync=metadata_sync,
settings_service=settings,
logger=logging.getLogger("test"),
)
result = await use_case.execute_with_error_handling(progress_callback=progress)
assert result["success"] is True
# Only model2 should be processed (model1 has db_checked=True)
assert result["processed"] == 1
assert result["updated"] == 1
assert len(metadata_sync.calls) == 1
assert metadata_sync.calls[0]["file_path"] == "model2.safetensors"
async def test_bulk_metadata_refresh_processes_never_fetched_models(
monkeypatch: pytest.MonkeyPatch,
) -> None:
"""Models that have never been fetched (from_civitai=None) should be processed."""
scanner = MockScanner()
scanner._cache.raw_data = [
{
"file_path": "model1.safetensors",
"sha256": "hash1",
"from_civitai": None,
"model_name": "NeverFetched",
},
{
"file_path": "model2.safetensors",
"sha256": "hash2",
"model_name": "NoFromCivitaiField",
},
]
service = MockModelService(scanner)
metadata_sync = StubMetadataSync()
settings = StubSettings(enable_metadata_archive_db=False)
progress = ProgressCollector()
async def fake_hydrate(model_data: Dict[str, Any]) -> Dict[str, Any]:
return model_data
monkeypatch.setattr(MetadataManager, "hydrate_model_data", staticmethod(fake_hydrate))
use_case = BulkMetadataRefreshUseCase(
service=service,
metadata_sync=metadata_sync,
settings_service=settings,
logger=logging.getLogger("test"),
)
result = await use_case.execute_with_error_handling(progress_callback=progress)
assert result["success"] is True
# Both models should be processed
assert result["processed"] == 2
assert result["updated"] == 2
assert len(metadata_sync.calls) == 2
async def test_download_model_use_case_raises_validation_error() -> None: async def test_download_model_use_case_raises_validation_error() -> None:
coordinator = StubDownloadCoordinator(error="validation") coordinator = StubDownloadCoordinator(error="validation")
use_case = DownloadModelUseCase(download_coordinator=coordinator) use_case = DownloadModelUseCase(download_coordinator=coordinator)

View File

@@ -75,6 +75,31 @@ def test_get_file_extension_defaults_to_jpg() -> None:
assert ext == ".jpg" assert ext == ".jpg"
def test_get_file_extension_from_media_type_hint_video() -> None:
"""Test that media_type_hint='video' returns .mp4 when other methods fail"""
ext = processor_module.ExampleImagesProcessor._get_file_extension_from_content_or_headers(
b"", {}, "https://c.genur.art/536be3c9-e506-4365-b078-bfbc5df9ceec", "video"
)
assert ext == ".mp4"
def test_get_file_extension_from_media_type_hint_image() -> None:
"""Test that media_type_hint='image' falls back to .jpg"""
ext = processor_module.ExampleImagesProcessor._get_file_extension_from_content_or_headers(
b"", {}, "https://example.com/no-extension", "image"
)
assert ext == ".jpg"
def test_get_file_extension_media_type_hint_low_priority() -> None:
"""Test that media_type_hint is only used as last resort (after URL extension)"""
# URL has extension, should use that instead of media_type_hint
ext = processor_module.ExampleImagesProcessor._get_file_extension_from_content_or_headers(
b"", {}, "https://example.com/video.mp4", "image"
)
assert ext == ".mp4"
class StubScanner: class StubScanner:
def __init__(self, models: list[Dict[str, Any]]) -> None: def __init__(self, models: list[Dict[str, Any]]) -> None:
self._cache = SimpleNamespace(raw_data=models) self._cache = SimpleNamespace(raw_data=models)

View File

@@ -6,7 +6,8 @@ export default defineConfig({
globals: true, globals: true,
setupFiles: ['tests/frontend/setup.js'], setupFiles: ['tests/frontend/setup.js'],
include: [ include: [
'tests/frontend/**/*.test.js' 'tests/frontend/**/*.test.js',
'tests/frontend/**/*.test.ts'
], ],
coverage: { coverage: {
enabled: process.env.VITEST_COVERAGE === 'true', enabled: process.env.VITEST_COVERAGE === 'true',

File diff suppressed because it is too large Load Diff

View File

@@ -12,9 +12,13 @@
"@comfyorg/comfyui-frontend-types": "^1.35.4", "@comfyorg/comfyui-frontend-types": "^1.35.4",
"@types/node": "^22.10.1", "@types/node": "^22.10.1",
"@vitejs/plugin-vue": "^5.2.3", "@vitejs/plugin-vue": "^5.2.3",
"@vitest/coverage-v8": "^3.2.4",
"@vue/test-utils": "^2.4.6",
"jsdom": "^26.0.0",
"typescript": "^5.7.2", "typescript": "^5.7.2",
"vite": "^6.3.5", "vite": "^6.3.5",
"vite-plugin-css-injected-by-js": "^3.5.2", "vite-plugin-css-injected-by-js": "^3.5.2",
"vitest": "^3.0.0",
"vue-tsc": "^2.1.10" "vue-tsc": "^2.1.10"
}, },
"scripts": { "scripts": {
@@ -24,6 +28,9 @@
"typecheck": "vue-tsc --noEmit", "typecheck": "vue-tsc --noEmit",
"clean": "rm -rf ../web/comfyui/vue-widgets", "clean": "rm -rf ../web/comfyui/vue-widgets",
"rebuild": "npm run clean && npm run build", "rebuild": "npm run clean && npm run build",
"prepare": "npm run build" "prepare": "npm run build",
"test": "vitest run",
"test:watch": "vitest",
"test:coverage": "vitest run --coverage"
} }
} }

View File

@@ -10,11 +10,28 @@
:use-custom-clip-range="state.useCustomClipRange.value" :use-custom-clip-range="state.useCustomClipRange.value"
:is-clip-strength-disabled="state.isClipStrengthDisabled.value" :is-clip-strength-disabled="state.isClipStrengthDisabled.value"
:is-loading="state.isLoading.value" :is-loading="state.isLoading.value"
:repeat-count="state.repeatCount.value"
:repeat-used="state.displayRepeatUsed.value"
:is-paused="state.isPaused.value"
:is-pause-disabled="hasQueuedPrompts"
:is-workflow-executing="state.isWorkflowExecuting.value"
:executing-repeat-step="state.executingRepeatStep.value"
@update:current-index="handleIndexUpdate" @update:current-index="handleIndexUpdate"
@update:model-strength="state.modelStrength.value = $event" @update:model-strength="state.modelStrength.value = $event"
@update:clip-strength="state.clipStrength.value = $event" @update:clip-strength="state.clipStrength.value = $event"
@update:use-custom-clip-range="handleUseCustomClipRangeChange" @update:use-custom-clip-range="handleUseCustomClipRangeChange"
@refresh="handleRefresh" @update:repeat-count="handleRepeatCountChange"
@toggle-pause="handleTogglePause"
@reset-index="handleResetIndex"
@open-lora-selector="isModalOpen = true"
/>
<LoraListModal
:visible="isModalOpen"
:lora-list="cachedLoraList"
:current-index="state.currentIndex.value"
@close="isModalOpen = false"
@select="handleModalSelect"
/> />
</div> </div>
</template> </template>
@@ -22,8 +39,9 @@
<script setup lang="ts"> <script setup lang="ts">
import { onMounted, ref } from 'vue' import { onMounted, ref } from 'vue'
import LoraCyclerSettingsView from './lora-cycler/LoraCyclerSettingsView.vue' import LoraCyclerSettingsView from './lora-cycler/LoraCyclerSettingsView.vue'
import LoraListModal from './lora-cycler/LoraListModal.vue'
import { useLoraCyclerState } from '../composables/useLoraCyclerState' import { useLoraCyclerState } from '../composables/useLoraCyclerState'
import type { ComponentWidget, CyclerConfig, LoraPoolConfig } from '../composables/types' import type { ComponentWidget, CyclerConfig, LoraPoolConfig, LoraItem } from '../composables/types'
type CyclerWidget = ComponentWidget<CyclerConfig> type CyclerWidget = ComponentWidget<CyclerConfig>
@@ -31,6 +49,7 @@ type CyclerWidget = ComponentWidget<CyclerConfig>
const props = defineProps<{ const props = defineProps<{
widget: CyclerWidget widget: CyclerWidget
node: { id: number; inputs?: any[]; widgets?: any[]; graph?: any } node: { id: number; inputs?: any[]; widgets?: any[]; graph?: any }
api?: any // ComfyUI API for execution events
}>() }>()
// State management // State management
@@ -39,12 +58,50 @@ const state = useLoraCyclerState(props.widget)
// Symbol to track if the widget has been executed at least once // Symbol to track if the widget has been executed at least once
const HAS_EXECUTED = Symbol('HAS_EXECUTED') const HAS_EXECUTED = Symbol('HAS_EXECUTED')
// Execution context queue for batch queue synchronization
// In batch queue mode, all beforeQueued calls happen BEFORE any onExecuted calls,
// so we need to snapshot the state at queue time and replay it during execution
interface ExecutionContext {
isPaused: boolean
repeatUsed: number
repeatCount: number
shouldAdvanceDisplay: boolean
displayRepeatUsed: number // Value to show in UI after completion
}
const executionQueue: ExecutionContext[] = []
// Reactive flag to track if there are queued prompts (for disabling pause button)
const hasQueuedPrompts = ref(false)
// Track pending executions for batch queue support (deferred UI updates)
// Uses FIFO order since executions are processed in the order they were queued
interface PendingExecution {
repeatUsed: number
repeatCount: number
shouldAdvanceDisplay: boolean
displayRepeatUsed: number // Value to show in UI after completion
output?: {
nextIndex: number
nextLoraName: string
nextLoraFilename: string
currentLoraName: string
currentLoraFilename: string
}
}
const pendingExecutions: PendingExecution[] = []
// Track last known pool config hash // Track last known pool config hash
const lastPoolConfigHash = ref('') const lastPoolConfigHash = ref('')
// Track if component is mounted // Track if component is mounted
const isMounted = ref(false) const isMounted = ref(false)
// Modal state
const isModalOpen = ref(false)
// Cache for LoRA list (used by modal)
const cachedLoraList = ref<LoraItem[]>([])
// Get pool config from connected node // Get pool config from connected node
const getPoolConfig = (): LoraPoolConfig | null => { const getPoolConfig = (): LoraPoolConfig | null => {
// Check if getPoolConfig method exists on node (added by main.ts) // Check if getPoolConfig method exists on node (added by main.ts)
@@ -54,27 +111,47 @@ const getPoolConfig = (): LoraPoolConfig | null => {
return null return null
} }
// Update display from LoRA list and index
const updateDisplayFromLoraList = (loraList: LoraItem[], index: number) => {
if (loraList.length > 0 && index > 0 && index <= loraList.length) {
const currentLora = loraList[index - 1]
if (currentLora) {
state.currentLoraName.value = currentLora.file_name
state.currentLoraFilename.value = currentLora.file_name
}
}
}
// Handle index update from user // Handle index update from user
const handleIndexUpdate = async (newIndex: number) => { const handleIndexUpdate = async (newIndex: number) => {
// Reset execution state when user manually changes index
// This ensures the next execution starts from the user-set index
;(props.widget as any)[HAS_EXECUTED] = false
state.executionIndex.value = null
state.nextIndex.value = null
// Clear execution queue since user is manually changing state
executionQueue.length = 0
hasQueuedPrompts.value = false
state.setIndex(newIndex) state.setIndex(newIndex)
// Refresh list to update current LoRA display // Refresh list to update current LoRA display
try { try {
const poolConfig = getPoolConfig() const poolConfig = getPoolConfig()
const loraList = await state.fetchCyclerList(poolConfig) const loraList = await state.fetchCyclerList(poolConfig)
cachedLoraList.value = loraList
if (loraList.length > 0 && newIndex > 0 && newIndex <= loraList.length) { updateDisplayFromLoraList(loraList, newIndex)
const currentLora = loraList[newIndex - 1]
if (currentLora) {
state.currentLoraName.value = currentLora.file_name
state.currentLoraFilename.value = currentLora.file_name
}
}
} catch (error) { } catch (error) {
console.error('[LoraCyclerWidget] Error updating index:', error) console.error('[LoraCyclerWidget] Error updating index:', error)
} }
} }
// Handle LoRA selection from modal
const handleModalSelect = (index: number) => {
handleIndexUpdate(index)
}
// Handle use custom clip range toggle // Handle use custom clip range toggle
const handleUseCustomClipRangeChange = (newValue: boolean) => { const handleUseCustomClipRangeChange = (newValue: boolean) => {
state.useCustomClipRange.value = newValue state.useCustomClipRange.value = newValue
@@ -84,13 +161,41 @@ const handleUseCustomClipRangeChange = (newValue: boolean) => {
} }
} }
// Handle refresh button click // Handle repeat count change
const handleRefresh = async () => { const handleRepeatCountChange = (newValue: number) => {
state.repeatCount.value = newValue
// Reset repeatUsed when changing repeat count
state.repeatUsed.value = 0
state.displayRepeatUsed.value = 0
}
// Handle pause toggle
const handleTogglePause = () => {
state.togglePause()
}
// Handle reset index
const handleResetIndex = async () => {
// Reset execution state
;(props.widget as any)[HAS_EXECUTED] = false
state.executionIndex.value = null
state.nextIndex.value = null
// Clear execution queue since user is resetting state
executionQueue.length = 0
hasQueuedPrompts.value = false
// Reset index and repeat state
state.resetIndex()
// Refresh list to update current LoRA display
try { try {
const poolConfig = getPoolConfig() const poolConfig = getPoolConfig()
await state.refreshList(poolConfig) const loraList = await state.fetchCyclerList(poolConfig)
cachedLoraList.value = loraList
updateDisplayFromLoraList(loraList, 1)
} catch (error) { } catch (error) {
console.error('[LoraCyclerWidget] Error refreshing:', error) console.error('[LoraCyclerWidget] Error resetting index:', error)
} }
} }
@@ -106,6 +211,9 @@ const checkPoolConfigChanges = async () => {
lastPoolConfigHash.value = newHash lastPoolConfigHash.value = newHash
try { try {
await state.refreshList(poolConfig) await state.refreshList(poolConfig)
// Update cached list when pool config changes
const loraList = await state.fetchCyclerList(poolConfig)
cachedLoraList.value = loraList
} catch (error) { } catch (error) {
console.error('[LoraCyclerWidget] Error on pool config change:', error) console.error('[LoraCyclerWidget] Error on pool config change:', error)
} }
@@ -129,17 +237,68 @@ onMounted(async () => {
// Add beforeQueued hook to handle index shifting for batch queue synchronization // Add beforeQueued hook to handle index shifting for batch queue synchronization
// This ensures each execution uses a different LoRA in the cycle // This ensures each execution uses a different LoRA in the cycle
// Now with support for repeat count and pause features
//
// IMPORTANT: In batch queue mode, ALL beforeQueued calls happen BEFORE any execution.
// We push an "execution context" snapshot to a queue so that onExecuted can use the
// correct state values that were captured at queue time (not the live state).
;(props.widget as any).beforeQueued = () => { ;(props.widget as any).beforeQueued = () => {
if (state.isPaused.value) {
// When paused: use current index, don't advance, don't count toward repeat limit
// Push context indicating this execution should NOT advance display
executionQueue.push({
isPaused: true,
repeatUsed: state.repeatUsed.value,
repeatCount: state.repeatCount.value,
shouldAdvanceDisplay: false,
displayRepeatUsed: state.displayRepeatUsed.value // Keep current display value when paused
})
hasQueuedPrompts.value = true
// CRITICAL: Clear execution_index when paused to force backend to use current_index
// This ensures paused executions use the same LoRA regardless of any
// execution_index set by previous non-paused beforeQueued calls
const pausedConfig = state.buildConfig()
pausedConfig.execution_index = null
props.widget.value = pausedConfig
return
}
if ((props.widget as any)[HAS_EXECUTED]) { if ((props.widget as any)[HAS_EXECUTED]) {
// After first execution: shift indices (previous next_index becomes execution_index) // After first execution: check repeat logic
state.generateNextIndex() if (state.repeatUsed.value < state.repeatCount.value) {
// Still repeating: increment repeatUsed, use same index
state.repeatUsed.value++
} else {
// Repeat complete: reset repeatUsed to 1, advance to next index
state.repeatUsed.value = 1
state.generateNextIndex()
}
} else { } else {
// First execution: just initialize next_index (execution_index stays null) // First execution: initialize
// This means first execution uses current_index from widget state.repeatUsed.value = 1
state.initializeNextIndex() state.initializeNextIndex()
;(props.widget as any)[HAS_EXECUTED] = true ;(props.widget as any)[HAS_EXECUTED] = true
} }
// Determine if this execution should advance the display
// (only when repeat cycle is complete for this queued item)
const shouldAdvanceDisplay = state.repeatUsed.value >= state.repeatCount.value
// Calculate the display value to show after this execution completes
// When advancing to a new LoRA: reset to 0 (fresh start for new LoRA)
// When repeating same LoRA: show current repeat step
const displayRepeatUsed = shouldAdvanceDisplay ? 0 : state.repeatUsed.value
// Push execution context snapshot to queue
executionQueue.push({
isPaused: false,
repeatUsed: state.repeatUsed.value,
repeatCount: state.repeatCount.value,
shouldAdvanceDisplay,
displayRepeatUsed
})
hasQueuedPrompts.value = true
// Update the widget value so the indices are included in the serialized config // Update the widget value so the indices are included in the serialized config
props.widget.value = state.buildConfig() props.widget.value = state.buildConfig()
} }
@@ -152,40 +311,71 @@ onMounted(async () => {
const poolConfig = getPoolConfig() const poolConfig = getPoolConfig()
lastPoolConfigHash.value = state.hashPoolConfig(poolConfig) lastPoolConfigHash.value = state.hashPoolConfig(poolConfig)
await state.refreshList(poolConfig) await state.refreshList(poolConfig)
// Cache the initial LoRA list for modal
const loraList = await state.fetchCyclerList(poolConfig)
cachedLoraList.value = loraList
} catch (error) { } catch (error) {
console.error('[LoraCyclerWidget] Error on initial load:', error) console.error('[LoraCyclerWidget] Error on initial load:', error)
} }
// Override onExecuted to handle backend UI updates // Override onExecuted to handle backend UI updates
// This defers the UI update until workflow completes (via API events)
const originalOnExecuted = (props.node as any).onExecuted?.bind(props.node) const originalOnExecuted = (props.node as any).onExecuted?.bind(props.node)
;(props.node as any).onExecuted = function(output: any) { ;(props.node as any).onExecuted = function(output: any) {
console.log("[LoraCyclerWidget] Node executed with output:", output) console.log("[LoraCyclerWidget] Node executed with output:", output)
// Update state from backend response (values are wrapped in arrays) // Pop execution context from queue (FIFO order)
if (output?.next_index !== undefined) { const context = executionQueue.shift()
const val = Array.isArray(output.next_index) ? output.next_index[0] : output.next_index hasQueuedPrompts.value = executionQueue.length > 0
state.currentIndex.value = val
} // Determine if we should advance the display index
const shouldAdvanceDisplay = context
? context.shouldAdvanceDisplay
: (!state.isPaused.value && state.repeatUsed.value >= state.repeatCount.value)
// Extract output values
const nextIndex = output?.next_index !== undefined
? (Array.isArray(output.next_index) ? output.next_index[0] : output.next_index)
: state.currentIndex.value
const nextLoraName = output?.next_lora_name !== undefined
? (Array.isArray(output.next_lora_name) ? output.next_lora_name[0] : output.next_lora_name)
: ''
const nextLoraFilename = output?.next_lora_filename !== undefined
? (Array.isArray(output.next_lora_filename) ? output.next_lora_filename[0] : output.next_lora_filename)
: ''
const currentLoraName = output?.current_lora_name !== undefined
? (Array.isArray(output.current_lora_name) ? output.current_lora_name[0] : output.current_lora_name)
: ''
const currentLoraFilename = output?.current_lora_filename !== undefined
? (Array.isArray(output.current_lora_filename) ? output.current_lora_filename[0] : output.current_lora_filename)
: ''
// Update total count immediately (doesn't need to wait for workflow completion)
if (output?.total_count !== undefined) { if (output?.total_count !== undefined) {
const val = Array.isArray(output.total_count) ? output.total_count[0] : output.total_count const val = Array.isArray(output.total_count) ? output.total_count[0] : output.total_count
state.totalCount.value = val state.totalCount.value = val
} }
if (output?.current_lora_name !== undefined) {
const val = Array.isArray(output.current_lora_name) ? output.current_lora_name[0] : output.current_lora_name // Store pending update (will be applied on workflow completion)
state.currentLoraName.value = val if (context) {
} pendingExecutions.push({
if (output?.current_lora_filename !== undefined) { repeatUsed: context.repeatUsed,
const val = Array.isArray(output.current_lora_filename) ? output.current_lora_filename[0] : output.current_lora_filename repeatCount: context.repeatCount,
state.currentLoraFilename.value = val shouldAdvanceDisplay,
} displayRepeatUsed: context.displayRepeatUsed,
if (output?.next_lora_name !== undefined) { output: {
const val = Array.isArray(output.next_lora_name) ? output.next_lora_name[0] : output.next_lora_name nextIndex,
state.currentLoraName.value = val nextLoraName,
} nextLoraFilename,
if (output?.next_lora_filename !== undefined) { currentLoraName,
const val = Array.isArray(output.next_lora_filename) ? output.next_lora_filename[0] : output.next_lora_filename currentLoraFilename
state.currentLoraFilename.value = val }
})
// Update visual feedback state (don't update displayRepeatUsed yet - wait for workflow completion)
state.executingRepeatStep.value = context.repeatUsed
state.isWorkflowExecuting.value = true
} }
// Call original onExecuted if it exists // Call original onExecuted if it exists
@@ -194,11 +384,69 @@ onMounted(async () => {
} }
} }
// Set up execution tracking via API events
if (props.api) {
// Handle workflow completion events using FIFO order
// Note: The 'executing' event doesn't contain prompt_id (only node ID as string),
// so we use FIFO order instead of prompt_id matching since executions are processed
// in the order they were queued
const handleExecutionComplete = () => {
// Process the first pending execution (FIFO order)
if (pendingExecutions.length === 0) {
return
}
const pending = pendingExecutions.shift()!
// Apply UI update now that workflow is complete
// Update repeat display (deferred like index updates)
state.displayRepeatUsed.value = pending.displayRepeatUsed
if (pending.output) {
if (pending.shouldAdvanceDisplay) {
state.currentIndex.value = pending.output.nextIndex
state.currentLoraName.value = pending.output.nextLoraName
state.currentLoraFilename.value = pending.output.nextLoraFilename
} else {
// When not advancing, show current LoRA info
state.currentLoraName.value = pending.output.currentLoraName
state.currentLoraFilename.value = pending.output.currentLoraFilename
}
}
// Reset visual feedback if no more pending
if (pendingExecutions.length === 0) {
state.isWorkflowExecuting.value = false
state.executingRepeatStep.value = 0
}
}
props.api.addEventListener('execution_success', handleExecutionComplete)
props.api.addEventListener('execution_error', handleExecutionComplete)
props.api.addEventListener('execution_interrupted', handleExecutionComplete)
// Store cleanup function for API listeners
const apiCleanup = () => {
props.api.removeEventListener('execution_success', handleExecutionComplete)
props.api.removeEventListener('execution_error', handleExecutionComplete)
props.api.removeEventListener('execution_interrupted', handleExecutionComplete)
}
// Extend existing cleanup
const existingCleanup = (props.widget as any).onRemoveCleanup
;(props.widget as any).onRemoveCleanup = () => {
existingCleanup?.()
apiCleanup()
}
}
// Watch for connection changes by polling (since ComfyUI doesn't provide connection events) // Watch for connection changes by polling (since ComfyUI doesn't provide connection events)
const checkInterval = setInterval(checkPoolConfigChanges, 1000) const checkInterval = setInterval(checkPoolConfigChanges, 1000)
// Cleanup on unmount (handled by Vue's effect scope) // Cleanup on unmount (handled by Vue's effect scope)
const existingCleanupForInterval = (props.widget as any).onRemoveCleanup
;(props.widget as any).onRemoveCleanup = () => { ;(props.widget as any).onRemoveCleanup = () => {
existingCleanupForInterval?.()
clearInterval(checkInterval) clearInterval(checkInterval)
} }
}) })

View File

@@ -6,57 +6,111 @@
<!-- Progress Display --> <!-- Progress Display -->
<div class="setting-section progress-section"> <div class="setting-section progress-section">
<div class="progress-display"> <div class="progress-display" :class="{ executing: isWorkflowExecuting }">
<div class="progress-info"> <div
<span class="progress-label">Next LoRA:</span> class="progress-info"
<span class="progress-name" :title="currentLoraFilename">{{ currentLoraName || 'None' }}</span> :class="{ disabled: isPauseDisabled }"
@click="handleOpenSelector"
>
<span class="progress-label">{{ isWorkflowExecuting ? 'Using LoRA:' : 'Next LoRA:' }}</span>
<span class="progress-name clickable" :class="{ disabled: isPauseDisabled }" :title="currentLoraFilename">
{{ currentLoraName || 'None' }}
<svg class="selector-icon" viewBox="0 0 24 24" fill="currentColor">
<path d="M7 10l5 5 5-5z"/>
</svg>
</span>
</div> </div>
<div class="progress-counter"> <div class="progress-counter">
<span class="progress-index">{{ currentIndex }}</span> <span class="progress-index">{{ currentIndex }}</span>
<span class="progress-separator">/</span> <span class="progress-separator">/</span>
<span class="progress-total">{{ totalCount }}</span> <span class="progress-total">{{ totalCount }}</span>
<button
class="refresh-button" <!-- Repeat progress indicator (only shown when repeatCount > 1) -->
:disabled="isLoading" <div v-if="repeatCount > 1" class="repeat-progress">
@click="$emit('refresh')" <div class="repeat-progress-track">
title="Refresh list" <div
> class="repeat-progress-fill"
<svg :style="{ width: `${(repeatUsed / repeatCount) * 100}%` }"
class="refresh-icon" :class="{ 'is-complete': repeatUsed >= repeatCount }"
:class="{ spinning: isLoading }" ></div>
viewBox="0 0 24 24" </div>
fill="none" <span class="repeat-progress-text">{{ repeatUsed }}/{{ repeatCount }}</span>
stroke="currentColor" </div>
stroke-width="2"
stroke-linecap="round"
stroke-linejoin="round"
>
<path d="M21 12a9 9 0 1 1-6.219-8.56"/>
<path d="M21 3v5h-5"/>
</svg>
</button>
</div> </div>
</div> </div>
</div> </div>
<!-- Starting Index --> <!-- Starting Index with Advanced Controls -->
<div class="setting-section"> <div class="setting-section">
<label class="setting-label">Starting Index</label> <div class="index-controls-row">
<div class="index-input-container"> <!-- Left: Index group -->
<input <div class="control-group">
type="number" <label class="control-group-label">Starting Index</label>
class="index-input" <div class="control-group-content">
:min="1" <input
:max="totalCount || 1" type="number"
:value="currentIndex" class="index-input"
:disabled="totalCount === 0" :min="1"
@input="onIndexInput" :max="totalCount || 1"
@blur="onIndexBlur" :value="currentIndex"
@pointerdown.stop :disabled="totalCount === 0"
@pointermove.stop @input="onIndexInput"
@pointerup.stop @blur="onIndexBlur"
/> @pointerdown.stop
<span class="index-hint">1 - {{ totalCount || 1 }}</span> @pointermove.stop
@pointerup.stop
/>
<span class="index-hint">/ {{ totalCount || 1 }}</span>
</div>
</div>
<!-- Right: Repeat group -->
<div class="control-group">
<label class="control-group-label">Repeat</label>
<div class="control-group-content">
<input
type="number"
class="repeat-input"
min="1"
max="99"
:value="repeatCount"
@input="onRepeatInput"
@blur="onRepeatBlur"
@pointerdown.stop
@pointermove.stop
@pointerup.stop
title="Each LoRA will be used this many times before moving to the next"
/>
<span class="repeat-suffix">×</span>
</div>
</div>
<!-- Action buttons -->
<div class="action-buttons">
<button
class="control-btn"
:class="{ active: isPaused }"
:disabled="isPauseDisabled"
@click="$emit('toggle-pause')"
:title="isPauseDisabled ? 'Cannot pause while prompts are queued' : (isPaused ? 'Continue iteration' : 'Pause iteration')"
>
<svg v-if="isPaused" viewBox="0 0 24 24" fill="currentColor" class="control-icon">
<path d="M8 5v14l11-7z"/>
</svg>
<svg v-else viewBox="0 0 24 24" fill="currentColor" class="control-icon">
<path d="M6 4h4v16H6zm8 0h4v16h-4z"/>
</svg>
</button>
<button
class="control-btn"
@click="$emit('reset-index')"
title="Reset to index 1"
>
<svg viewBox="0 0 24 24" fill="currentColor" class="control-icon">
<path d="M12 5V1L7 6l5 5V7c3.31 0 6 2.69 6 6s-2.69 6-6 6-6-2.69-6-6H4c0 4.42 3.58 8 8 8s8-3.58 8-8-3.58-8-8-8z"/>
</svg>
</button>
</div>
</div> </div>
</div> </div>
@@ -122,7 +176,12 @@ const props = defineProps<{
clipStrength: number clipStrength: number
useCustomClipRange: boolean useCustomClipRange: boolean
isClipStrengthDisabled: boolean isClipStrengthDisabled: boolean
isLoading: boolean repeatCount: number
repeatUsed: number
isPaused: boolean
isPauseDisabled: boolean
isWorkflowExecuting: boolean
executingRepeatStep: number
}>() }>()
const emit = defineEmits<{ const emit = defineEmits<{
@@ -130,11 +189,22 @@ const emit = defineEmits<{
'update:modelStrength': [value: number] 'update:modelStrength': [value: number]
'update:clipStrength': [value: number] 'update:clipStrength': [value: number]
'update:useCustomClipRange': [value: boolean] 'update:useCustomClipRange': [value: boolean]
'refresh': [] 'update:repeatCount': [value: number]
'toggle-pause': []
'reset-index': []
'open-lora-selector': []
}>() }>()
// Temporary value for input while typing // Temporary value for input while typing
const tempIndex = ref<string>('') const tempIndex = ref<string>('')
const tempRepeat = ref<string>('')
const handleOpenSelector = () => {
if (props.isPauseDisabled) {
return
}
emit('open-lora-selector')
}
const onIndexInput = (event: Event) => { const onIndexInput = (event: Event) => {
const input = event.target as HTMLInputElement const input = event.target as HTMLInputElement
@@ -154,6 +224,25 @@ const onIndexBlur = (event: Event) => {
} }
tempIndex.value = '' tempIndex.value = ''
} }
const onRepeatInput = (event: Event) => {
const input = event.target as HTMLInputElement
tempRepeat.value = input.value
}
const onRepeatBlur = (event: Event) => {
const input = event.target as HTMLInputElement
const value = parseInt(input.value, 10)
if (!isNaN(value)) {
const clampedValue = Math.max(1, Math.min(value, 99))
emit('update:repeatCount', clampedValue)
input.value = clampedValue.toString()
} else {
input.value = props.repeatCount.toString()
}
tempRepeat.value = ''
}
</script> </script>
<style scoped> <style scoped>
@@ -203,6 +292,17 @@ const onIndexBlur = (event: Event) => {
display: flex; display: flex;
justify-content: space-between; justify-content: space-between;
align-items: center; align-items: center;
transition: border-color 0.3s ease;
}
.progress-display.executing {
border-color: rgba(66, 153, 225, 0.5);
animation: pulse 2s ease-in-out infinite;
}
@keyframes pulse {
0%, 100% { border-color: rgba(66, 153, 225, 0.3); }
50% { border-color: rgba(66, 153, 225, 0.7); }
} }
.progress-info { .progress-info {
@@ -230,6 +330,42 @@ const onIndexBlur = (event: Event) => {
white-space: nowrap; white-space: nowrap;
} }
.progress-name.clickable {
cursor: pointer;
padding: 2px 6px;
margin: -2px -6px;
border-radius: 4px;
transition: all 0.2s;
display: inline-flex;
align-items: center;
gap: 4px;
}
.progress-name.clickable:hover:not(.disabled) {
background: rgba(66, 153, 225, 0.2);
color: rgba(191, 219, 254, 1);
}
.progress-name.clickable.disabled {
cursor: not-allowed;
opacity: 0.5;
}
.progress-info.disabled {
cursor: not-allowed;
}
.selector-icon {
width: 16px;
height: 16px;
opacity: 0.5;
flex-shrink: 0;
}
.progress-name.clickable:hover .selector-icon {
opacity: 0.8;
}
.progress-counter { .progress-counter {
display: flex; display: flex;
align-items: center; align-items: center;
@@ -243,6 +379,9 @@ const onIndexBlur = (event: Event) => {
font-weight: 600; font-weight: 600;
color: rgba(66, 153, 225, 1); color: rgba(66, 153, 225, 1);
font-family: 'SF Mono', 'Roboto Mono', monospace; font-family: 'SF Mono', 'Roboto Mono', monospace;
min-width: 4ch;
text-align: right;
font-variant-numeric: tabular-nums;
} }
.progress-separator { .progress-separator {
@@ -256,69 +395,92 @@ const onIndexBlur = (event: Event) => {
font-weight: 500; font-weight: 500;
color: rgba(226, 232, 240, 0.6); color: rgba(226, 232, 240, 0.6);
font-family: 'SF Mono', 'Roboto Mono', monospace; font-family: 'SF Mono', 'Roboto Mono', monospace;
min-width: 4ch;
text-align: left;
font-variant-numeric: tabular-nums;
} }
.refresh-button { /* Repeat Progress */
.repeat-progress {
display: flex; display: flex;
align-items: center; align-items: center;
justify-content: center; gap: 6px;
width: 24px;
height: 24px;
margin-left: 8px; margin-left: 8px;
padding: 0; padding: 2px 6px;
background: transparent; background: rgba(26, 32, 44, 0.6);
border: 1px solid rgba(255, 255, 255, 0.1); border: 1px solid rgba(226, 232, 240, 0.1);
border-radius: 4px; border-radius: 4px;
color: rgba(226, 232, 240, 0.6);
cursor: pointer;
transition: all 0.2s;
} }
.refresh-button:hover:not(:disabled) { .repeat-progress-track {
background: rgba(66, 153, 225, 0.2); width: 32px;
border-color: rgba(66, 153, 225, 0.4); height: 4px;
color: rgba(191, 219, 254, 1); background: rgba(226, 232, 240, 0.15);
border-radius: 2px;
overflow: hidden;
} }
.refresh-button:disabled { .repeat-progress-fill {
opacity: 0.4; height: 100%;
cursor: not-allowed; background: linear-gradient(90deg, #f59e0b, #fbbf24);
border-radius: 2px;
transition: width 0.3s ease;
} }
.refresh-icon { .repeat-progress-fill.is-complete {
width: 14px; background: linear-gradient(90deg, #10b981, #34d399);
height: 14px;
} }
.refresh-icon.spinning { .repeat-progress-text {
animation: spin 1s linear infinite; font-size: 10px;
font-family: 'SF Mono', 'Roboto Mono', monospace;
color: rgba(253, 230, 138, 0.9);
min-width: 3ch;
font-variant-numeric: tabular-nums;
} }
@keyframes spin { /* Index Controls Row - Grouped Layout */
from { .index-controls-row {
transform: rotate(0deg);
}
to {
transform: rotate(360deg);
}
}
/* Index Input */
.index-input-container {
display: flex; display: flex;
align-items: center; align-items: flex-end;
gap: 8px; gap: 16px;
}
/* Control Group */
.control-group {
display: flex;
flex-direction: column;
gap: 6px;
}
.control-group-label {
font-size: 11px;
font-weight: 500;
color: rgba(226, 232, 240, 0.5);
text-transform: uppercase;
letter-spacing: 0.03em;
line-height: 1;
}
.control-group-content {
display: flex;
align-items: baseline;
gap: 4px;
height: 32px;
} }
.index-input { .index-input {
width: 80px; width: 50px;
padding: 6px 10px; height: 32px;
padding: 0 8px;
background: rgba(26, 32, 44, 0.9); background: rgba(26, 32, 44, 0.9);
border: 1px solid rgba(226, 232, 240, 0.2); border: 1px solid rgba(226, 232, 240, 0.2);
border-radius: 6px; border-radius: 6px;
color: #e4e4e7; color: #e4e4e7;
font-size: 13px; font-size: 13px;
font-family: 'SF Mono', 'Roboto Mono', monospace; font-family: 'SF Mono', 'Roboto Mono', monospace;
line-height: 32px;
box-sizing: border-box;
} }
.index-input:focus { .index-input:focus {
@@ -332,8 +494,89 @@ const onIndexBlur = (event: Event) => {
} }
.index-hint { .index-hint {
font-size: 11px; font-size: 12px;
color: rgba(226, 232, 240, 0.4); color: rgba(226, 232, 240, 0.4);
font-variant-numeric: tabular-nums;
line-height: 32px;
}
/* Repeat Controls */
.repeat-input {
width: 40px;
height: 32px;
padding: 0 6px;
background: rgba(26, 32, 44, 0.9);
border: 1px solid rgba(226, 232, 240, 0.2);
border-radius: 6px;
color: #e4e4e7;
font-size: 13px;
font-family: 'SF Mono', 'Roboto Mono', monospace;
text-align: center;
line-height: 32px;
box-sizing: border-box;
}
.repeat-input:focus {
outline: none;
border-color: rgba(66, 153, 225, 0.6);
}
.repeat-suffix {
font-size: 13px;
color: rgba(226, 232, 240, 0.4);
font-weight: 500;
line-height: 32px;
}
/* Action Buttons */
.action-buttons {
display: flex;
align-items: center;
gap: 6px;
margin-left: auto;
}
/* Control Buttons */
.control-btn {
display: flex;
align-items: center;
justify-content: center;
width: 24px;
height: 24px;
padding: 0;
background: transparent;
border: 1px solid rgba(255, 255, 255, 0.1);
border-radius: 4px;
color: rgba(226, 232, 240, 0.6);
cursor: pointer;
transition: all 0.2s;
}
.control-btn:hover:not(:disabled) {
background: rgba(66, 153, 225, 0.2);
border-color: rgba(66, 153, 225, 0.4);
color: rgba(191, 219, 254, 1);
}
.control-btn:disabled {
opacity: 0.4;
cursor: not-allowed;
}
.control-btn.active {
background: rgba(245, 158, 11, 0.2);
border-color: rgba(245, 158, 11, 0.5);
color: rgba(253, 230, 138, 1);
}
.control-btn.active:hover {
background: rgba(245, 158, 11, 0.3);
border-color: rgba(245, 158, 11, 0.6);
}
.control-icon {
width: 14px;
height: 14px;
} }
/* Slider Container */ /* Slider Container */

View File

@@ -0,0 +1,313 @@
<template>
<ModalWrapper
:visible="visible"
title="Select LoRA"
:subtitle="subtitleText"
@close="$emit('close')"
>
<template #search>
<div class="search-container">
<svg class="search-icon" viewBox="0 0 16 16" fill="currentColor">
<path d="M11.742 10.344a6.5 6.5 0 1 0-1.397 1.398h-.001c.03.04.062.078.098.115l3.85 3.85a1 1 0 0 0 1.415-1.414l-3.85-3.85a1.007 1.007 0 0 0-.115-.1zM12 6.5a5.5 5.5 0 1 1-11 0 5.5 5.5 0 0 1 11 0z"/>
</svg>
<input
ref="searchInputRef"
v-model="searchQuery"
type="text"
class="search-input"
placeholder="Search LoRAs..."
/>
<button
v-if="searchQuery"
type="button"
class="clear-button"
@click="clearSearch"
>
<svg viewBox="0 0 16 16" fill="currentColor">
<path d="M4.646 4.646a.5.5 0 0 1 .708 0L8 7.293l2.646-2.647a.5.5 0 0 1 .708.708L8.707 8l2.647 2.646a.5.5 0 0 1-.708.708L8 8.707l-2.646 2.647a.5.5 0 0 1-.708-.708L7.293 8 4.646 5.354a.5.5 0 0 1 0-.708z"/>
</svg>
</button>
</div>
</template>
<div class="lora-list">
<div
v-for="item in filteredList"
:key="item.index"
class="lora-item"
:class="{ active: currentIndex === item.index }"
@mouseenter="showPreview(item.lora.file_name, $event)"
@mouseleave="hidePreview"
@click="selectLora(item.index)"
>
<span class="lora-index">{{ item.index }}</span>
<span class="lora-name" :title="item.lora.file_name">{{ item.lora.file_name }}</span>
<span v-if="currentIndex === item.index" class="current-badge">Current</span>
</div>
<div v-if="filteredList.length === 0" class="no-results">
No LoRAs found
</div>
</div>
</ModalWrapper>
</template>
<script setup lang="ts">
import { ref, computed, watch, nextTick, onUnmounted } from 'vue'
import ModalWrapper from '../lora-pool/modals/ModalWrapper.vue'
import type { LoraItem } from '../../composables/types'
interface LoraListItem {
index: number
lora: LoraItem
}
const props = defineProps<{
visible: boolean
loraList: LoraItem[]
currentIndex: number
}>()
const emit = defineEmits<{
close: []
select: [index: number]
}>()
const searchQuery = ref('')
const searchInputRef = ref<HTMLInputElement | null>(null)
// Preview tooltip instance (lazy init)
let previewTooltip: any = null
const subtitleText = computed(() => {
const total = props.loraList.length
const filtered = filteredList.value.length
if (filtered === total) {
return `Total: ${total} LoRA${total !== 1 ? 's' : ''}`
}
return `Showing ${filtered} of ${total} LoRA${total !== 1 ? 's' : ''}`
})
const filteredList = computed<LoraListItem[]>(() => {
const list = props.loraList.map((lora, idx) => ({
index: idx + 1,
lora
}))
if (!searchQuery.value.trim()) {
return list
}
const query = searchQuery.value.toLowerCase()
return list.filter(item =>
item.lora.file_name.toLowerCase().includes(query)
)
})
const clearSearch = () => {
searchQuery.value = ''
searchInputRef.value?.focus()
}
const selectLora = (index: number) => {
emit('select', index)
emit('close')
}
// Custom preview URL resolver for Vue widgets environment
// The default preview_tooltip.js uses api.fetchApi which is mocked as native fetch
// in the Vue widgets build, so we need to use the full path with /api prefix
const customPreviewUrlResolver = async (modelName: string) => {
const response = await fetch(
`/api/lm/loras/preview-url?name=${encodeURIComponent(modelName)}&license_flags=true`
)
if (!response.ok) {
throw new Error('Failed to fetch preview URL')
}
const data = await response.json()
if (!data.success || !data.preview_url) {
throw new Error('No preview available')
}
return {
previewUrl: data.preview_url,
displayName: data.display_name ?? modelName,
licenseFlags: data.license_flags
}
}
// Lazy load PreviewTooltip to avoid loading it unnecessarily
const getPreviewTooltip = async () => {
if (!previewTooltip) {
const { PreviewTooltip } = await import(/* @vite-ignore */ `${'../preview_tooltip.js'}`)
previewTooltip = new PreviewTooltip({
modelType: 'loras',
displayNameFormatter: (name: string) => name,
previewUrlResolver: customPreviewUrlResolver
})
}
return previewTooltip
}
const showPreview = async (loraName: string, event: MouseEvent) => {
const tooltip = await getPreviewTooltip()
const rect = (event.target as HTMLElement).getBoundingClientRect()
// Position to the right of the item, centered vertically
tooltip.show(loraName, rect.right + 10, rect.top + rect.height / 2)
}
const hidePreview = async () => {
if (previewTooltip) {
previewTooltip.hide()
}
}
// Focus search input when modal opens
watch(() => props.visible, (isVisible) => {
if (isVisible) {
searchQuery.value = ''
nextTick(() => {
searchInputRef.value?.focus()
})
} else {
// Hide preview when modal closes
hidePreview()
}
})
// Cleanup on unmount
onUnmounted(() => {
if (previewTooltip) {
previewTooltip.cleanup()
previewTooltip = null
}
})
</script>
<style scoped>
.search-container {
position: relative;
}
.search-icon {
position: absolute;
left: 10px;
top: 50%;
transform: translateY(-50%);
width: 14px;
height: 14px;
color: var(--fg-color, #fff);
opacity: 0.5;
}
.search-input {
width: 100%;
padding: 8px 32px;
background: var(--comfy-input-bg, #333);
border: 1px solid var(--border-color, #444);
border-radius: 6px;
color: var(--fg-color, #fff);
font-size: 13px;
outline: none;
box-sizing: border-box;
}
.search-input:focus {
border-color: rgba(66, 153, 225, 0.6);
}
.search-input::placeholder {
color: var(--fg-color, #fff);
opacity: 0.4;
}
.clear-button {
position: absolute;
right: 8px;
top: 50%;
transform: translateY(-50%);
display: flex;
align-items: center;
justify-content: center;
width: 20px;
height: 20px;
background: transparent;
border: none;
cursor: pointer;
padding: 0;
opacity: 0.5;
transition: opacity 0.15s;
}
.clear-button:hover {
opacity: 0.8;
}
.clear-button svg {
width: 12px;
height: 12px;
color: var(--fg-color, #fff);
}
.lora-list {
display: flex;
flex-direction: column;
gap: 2px;
max-height: 400px;
overflow-y: auto;
}
.lora-item {
display: flex;
align-items: center;
gap: 12px;
padding: 10px 12px;
border-radius: 6px;
cursor: pointer;
transition: all 0.15s;
border-left: 3px solid transparent;
}
.lora-item:hover {
background: rgba(66, 153, 225, 0.15);
}
.lora-item.active {
background: rgba(66, 153, 225, 0.25);
border-left-color: rgba(66, 153, 225, 0.8);
}
.lora-index {
font-family: 'SF Mono', 'Roboto Mono', monospace;
font-size: 12px;
color: rgba(226, 232, 240, 0.5);
min-width: 3ch;
text-align: right;
font-variant-numeric: tabular-nums;
}
.lora-name {
flex: 1;
font-size: 13px;
color: var(--fg-color, #fff);
overflow: hidden;
text-overflow: ellipsis;
white-space: nowrap;
}
.current-badge {
font-size: 11px;
padding: 2px 8px;
background: rgba(66, 153, 225, 0.3);
border: 1px solid rgba(66, 153, 225, 0.5);
border-radius: 4px;
color: rgba(191, 219, 254, 1);
font-weight: 500;
}
.no-results {
padding: 32px 20px;
text-align: center;
color: var(--fg-color, #fff);
opacity: 0.5;
font-size: 13px;
}
</style>

View File

@@ -81,7 +81,7 @@ watch(() => props.visible, (isVisible) => {
.lora-pool-modal-backdrop { .lora-pool-modal-backdrop {
position: fixed; position: fixed;
inset: 0; inset: 0;
z-index: 10000; z-index: 9998;
background: rgba(0, 0, 0, 0.6); background: rgba(0, 0, 0, 0.6);
display: flex; display: flex;
align-items: center; align-items: center;

View File

@@ -206,7 +206,9 @@ const stepToDecimals = (step: number): number => {
const snapToStep = (value: number, segmentMultiplier?: number): number => { const snapToStep = (value: number, segmentMultiplier?: number): number => {
const effectiveStep = segmentMultiplier ? props.step * segmentMultiplier : props.step const effectiveStep = segmentMultiplier ? props.step * segmentMultiplier : props.step
const steps = Math.round((value - props.min) / effectiveStep) const steps = Math.round((value - props.min) / effectiveStep)
return Math.max(props.min, Math.min(props.max, props.min + steps * effectiveStep)) const rawValue = Math.max(props.min, Math.min(props.max, props.min + steps * effectiveStep))
// Fix floating point precision issues, limit to 2 decimal places
return Math.round(rawValue * 100) / 100
} }
const startDrag = (handle: 'min' | 'max', event: PointerEvent) => { const startDrag = (handle: 'min' | 'max', event: PointerEvent) => {

View File

@@ -82,7 +82,9 @@ const stepToDecimals = (step: number): number => {
const snapToStep = (value: number): number => { const snapToStep = (value: number): number => {
const steps = Math.round((value - props.min) / props.step) const steps = Math.round((value - props.min) / props.step)
return Math.max(props.min, Math.min(props.max, props.min + steps * props.step)) const rawValue = Math.max(props.min, Math.min(props.max, props.min + steps * props.step))
// Fix floating point precision issues, limit to 2 decimal places
return Math.round(rawValue * 100) / 100
} }
const startDrag = (event: PointerEvent) => { const startDrag = (event: PointerEvent) => {

View File

@@ -80,6 +80,10 @@ export interface CyclerConfig {
// Dual-index mechanism for batch queue synchronization // Dual-index mechanism for batch queue synchronization
execution_index?: number | null // Index to use for current execution execution_index?: number | null // Index to use for current execution
next_index?: number | null // Index for display after execution next_index?: number | null // Index for display after execution
// Advanced index control features
repeat_count: number // How many times each LoRA should repeat (default: 1)
repeat_used: number // How many times current index has been used
is_paused: boolean // Whether iteration is paused
} }
// Widget config union type // Widget config union type

View File

@@ -29,6 +29,16 @@ export function useLoraCyclerState(widget: ComponentWidget<CyclerConfig>) {
const executionIndex = ref<number | null>(null) const executionIndex = ref<number | null>(null)
const nextIndex = ref<number | null>(null) const nextIndex = ref<number | null>(null)
// Advanced index control features
const repeatCount = ref(1) // How many times each LoRA should repeat
const repeatUsed = ref(0) // How many times current index has been used (internal tracking)
const displayRepeatUsed = ref(0) // For UI display, deferred updates like currentIndex
const isPaused = ref(false) // Whether iteration is paused
// Execution progress tracking (visual feedback)
const isWorkflowExecuting = ref(false) // Workflow is currently running
const executingRepeatStep = ref(0) // Which repeat step (1-based, 0 = not executing)
// Build config object from current state // Build config object from current state
const buildConfig = (): CyclerConfig => { const buildConfig = (): CyclerConfig => {
// Skip updating widget.value during restoration to prevent infinite loops // Skip updating widget.value during restoration to prevent infinite loops
@@ -45,6 +55,9 @@ export function useLoraCyclerState(widget: ComponentWidget<CyclerConfig>) {
current_lora_filename: currentLoraFilename.value, current_lora_filename: currentLoraFilename.value,
execution_index: executionIndex.value, execution_index: executionIndex.value,
next_index: nextIndex.value, next_index: nextIndex.value,
repeat_count: repeatCount.value,
repeat_used: repeatUsed.value,
is_paused: isPaused.value,
} }
} }
return { return {
@@ -59,6 +72,9 @@ export function useLoraCyclerState(widget: ComponentWidget<CyclerConfig>) {
current_lora_filename: currentLoraFilename.value, current_lora_filename: currentLoraFilename.value,
execution_index: executionIndex.value, execution_index: executionIndex.value,
next_index: nextIndex.value, next_index: nextIndex.value,
repeat_count: repeatCount.value,
repeat_used: repeatUsed.value,
is_paused: isPaused.value,
} }
} }
@@ -77,6 +93,10 @@ export function useLoraCyclerState(widget: ComponentWidget<CyclerConfig>) {
sortBy.value = config.sort_by || 'filename' sortBy.value = config.sort_by || 'filename'
currentLoraName.value = config.current_lora_name || '' currentLoraName.value = config.current_lora_name || ''
currentLoraFilename.value = config.current_lora_filename || '' currentLoraFilename.value = config.current_lora_filename || ''
// Advanced index control features
repeatCount.value = config.repeat_count ?? 1
repeatUsed.value = config.repeat_used ?? 0
isPaused.value = config.is_paused ?? false
// Note: execution_index and next_index are not restored from config // Note: execution_index and next_index are not restored from config
// as they are transient values used only during batch execution // as they are transient values used only during batch execution
} finally { } finally {
@@ -215,6 +235,19 @@ export function useLoraCyclerState(widget: ComponentWidget<CyclerConfig>) {
} }
} }
// Reset index to 1 and clear repeat state
const resetIndex = () => {
currentIndex.value = 1
repeatUsed.value = 0
displayRepeatUsed.value = 0
// Note: isPaused is intentionally not reset - user may want to stay paused after reset
}
// Toggle pause state
const togglePause = () => {
isPaused.value = !isPaused.value
}
// Computed property to check if clip strength is disabled // Computed property to check if clip strength is disabled
const isClipStrengthDisabled = computed(() => !useCustomClipRange.value) const isClipStrengthDisabled = computed(() => !useCustomClipRange.value)
@@ -236,6 +269,9 @@ export function useLoraCyclerState(widget: ComponentWidget<CyclerConfig>) {
sortBy, sortBy,
currentLoraName, currentLoraName,
currentLoraFilename, currentLoraFilename,
repeatCount,
repeatUsed,
isPaused,
], () => { ], () => {
widget.value = buildConfig() widget.value = buildConfig()
}, { deep: true }) }, { deep: true })
@@ -254,6 +290,12 @@ export function useLoraCyclerState(widget: ComponentWidget<CyclerConfig>) {
isLoading, isLoading,
executionIndex, executionIndex,
nextIndex, nextIndex,
repeatCount,
repeatUsed,
displayRepeatUsed,
isPaused,
isWorkflowExecuting,
executingRepeatStep,
// Computed // Computed
isClipStrengthDisabled, isClipStrengthDisabled,
@@ -267,5 +309,7 @@ export function useLoraCyclerState(widget: ComponentWidget<CyclerConfig>) {
setIndex, setIndex,
generateNextIndex, generateNextIndex,
initializeNextIndex, initializeNextIndex,
resetIndex,
togglePause,
} }
} }

View File

@@ -27,6 +27,8 @@ const AUTOCOMPLETE_TEXT_WIDGET_MAX_HEIGHT = 100
// @ts-ignore - ComfyUI external module // @ts-ignore - ComfyUI external module
import { app } from '../../../scripts/app.js' import { app } from '../../../scripts/app.js'
// @ts-ignore - ComfyUI external module
import { api } from '../../../scripts/api.js'
// @ts-ignore // @ts-ignore
import { getPoolConfigFromConnectedNode, getActiveLorasFromNode, updateConnectedTriggerWords, updateDownstreamLoaders } from '../../web/comfyui/utils.js' import { getPoolConfigFromConnectedNode, getActiveLorasFromNode, updateConnectedTriggerWords, updateDownstreamLoaders } from '../../web/comfyui/utils.js'
@@ -255,7 +257,8 @@ function createLoraCyclerWidget(node) {
const vueApp = createApp(LoraCyclerWidget, { const vueApp = createApp(LoraCyclerWidget, {
widget, widget,
node node,
api
}) })
vueApp.use(PrimeVue, { vueApp.use(PrimeVue, {

View File

@@ -0,0 +1,634 @@
/**
* Unit tests for useLoraCyclerState composable
*
* Tests pure state transitions and index calculations in isolation.
*/
import { describe, it, expect, beforeEach, vi } from 'vitest'
import { useLoraCyclerState } from '@/composables/useLoraCyclerState'
import {
createMockWidget,
createMockCyclerConfig,
createMockPoolConfig
} from '../fixtures/mockConfigs'
import { setupFetchMock, resetFetchMock } from '../setup'
describe('useLoraCyclerState', () => {
beforeEach(() => {
resetFetchMock()
})
describe('Initial State', () => {
it('should initialize with default values', () => {
const widget = createMockWidget()
const state = useLoraCyclerState(widget)
expect(state.currentIndex.value).toBe(1)
expect(state.totalCount.value).toBe(0)
expect(state.poolConfigHash.value).toBe('')
expect(state.modelStrength.value).toBe(1.0)
expect(state.clipStrength.value).toBe(1.0)
expect(state.useCustomClipRange.value).toBe(false)
expect(state.sortBy.value).toBe('filename')
expect(state.executionIndex.value).toBeNull()
expect(state.nextIndex.value).toBeNull()
expect(state.repeatCount.value).toBe(1)
expect(state.repeatUsed.value).toBe(0)
expect(state.displayRepeatUsed.value).toBe(0)
expect(state.isPaused.value).toBe(false)
})
})
describe('restoreFromConfig', () => {
it('should restore state from config object', () => {
const widget = createMockWidget()
const state = useLoraCyclerState(widget)
const config = createMockCyclerConfig({
current_index: 3,
total_count: 10,
model_strength: 0.8,
clip_strength: 0.6,
use_same_clip_strength: false,
repeat_count: 2,
repeat_used: 1,
is_paused: true
})
state.restoreFromConfig(config)
expect(state.currentIndex.value).toBe(3)
expect(state.totalCount.value).toBe(10)
expect(state.modelStrength.value).toBe(0.8)
expect(state.clipStrength.value).toBe(0.6)
expect(state.useCustomClipRange.value).toBe(true) // inverted from use_same_clip_strength
expect(state.repeatCount.value).toBe(2)
expect(state.repeatUsed.value).toBe(1)
expect(state.isPaused.value).toBe(true)
})
it('should handle missing optional fields with defaults', () => {
const widget = createMockWidget()
const state = useLoraCyclerState(widget)
// Minimal config
state.restoreFromConfig({
current_index: 5,
total_count: 10,
pool_config_hash: '',
model_strength: 1.0,
clip_strength: 1.0,
use_same_clip_strength: true,
sort_by: 'filename',
current_lora_name: '',
current_lora_filename: '',
repeat_count: 1,
repeat_used: 0,
is_paused: false
})
expect(state.currentIndex.value).toBe(5)
expect(state.repeatCount.value).toBe(1)
expect(state.isPaused.value).toBe(false)
})
it('should not restore execution_index and next_index (transient values)', () => {
const widget = createMockWidget()
const state = useLoraCyclerState(widget)
// Set execution indices
state.executionIndex.value = 2
state.nextIndex.value = 3
// Restore from config (these fields in config should be ignored)
state.restoreFromConfig(createMockCyclerConfig({
execution_index: 5,
next_index: 6
}))
// Execution indices should remain unchanged
expect(state.executionIndex.value).toBe(2)
expect(state.nextIndex.value).toBe(3)
})
})
describe('buildConfig', () => {
it('should build config object from current state', () => {
const widget = createMockWidget()
const state = useLoraCyclerState(widget)
state.currentIndex.value = 3
state.totalCount.value = 10
state.modelStrength.value = 0.8
state.repeatCount.value = 2
state.repeatUsed.value = 1
state.isPaused.value = true
const config = state.buildConfig()
expect(config.current_index).toBe(3)
expect(config.total_count).toBe(10)
expect(config.model_strength).toBe(0.8)
expect(config.repeat_count).toBe(2)
expect(config.repeat_used).toBe(1)
expect(config.is_paused).toBe(true)
})
})
describe('setIndex', () => {
it('should set index within valid range', () => {
const widget = createMockWidget()
const state = useLoraCyclerState(widget)
state.totalCount.value = 10
state.setIndex(5)
expect(state.currentIndex.value).toBe(5)
state.setIndex(1)
expect(state.currentIndex.value).toBe(1)
state.setIndex(10)
expect(state.currentIndex.value).toBe(10)
})
it('should not set index outside valid range', () => {
const widget = createMockWidget()
const state = useLoraCyclerState(widget)
state.totalCount.value = 10
state.currentIndex.value = 5
state.setIndex(0)
expect(state.currentIndex.value).toBe(5) // unchanged
state.setIndex(11)
expect(state.currentIndex.value).toBe(5) // unchanged
state.setIndex(-1)
expect(state.currentIndex.value).toBe(5) // unchanged
})
})
describe('resetIndex', () => {
it('should reset index to 1 and clear repeatUsed and displayRepeatUsed', () => {
const widget = createMockWidget()
const state = useLoraCyclerState(widget)
state.currentIndex.value = 5
state.repeatUsed.value = 2
state.displayRepeatUsed.value = 2
state.isPaused.value = true
state.resetIndex()
expect(state.currentIndex.value).toBe(1)
expect(state.repeatUsed.value).toBe(0)
expect(state.displayRepeatUsed.value).toBe(0)
expect(state.isPaused.value).toBe(true) // isPaused should NOT be reset
})
})
describe('togglePause', () => {
it('should toggle pause state', () => {
const widget = createMockWidget()
const state = useLoraCyclerState(widget)
expect(state.isPaused.value).toBe(false)
state.togglePause()
expect(state.isPaused.value).toBe(true)
state.togglePause()
expect(state.isPaused.value).toBe(false)
})
})
describe('generateNextIndex', () => {
it('should shift indices correctly', () => {
const widget = createMockWidget()
const state = useLoraCyclerState(widget)
state.totalCount.value = 5
state.currentIndex.value = 1
state.nextIndex.value = 2
// First call: executionIndex becomes 2 (previous nextIndex), nextIndex becomes 3
state.generateNextIndex()
expect(state.executionIndex.value).toBe(2)
expect(state.nextIndex.value).toBe(3)
// Second call: executionIndex becomes 3, nextIndex becomes 4
state.generateNextIndex()
expect(state.executionIndex.value).toBe(3)
expect(state.nextIndex.value).toBe(4)
})
it('should wrap index from totalCount to 1', () => {
const widget = createMockWidget()
const state = useLoraCyclerState(widget)
state.totalCount.value = 5
state.nextIndex.value = 5 // At the last index
state.generateNextIndex()
expect(state.executionIndex.value).toBe(5)
expect(state.nextIndex.value).toBe(1) // Wrapped to 1
})
it('should use currentIndex when nextIndex is null', () => {
const widget = createMockWidget()
const state = useLoraCyclerState(widget)
state.totalCount.value = 5
state.currentIndex.value = 3
state.nextIndex.value = null
state.generateNextIndex()
// executionIndex becomes previous nextIndex (null)
expect(state.executionIndex.value).toBeNull()
// nextIndex is calculated from currentIndex (3) -> 4
expect(state.nextIndex.value).toBe(4)
})
})
describe('initializeNextIndex', () => {
it('should initialize nextIndex to currentIndex + 1 when null', () => {
const widget = createMockWidget()
const state = useLoraCyclerState(widget)
state.totalCount.value = 5
state.currentIndex.value = 1
state.nextIndex.value = null
state.initializeNextIndex()
expect(state.nextIndex.value).toBe(2)
})
it('should wrap nextIndex when currentIndex is at totalCount', () => {
const widget = createMockWidget()
const state = useLoraCyclerState(widget)
state.totalCount.value = 5
state.currentIndex.value = 5
state.nextIndex.value = null
state.initializeNextIndex()
expect(state.nextIndex.value).toBe(1) // Wrapped
})
it('should not change nextIndex if already set', () => {
const widget = createMockWidget()
const state = useLoraCyclerState(widget)
state.totalCount.value = 5
state.currentIndex.value = 1
state.nextIndex.value = 4
state.initializeNextIndex()
expect(state.nextIndex.value).toBe(4) // Unchanged
})
})
describe('Index Wrapping Edge Cases', () => {
it('should handle single item pool', () => {
const widget = createMockWidget()
const state = useLoraCyclerState(widget)
state.totalCount.value = 1
state.currentIndex.value = 1
state.nextIndex.value = null
state.initializeNextIndex()
expect(state.nextIndex.value).toBe(1) // Wraps back to 1
})
it('should handle zero total count gracefully', () => {
const widget = createMockWidget()
const state = useLoraCyclerState(widget)
state.totalCount.value = 0
state.currentIndex.value = 1
state.nextIndex.value = null
state.initializeNextIndex()
// Should still calculate, even if totalCount is 0
expect(state.nextIndex.value).toBe(2) // No wrapping since totalCount <= 0
})
})
describe('hashPoolConfig', () => {
it('should generate consistent hash for same config', () => {
const widget = createMockWidget()
const state = useLoraCyclerState(widget)
const config1 = createMockPoolConfig()
const config2 = createMockPoolConfig()
const hash1 = state.hashPoolConfig(config1)
const hash2 = state.hashPoolConfig(config2)
expect(hash1).toBe(hash2)
})
it('should generate different hash for different configs', () => {
const widget = createMockWidget()
const state = useLoraCyclerState(widget)
const config1 = createMockPoolConfig({
filters: {
baseModels: ['SD 1.5'],
tags: { include: [], exclude: [] },
folders: { include: [], exclude: [] },
license: { noCreditRequired: false, allowSelling: false }
}
})
const config2 = createMockPoolConfig({
filters: {
baseModels: ['SDXL'],
tags: { include: [], exclude: [] },
folders: { include: [], exclude: [] },
license: { noCreditRequired: false, allowSelling: false }
}
})
const hash1 = state.hashPoolConfig(config1)
const hash2 = state.hashPoolConfig(config2)
expect(hash1).not.toBe(hash2)
})
it('should return empty string for null config', () => {
const widget = createMockWidget()
const state = useLoraCyclerState(widget)
expect(state.hashPoolConfig(null)).toBe('')
})
it('should return empty string for config without filters', () => {
const widget = createMockWidget()
const state = useLoraCyclerState(widget)
const config = { version: 1, preview: { matchCount: 0, lastUpdated: 0 } } as any
expect(state.hashPoolConfig(config)).toBe('')
})
})
describe('Clip Strength Synchronization', () => {
it('should sync clipStrength with modelStrength when useCustomClipRange is false', async () => {
const widget = createMockWidget()
const state = useLoraCyclerState(widget)
state.useCustomClipRange.value = false
state.modelStrength.value = 0.5
// Wait for Vue reactivity
await vi.waitFor(() => {
expect(state.clipStrength.value).toBe(0.5)
})
})
it('should not sync clipStrength when useCustomClipRange is true', async () => {
const widget = createMockWidget()
const state = useLoraCyclerState(widget)
state.useCustomClipRange.value = true
state.clipStrength.value = 0.7
state.modelStrength.value = 0.5
// clipStrength should remain unchanged
await vi.waitFor(() => {
expect(state.clipStrength.value).toBe(0.7)
})
})
})
describe('Widget Value Synchronization', () => {
it('should update widget.value when state changes', async () => {
const widget = createMockWidget()
const state = useLoraCyclerState(widget)
state.currentIndex.value = 3
state.repeatCount.value = 2
// Wait for Vue reactivity
await vi.waitFor(() => {
expect(widget.value?.current_index).toBe(3)
expect(widget.value?.repeat_count).toBe(2)
})
})
})
describe('Repeat Logic State', () => {
it('should track repeatUsed correctly', () => {
const widget = createMockWidget()
const state = useLoraCyclerState(widget)
state.repeatCount.value = 3
expect(state.repeatUsed.value).toBe(0)
state.repeatUsed.value = 1
expect(state.repeatUsed.value).toBe(1)
state.repeatUsed.value = 3
expect(state.repeatUsed.value).toBe(3)
})
})
describe('fetchCyclerList', () => {
it('should call API and return lora list', async () => {
const mockLoras = [
{ file_name: 'lora1.safetensors', model_name: 'LoRA 1' },
{ file_name: 'lora2.safetensors', model_name: 'LoRA 2' }
]
setupFetchMock({ success: true, loras: mockLoras })
const widget = createMockWidget()
const state = useLoraCyclerState(widget)
const result = await state.fetchCyclerList(null)
expect(result).toEqual(mockLoras)
expect(state.isLoading.value).toBe(false)
})
it('should include pool config filters in request', async () => {
const mockFetch = setupFetchMock({ success: true, loras: [] })
const widget = createMockWidget()
const state = useLoraCyclerState(widget)
const poolConfig = createMockPoolConfig()
await state.fetchCyclerList(poolConfig)
expect(mockFetch).toHaveBeenCalledWith(
'/api/lm/loras/cycler-list',
expect.objectContaining({
method: 'POST',
body: expect.stringContaining('pool_config')
})
)
})
it('should set isLoading during fetch', async () => {
let resolvePromise: (value: unknown) => void
const pendingPromise = new Promise(resolve => {
resolvePromise = resolve
})
// Use mockFetch from setup instead of overriding global
const { mockFetch } = await import('../setup')
mockFetch.mockReset()
mockFetch.mockReturnValue(pendingPromise)
const widget = createMockWidget()
const state = useLoraCyclerState(widget)
const fetchPromise = state.fetchCyclerList(null)
expect(state.isLoading.value).toBe(true)
// Resolve the fetch
resolvePromise!({
ok: true,
json: () => Promise.resolve({ success: true, loras: [] })
})
await fetchPromise
expect(state.isLoading.value).toBe(false)
})
})
describe('refreshList', () => {
it('should update totalCount from API response', async () => {
const mockLoras = [
{ file_name: 'lora1.safetensors', model_name: 'LoRA 1' },
{ file_name: 'lora2.safetensors', model_name: 'LoRA 2' },
{ file_name: 'lora3.safetensors', model_name: 'LoRA 3' }
]
// Reset and setup fresh mock
resetFetchMock()
setupFetchMock({ success: true, loras: mockLoras })
const widget = createMockWidget()
const state = useLoraCyclerState(widget)
await state.refreshList(null)
expect(state.totalCount.value).toBe(3)
})
it('should reset index to 1 when pool config hash changes', async () => {
resetFetchMock()
setupFetchMock({ success: true, loras: [{ file_name: 'lora1.safetensors', model_name: 'LoRA 1' }] })
const widget = createMockWidget()
const state = useLoraCyclerState(widget)
// Set initial state
state.currentIndex.value = 5
state.poolConfigHash.value = 'old-hash'
// Refresh with new config (different hash)
const newConfig = createMockPoolConfig({
filters: {
baseModels: ['SDXL'],
tags: { include: [], exclude: [] },
folders: { include: [], exclude: [] },
license: { noCreditRequired: false, allowSelling: false }
}
})
await state.refreshList(newConfig)
expect(state.currentIndex.value).toBe(1)
})
it('should clamp index when totalCount decreases', async () => {
// Setup mock first, then create state
resetFetchMock()
setupFetchMock({
success: true,
loras: [
{ file_name: 'lora1.safetensors', model_name: 'LoRA 1' },
{ file_name: 'lora2.safetensors', model_name: 'LoRA 2' },
{ file_name: 'lora3.safetensors', model_name: 'LoRA 3' }
]
})
const widget = createMockWidget()
const state = useLoraCyclerState(widget)
// Set initial state with high index
state.currentIndex.value = 10
state.totalCount.value = 10
await state.refreshList(null)
expect(state.totalCount.value).toBe(3)
expect(state.currentIndex.value).toBe(3) // Clamped to max
})
it('should update currentLoraName and currentLoraFilename', async () => {
resetFetchMock()
setupFetchMock({
success: true,
loras: [
{ file_name: 'lora1.safetensors', model_name: 'LoRA 1' },
{ file_name: 'lora2.safetensors', model_name: 'LoRA 2' }
]
})
const widget = createMockWidget()
const state = useLoraCyclerState(widget)
// Set totalCount first so setIndex works, then set index
state.totalCount.value = 2
state.currentIndex.value = 2
await state.refreshList(null)
expect(state.currentLoraFilename.value).toBe('lora2.safetensors')
})
it('should handle empty list gracefully', async () => {
resetFetchMock()
setupFetchMock({ success: true, loras: [] })
const widget = createMockWidget()
const state = useLoraCyclerState(widget)
state.currentIndex.value = 5
state.totalCount.value = 5
await state.refreshList(null)
expect(state.totalCount.value).toBe(0)
// When totalCount is 0, Math.max(1, 0) = 1, but if currentIndex > totalCount it gets clamped to max(1, totalCount)
// Looking at the actual code: Math.max(1, totalCount) where totalCount=0 gives 1
expect(state.currentIndex.value).toBe(1)
expect(state.currentLoraName.value).toBe('')
expect(state.currentLoraFilename.value).toBe('')
})
})
describe('isClipStrengthDisabled computed', () => {
it('should return true when useCustomClipRange is false', () => {
const widget = createMockWidget()
const state = useLoraCyclerState(widget)
state.useCustomClipRange.value = false
expect(state.isClipStrengthDisabled.value).toBe(true)
})
it('should return false when useCustomClipRange is true', () => {
const widget = createMockWidget()
const state = useLoraCyclerState(widget)
state.useCustomClipRange.value = true
expect(state.isClipStrengthDisabled.value).toBe(false)
})
})
})

View File

@@ -0,0 +1,175 @@
/**
* Test fixtures for LoRA Cycler testing
*/
import type { CyclerConfig, LoraPoolConfig } from '@/composables/types'
import type { CyclerLoraItem } from '@/composables/useLoraCyclerState'
/**
* Creates a default CyclerConfig for testing
*/
export function createMockCyclerConfig(overrides: Partial<CyclerConfig> = {}): CyclerConfig {
return {
current_index: 1,
total_count: 5,
pool_config_hash: '',
model_strength: 1.0,
clip_strength: 1.0,
use_same_clip_strength: true,
sort_by: 'filename',
current_lora_name: 'lora1.safetensors',
current_lora_filename: 'lora1.safetensors',
execution_index: null,
next_index: null,
repeat_count: 1,
repeat_used: 0,
is_paused: false,
...overrides
}
}
/**
* Creates a mock LoraPoolConfig for testing
*/
export function createMockPoolConfig(overrides: Partial<LoraPoolConfig> = {}): LoraPoolConfig {
return {
version: 1,
filters: {
baseModels: ['SD 1.5'],
tags: { include: [], exclude: [] },
folders: { include: [], exclude: [] },
license: {
noCreditRequired: false,
allowSelling: false
}
},
preview: { matchCount: 10, lastUpdated: Date.now() },
...overrides
}
}
/**
* Creates a list of mock LoRA items for testing
*/
export function createMockLoraList(count: number = 5): CyclerLoraItem[] {
return Array.from({ length: count }, (_, i) => ({
file_name: `lora${i + 1}.safetensors`,
model_name: `LoRA Model ${i + 1}`
}))
}
/**
* Creates a mock widget object for testing useLoraCyclerState
*/
export function createMockWidget(initialValue?: CyclerConfig) {
return {
value: initialValue,
callback: undefined as ((v: CyclerConfig) => void) | undefined
}
}
/**
* Creates a mock node object for testing component integration
*/
export function createMockNode(options: {
id?: number
poolConfig?: LoraPoolConfig | null
} = {}) {
const { id = 1, poolConfig = null } = options
return {
id,
inputs: [],
widgets: [],
graph: null,
getPoolConfig: () => poolConfig,
onExecuted: undefined as ((output: unknown) => void) | undefined
}
}
/**
* Creates mock execution output from the backend
*/
export function createMockExecutionOutput(options: {
nextIndex?: number
totalCount?: number
nextLoraName?: string
nextLoraFilename?: string
currentLoraName?: string
currentLoraFilename?: string
} = {}) {
const {
nextIndex = 2,
totalCount = 5,
nextLoraName = 'lora2.safetensors',
nextLoraFilename = 'lora2.safetensors',
currentLoraName = 'lora1.safetensors',
currentLoraFilename = 'lora1.safetensors'
} = options
return {
next_index: [nextIndex],
total_count: [totalCount],
next_lora_name: [nextLoraName],
next_lora_filename: [nextLoraFilename],
current_lora_name: [currentLoraName],
current_lora_filename: [currentLoraFilename]
}
}
/**
* Sample LoRA lists for specific test scenarios
*/
export const SAMPLE_LORA_LISTS = {
// 3 LoRAs for simple cycling tests
small: createMockLoraList(3),
// 5 LoRAs for standard tests
medium: createMockLoraList(5),
// 10 LoRAs for larger tests
large: createMockLoraList(10),
// Empty list for edge case testing
empty: [] as CyclerLoraItem[],
// Single LoRA for edge case testing
single: createMockLoraList(1)
}
/**
* Sample pool configs for testing
*/
export const SAMPLE_POOL_CONFIGS = {
// Default SD 1.5 filter
sd15: createMockPoolConfig({
filters: {
baseModels: ['SD 1.5'],
tags: { include: [], exclude: [] },
folders: { include: [], exclude: [] },
license: { noCreditRequired: false, allowSelling: false }
}
}),
// SDXL filter
sdxl: createMockPoolConfig({
filters: {
baseModels: ['SDXL'],
tags: { include: [], exclude: [] },
folders: { include: [], exclude: [] },
license: { noCreditRequired: false, allowSelling: false }
}
}),
// Filter with tags
withTags: createMockPoolConfig({
filters: {
baseModels: ['SD 1.5'],
tags: { include: ['anime', 'style'], exclude: ['realistic'] },
folders: { include: [], exclude: [] },
license: { noCreditRequired: false, allowSelling: false }
}
}),
// Empty/null config
empty: null as LoraPoolConfig | null
}

View File

@@ -0,0 +1,885 @@
/**
* Integration tests for batch queue execution scenarios
*
* These tests simulate ComfyUI's execution modes to verify correct LoRA cycling behavior.
*/
import { describe, it, expect, beforeEach, vi } from 'vitest'
import { useLoraCyclerState } from '@/composables/useLoraCyclerState'
import type { CyclerConfig } from '@/composables/types'
import {
createMockWidget,
createMockCyclerConfig,
createMockLoraList,
createMockPoolConfig
} from '../fixtures/mockConfigs'
import { setupFetchMock, resetFetchMock } from '../setup'
import { BatchQueueSimulator, IndexTracker } from '../utils/BatchQueueSimulator'
/**
* Creates a test harness that mimics the LoraCyclerWidget's behavior
*/
function createTestHarness(options: {
totalCount?: number
initialIndex?: number
repeatCount?: number
isPaused?: boolean
} = {}) {
const {
totalCount = 5,
initialIndex = 1,
repeatCount = 1,
isPaused = false
} = options
const widget = createMockWidget() as any
const state = useLoraCyclerState(widget)
// Initialize state
state.totalCount.value = totalCount
state.currentIndex.value = initialIndex
state.repeatCount.value = repeatCount
state.isPaused.value = isPaused
// Track if first execution
const HAS_EXECUTED = Symbol('HAS_EXECUTED')
widget[HAS_EXECUTED] = false
// Execution queue for batch synchronization
interface ExecutionContext {
isPaused: boolean
repeatUsed: number
repeatCount: number
shouldAdvanceDisplay: boolean
displayRepeatUsed: number // Value to show in UI after completion
}
const executionQueue: ExecutionContext[] = []
// beforeQueued hook (mirrors LoraCyclerWidget.vue logic)
widget.beforeQueued = () => {
if (state.isPaused.value) {
executionQueue.push({
isPaused: true,
repeatUsed: state.repeatUsed.value,
repeatCount: state.repeatCount.value,
shouldAdvanceDisplay: false,
displayRepeatUsed: state.displayRepeatUsed.value // Keep current display value when paused
})
// CRITICAL: Clear execution_index when paused to force backend to use current_index
const pausedConfig = state.buildConfig()
pausedConfig.execution_index = null
widget.value = pausedConfig
return
}
if (widget[HAS_EXECUTED]) {
if (state.repeatUsed.value < state.repeatCount.value) {
state.repeatUsed.value++
} else {
state.repeatUsed.value = 1
state.generateNextIndex()
}
} else {
state.repeatUsed.value = 1
state.initializeNextIndex()
widget[HAS_EXECUTED] = true
}
const shouldAdvanceDisplay = state.repeatUsed.value >= state.repeatCount.value
// Calculate the display value to show after this execution completes
// When advancing to a new LoRA: reset to 0 (fresh start for new LoRA)
// When repeating same LoRA: show current repeat step
const displayRepeatUsed = shouldAdvanceDisplay ? 0 : state.repeatUsed.value
executionQueue.push({
isPaused: false,
repeatUsed: state.repeatUsed.value,
repeatCount: state.repeatCount.value,
shouldAdvanceDisplay,
displayRepeatUsed
})
widget.value = state.buildConfig()
}
// Mock node with onExecuted
const node = {
id: 1,
onExecuted: (output: any) => {
const context = executionQueue.shift()
const shouldAdvanceDisplay = context
? context.shouldAdvanceDisplay
: (!state.isPaused.value && state.repeatUsed.value >= state.repeatCount.value)
// Update displayRepeatUsed (deferred like index updates)
if (context) {
state.displayRepeatUsed.value = context.displayRepeatUsed
}
if (shouldAdvanceDisplay && output?.next_index !== undefined) {
const val = Array.isArray(output.next_index) ? output.next_index[0] : output.next_index
state.currentIndex.value = val
}
if (output?.total_count !== undefined) {
const val = Array.isArray(output.total_count) ? output.total_count[0] : output.total_count
state.totalCount.value = val
}
if (shouldAdvanceDisplay) {
if (output?.next_lora_name !== undefined) {
const val = Array.isArray(output.next_lora_name) ? output.next_lora_name[0] : output.next_lora_name
state.currentLoraName.value = val
}
if (output?.next_lora_filename !== undefined) {
const val = Array.isArray(output.next_lora_filename) ? output.next_lora_filename[0] : output.next_lora_filename
state.currentLoraFilename.value = val
}
}
}
}
// Reset execution state (mimics manual index change)
const resetExecutionState = () => {
widget[HAS_EXECUTED] = false
state.executionIndex.value = null
state.nextIndex.value = null
executionQueue.length = 0
}
return {
widget,
state,
node,
executionQueue,
resetExecutionState,
getConfig: () => state.buildConfig(),
HAS_EXECUTED
}
}
describe('Batch Queue Integration Tests', () => {
beforeEach(() => {
resetFetchMock()
})
describe('Basic Cycling', () => {
it('should cycle through N LoRAs in batch of N (batch queue mode)', async () => {
const harness = createTestHarness({ totalCount: 3 })
const simulator = new BatchQueueSimulator({ totalCount: 3 })
// Simulate batch queue of 3 prompts
await simulator.runBatchQueue(
3,
{
beforeQueued: () => harness.widget.beforeQueued(),
onExecuted: (output) => harness.node.onExecuted(output)
},
() => harness.getConfig()
)
// After cycling through all 3, currentIndex should wrap back to 1
// First execution: index 1, next becomes 2
// Second execution: index 2, next becomes 3
// Third execution: index 3, next becomes 1
expect(harness.state.currentIndex.value).toBe(1)
})
it('should cycle through N LoRAs in batch of N (sequential mode)', async () => {
const harness = createTestHarness({ totalCount: 3 })
const simulator = new BatchQueueSimulator({ totalCount: 3 })
// Simulate sequential execution of 3 prompts
await simulator.runSequential(
3,
{
beforeQueued: () => harness.widget.beforeQueued(),
onExecuted: (output) => harness.node.onExecuted(output)
},
() => harness.getConfig()
)
// Same result as batch mode
expect(harness.state.currentIndex.value).toBe(1)
})
it('should handle partial cycle (batch of 2 in pool of 5)', async () => {
const harness = createTestHarness({ totalCount: 5, initialIndex: 1 })
const simulator = new BatchQueueSimulator({ totalCount: 5 })
await simulator.runBatchQueue(
2,
{
beforeQueued: () => harness.widget.beforeQueued(),
onExecuted: (output) => harness.node.onExecuted(output)
},
() => harness.getConfig()
)
// After 2 executions starting from 1: 1 -> 2 -> 3
expect(harness.state.currentIndex.value).toBe(3)
})
})
describe('Repeat Functionality', () => {
it('should repeat each LoRA repeatCount times', async () => {
const harness = createTestHarness({ totalCount: 3, repeatCount: 2 })
const simulator = new BatchQueueSimulator({ totalCount: 3 })
// With repeatCount=2, need 6 executions to cycle through 3 LoRAs
await simulator.runBatchQueue(
6,
{
beforeQueued: () => harness.widget.beforeQueued(),
onExecuted: (output) => harness.node.onExecuted(output)
},
() => harness.getConfig()
)
// Should have cycled back to beginning
expect(harness.state.currentIndex.value).toBe(1)
})
it('should track repeatUsed correctly during batch', async () => {
const harness = createTestHarness({ totalCount: 3, repeatCount: 3 })
// First beforeQueued: repeatUsed = 1
harness.widget.beforeQueued()
expect(harness.state.repeatUsed.value).toBe(1)
// Second beforeQueued: repeatUsed = 2
harness.widget.beforeQueued()
expect(harness.state.repeatUsed.value).toBe(2)
// Third beforeQueued: repeatUsed = 3 (will advance on next)
harness.widget.beforeQueued()
expect(harness.state.repeatUsed.value).toBe(3)
// Fourth beforeQueued: repeatUsed resets to 1, index advances
harness.widget.beforeQueued()
expect(harness.state.repeatUsed.value).toBe(1)
expect(harness.state.nextIndex.value).toBe(3) // Advanced from 2 to 3
})
it('should not advance display until repeat cycle completes', async () => {
const harness = createTestHarness({ totalCount: 5, repeatCount: 2 })
const simulator = new BatchQueueSimulator({ totalCount: 5 })
// First execution: repeatUsed=1 < repeatCount=2, shouldAdvanceDisplay=false
// Second execution: repeatUsed=2 >= repeatCount=2, shouldAdvanceDisplay=true
const indexHistory: number[] = []
// Override onExecuted to track index changes
const originalOnExecuted = harness.node.onExecuted
harness.node.onExecuted = (output: any) => {
originalOnExecuted(output)
indexHistory.push(harness.state.currentIndex.value)
}
await simulator.runBatchQueue(
4,
{
beforeQueued: () => harness.widget.beforeQueued(),
onExecuted: (output) => harness.node.onExecuted(output)
},
() => harness.getConfig()
)
// Index should only change on 2nd and 4th execution
// Starting at 1: stay 1, advance to 2, stay 2, advance to 3
expect(indexHistory).toEqual([1, 2, 2, 3])
})
it('should defer displayRepeatUsed updates until workflow completion', async () => {
const harness = createTestHarness({ totalCount: 3, repeatCount: 3 })
// Initial state
expect(harness.state.displayRepeatUsed.value).toBe(0)
// Queue 3 executions in batch mode (all beforeQueued before any onExecuted)
harness.widget.beforeQueued() // repeatUsed = 1
harness.widget.beforeQueued() // repeatUsed = 2
harness.widget.beforeQueued() // repeatUsed = 3
// displayRepeatUsed should NOT have changed yet (still 0)
// because no onExecuted has been called
expect(harness.state.displayRepeatUsed.value).toBe(0)
// Now simulate workflow completions
harness.node.onExecuted({ next_index: 1 })
expect(harness.state.displayRepeatUsed.value).toBe(1)
harness.node.onExecuted({ next_index: 1 })
expect(harness.state.displayRepeatUsed.value).toBe(2)
harness.node.onExecuted({ next_index: 2 })
// After completing repeat cycle, displayRepeatUsed resets to 0
expect(harness.state.displayRepeatUsed.value).toBe(0)
})
it('should reset displayRepeatUsed to 0 when advancing to new LoRA', async () => {
const harness = createTestHarness({ totalCount: 3, repeatCount: 2 })
const simulator = new BatchQueueSimulator({ totalCount: 3 })
const displayHistory: number[] = []
const originalOnExecuted = harness.node.onExecuted
harness.node.onExecuted = (output: any) => {
originalOnExecuted(output)
displayHistory.push(harness.state.displayRepeatUsed.value)
}
// Run 4 executions: 2 repeats of LoRA 1, 2 repeats of LoRA 2
await simulator.runBatchQueue(
4,
{
beforeQueued: () => harness.widget.beforeQueued(),
onExecuted: (output) => harness.node.onExecuted(output)
},
() => harness.getConfig()
)
// displayRepeatUsed should show:
// 1st exec: 1 (first repeat of LoRA 1)
// 2nd exec: 0 (complete, reset for next LoRA)
// 3rd exec: 1 (first repeat of LoRA 2)
// 4th exec: 0 (complete, reset for next LoRA)
expect(displayHistory).toEqual([1, 0, 1, 0])
})
it('should show current repeat step when not advancing', async () => {
const harness = createTestHarness({ totalCount: 3, repeatCount: 4 })
const simulator = new BatchQueueSimulator({ totalCount: 3 })
const displayHistory: number[] = []
const originalOnExecuted = harness.node.onExecuted
harness.node.onExecuted = (output: any) => {
originalOnExecuted(output)
displayHistory.push(harness.state.displayRepeatUsed.value)
}
// Run 4 executions: all 4 repeats of the same LoRA
await simulator.runBatchQueue(
4,
{
beforeQueued: () => harness.widget.beforeQueued(),
onExecuted: (output) => harness.node.onExecuted(output)
},
() => harness.getConfig()
)
// displayRepeatUsed should show:
// 1st exec: 1 (repeat 1/4, not advancing)
// 2nd exec: 2 (repeat 2/4, not advancing)
// 3rd exec: 3 (repeat 3/4, not advancing)
// 4th exec: 0 (repeat 4/4, complete, reset for next LoRA)
expect(displayHistory).toEqual([1, 2, 3, 0])
})
})
describe('Pause Functionality', () => {
it('should maintain index when paused', async () => {
const harness = createTestHarness({ totalCount: 5, isPaused: true })
const simulator = new BatchQueueSimulator({ totalCount: 5 })
await simulator.runBatchQueue(
3,
{
beforeQueued: () => harness.widget.beforeQueued(),
onExecuted: (output) => harness.node.onExecuted(output)
},
() => harness.getConfig()
)
// Index should not advance when paused
expect(harness.state.currentIndex.value).toBe(1)
})
it('should not count paused executions toward repeat limit', async () => {
const harness = createTestHarness({ totalCount: 5, repeatCount: 2 })
// Run 2 executions while paused
harness.state.isPaused.value = true
harness.widget.beforeQueued()
harness.widget.beforeQueued()
// repeatUsed should still be 0 (paused executions don't count)
expect(harness.state.repeatUsed.value).toBe(0)
// Unpause and run
harness.state.isPaused.value = false
harness.widget.beforeQueued()
expect(harness.state.repeatUsed.value).toBe(1)
})
it('should preserve displayRepeatUsed when paused', async () => {
const harness = createTestHarness({ totalCount: 5, repeatCount: 3 })
// Run one execution to set displayRepeatUsed
harness.widget.beforeQueued()
harness.node.onExecuted({ next_index: 1 })
expect(harness.state.displayRepeatUsed.value).toBe(1)
// Pause
harness.state.isPaused.value = true
// Queue and execute while paused
harness.widget.beforeQueued()
harness.node.onExecuted({ next_index: 1 })
// displayRepeatUsed should remain at 1 (paused executions don't change it)
expect(harness.state.displayRepeatUsed.value).toBe(1)
// Queue another paused execution
harness.widget.beforeQueued()
harness.node.onExecuted({ next_index: 1 })
// Still should be 1
expect(harness.state.displayRepeatUsed.value).toBe(1)
})
it('should use same LoRA when pause is toggled mid-batch', async () => {
// This tests the critical bug scenario:
// 1. User queues multiple prompts (not paused)
// 2. All beforeQueued calls complete, each advancing execution_index
// 3. User clicks pause
// 4. onExecuted starts firing - paused executions should use current_index, not execution_index
const harness = createTestHarness({ totalCount: 5 })
// Queue first prompt (not paused) - this sets up execution_index
harness.widget.beforeQueued()
const config1 = harness.getConfig()
expect(config1.execution_index).toBeNull() // First execution uses current_index
// User clicks pause mid-batch
harness.state.isPaused.value = true
// Queue subsequent prompts while paused
harness.widget.beforeQueued()
const config2 = harness.getConfig()
// CRITICAL: execution_index should be null when paused to force backend to use current_index
expect(config2.execution_index).toBeNull()
harness.widget.beforeQueued()
const config3 = harness.getConfig()
expect(config3.execution_index).toBeNull()
// Verify execution queue has correct context
expect(harness.executionQueue.length).toBe(3)
expect(harness.executionQueue[0].isPaused).toBe(false)
expect(harness.executionQueue[1].isPaused).toBe(true)
expect(harness.executionQueue[2].isPaused).toBe(true)
})
it('should have null execution_index in widget.value when paused even after non-paused queues', async () => {
// More detailed test for the execution_index clearing behavior
// This tests that widget.value (what backend receives) has null execution_index
const harness = createTestHarness({ totalCount: 5 })
// Queue 3 prompts while not paused
harness.widget.beforeQueued()
harness.widget.beforeQueued()
harness.widget.beforeQueued()
// Verify execution_index was set by non-paused queues in widget.value
expect(harness.widget.value.execution_index).not.toBeNull()
// User pauses
harness.state.isPaused.value = true
// Queue while paused - should clear execution_index in widget.value
// This is the value that gets sent to the backend
harness.widget.beforeQueued()
expect(harness.widget.value.execution_index).toBeNull()
// State's executionIndex may still have the old value (that's fine)
// What matters is widget.value which is what the backend uses
})
it('should have hasQueuedPrompts true when execution queue has items', async () => {
// This tests the pause button disabled state
const harness = createTestHarness({ totalCount: 5 })
// Initially no queued prompts
expect(harness.executionQueue.length).toBe(0)
// Queue some prompts
harness.widget.beforeQueued()
harness.widget.beforeQueued()
harness.widget.beforeQueued()
// Execution queue should have items
expect(harness.executionQueue.length).toBe(3)
})
it('should have empty execution queue after all executions complete', async () => {
// This tests that pause button becomes enabled after executions complete
const harness = createTestHarness({ totalCount: 5 })
const simulator = new BatchQueueSimulator({ totalCount: 5 })
// Run batch queue execution
await simulator.runBatchQueue(
3,
{
beforeQueued: () => harness.widget.beforeQueued(),
onExecuted: (output) => harness.node.onExecuted(output)
},
() => harness.getConfig()
)
// After all executions, queue should be empty
expect(harness.executionQueue.length).toBe(0)
})
it('should resume cycling after unpause', async () => {
const harness = createTestHarness({ totalCount: 3, initialIndex: 2 })
const simulator = new BatchQueueSimulator({ totalCount: 3 })
// Execute once while not paused
await simulator.runSingle(
{
beforeQueued: () => harness.widget.beforeQueued(),
onExecuted: (output) => harness.node.onExecuted(output)
},
() => harness.getConfig()
)
// Pause
harness.state.isPaused.value = true
// Execute twice while paused
await simulator.runBatchQueue(
2,
{
beforeQueued: () => harness.widget.beforeQueued(),
onExecuted: (output) => harness.node.onExecuted(output)
},
() => harness.getConfig()
)
// Unpause and execute
harness.state.isPaused.value = false
await simulator.runSingle(
{
beforeQueued: () => harness.widget.beforeQueued(),
onExecuted: (output) => harness.node.onExecuted(output)
},
() => harness.getConfig()
)
// Should continue from where it left off (index 3 -> 1)
expect(harness.state.currentIndex.value).toBe(1)
})
})
describe('Manual Index Change', () => {
it('should reset execution state on manual index change', async () => {
const harness = createTestHarness({ totalCount: 5 })
// Execute a few times
harness.widget.beforeQueued()
harness.widget.beforeQueued()
expect(harness.widget[harness.HAS_EXECUTED]).toBe(true)
expect(harness.executionQueue.length).toBe(2)
// User manually changes index (mimics handleIndexUpdate)
harness.resetExecutionState()
harness.state.setIndex(4)
expect(harness.widget[harness.HAS_EXECUTED]).toBe(false)
expect(harness.state.executionIndex.value).toBeNull()
expect(harness.state.nextIndex.value).toBeNull()
expect(harness.executionQueue.length).toBe(0)
})
it('should start fresh cycle from manual index', async () => {
const harness = createTestHarness({ totalCount: 5 })
const simulator = new BatchQueueSimulator({ totalCount: 5 })
// Execute twice starting from 1
await simulator.runBatchQueue(
2,
{
beforeQueued: () => harness.widget.beforeQueued(),
onExecuted: (output) => harness.node.onExecuted(output)
},
() => harness.getConfig()
)
expect(harness.state.currentIndex.value).toBe(3)
// User manually sets index to 1
harness.resetExecutionState()
harness.state.setIndex(1)
// Execute again - should start fresh from 1
await simulator.runBatchQueue(
2,
{
beforeQueued: () => harness.widget.beforeQueued(),
onExecuted: (output) => harness.node.onExecuted(output)
},
() => harness.getConfig()
)
expect(harness.state.currentIndex.value).toBe(3)
})
})
describe('Execution Queue Mismatch', () => {
it('should handle interrupted execution (queue > executed)', async () => {
const harness = createTestHarness({ totalCount: 5 })
const simulator = new BatchQueueSimulator({ totalCount: 5 })
// Queue 5 but only execute 2 (simulates cancel)
await simulator.runInterrupted(
5, // queued
2, // executed
{
beforeQueued: () => harness.widget.beforeQueued(),
onExecuted: (output) => harness.node.onExecuted(output)
},
() => harness.getConfig()
)
// 3 contexts remain in queue
expect(harness.executionQueue.length).toBe(3)
// Index should reflect only the 2 executions that completed
expect(harness.state.currentIndex.value).toBe(3)
})
it('should recover from mismatch on next manual index change', async () => {
const harness = createTestHarness({ totalCount: 5 })
const simulator = new BatchQueueSimulator({ totalCount: 5 })
// Create mismatch
await simulator.runInterrupted(
5,
2,
{
beforeQueued: () => harness.widget.beforeQueued(),
onExecuted: (output) => harness.node.onExecuted(output)
},
() => harness.getConfig()
)
expect(harness.executionQueue.length).toBe(3)
// Manual index change clears queue
harness.resetExecutionState()
harness.state.setIndex(1)
expect(harness.executionQueue.length).toBe(0)
// Can execute normally again
await simulator.runSingle(
{
beforeQueued: () => harness.widget.beforeQueued(),
onExecuted: (output) => harness.node.onExecuted(output)
},
() => harness.getConfig()
)
expect(harness.state.currentIndex.value).toBe(2)
})
})
describe('Edge Cases', () => {
it('should handle single item pool', async () => {
const harness = createTestHarness({ totalCount: 1 })
const simulator = new BatchQueueSimulator({ totalCount: 1 })
await simulator.runBatchQueue(
3,
{
beforeQueued: () => harness.widget.beforeQueued(),
onExecuted: (output) => harness.node.onExecuted(output)
},
() => harness.getConfig()
)
// Should always stay at index 1
expect(harness.state.currentIndex.value).toBe(1)
})
it('should handle empty pool gracefully', async () => {
const harness = createTestHarness({ totalCount: 0 })
// beforeQueued should still work without errors
expect(() => harness.widget.beforeQueued()).not.toThrow()
})
it('should handle rapid sequential executions', async () => {
const harness = createTestHarness({ totalCount: 5 })
const simulator = new BatchQueueSimulator({ totalCount: 5 })
// Run 20 sequential executions
await simulator.runSequential(
20,
{
beforeQueued: () => harness.widget.beforeQueued(),
onExecuted: (output) => harness.node.onExecuted(output)
},
() => harness.getConfig()
)
// 20 % 5 = 0, so should wrap back to 1
// But first execution uses index 1, so after 20 executions we're at 21 % 5 = 1
expect(harness.state.currentIndex.value).toBe(1)
})
it('should preserve state consistency across many cycles', async () => {
const harness = createTestHarness({ totalCount: 3, repeatCount: 2 })
const simulator = new BatchQueueSimulator({ totalCount: 3 })
// Run 100 executions in batches
for (let batch = 0; batch < 10; batch++) {
await simulator.runBatchQueue(
10,
{
beforeQueued: () => harness.widget.beforeQueued(),
onExecuted: (output) => harness.node.onExecuted(output)
},
() => harness.getConfig()
)
}
// Verify state is still valid
expect(harness.state.currentIndex.value).toBeGreaterThanOrEqual(1)
expect(harness.state.currentIndex.value).toBeLessThanOrEqual(3)
expect(harness.state.repeatUsed.value).toBeGreaterThanOrEqual(1)
expect(harness.state.repeatUsed.value).toBeLessThanOrEqual(2)
expect(harness.executionQueue.length).toBe(0)
})
})
describe('Invariant Assertions', () => {
it('should always have valid index (1 <= currentIndex <= totalCount)', async () => {
const harness = createTestHarness({ totalCount: 5 })
const simulator = new BatchQueueSimulator({ totalCount: 5 })
const checkInvariant = () => {
const { currentIndex, totalCount } = harness.state
if (totalCount.value > 0) {
expect(currentIndex.value).toBeGreaterThanOrEqual(1)
expect(currentIndex.value).toBeLessThanOrEqual(totalCount.value)
}
}
// Override onExecuted to check invariant after each execution
const originalOnExecuted = harness.node.onExecuted
harness.node.onExecuted = (output: any) => {
originalOnExecuted(output)
checkInvariant()
}
await simulator.runBatchQueue(
20,
{
beforeQueued: () => harness.widget.beforeQueued(),
onExecuted: (output) => harness.node.onExecuted(output)
},
() => harness.getConfig()
)
})
it('should always have repeatUsed <= repeatCount', async () => {
const harness = createTestHarness({ totalCount: 5, repeatCount: 3 })
const checkInvariant = () => {
expect(harness.state.repeatUsed.value).toBeLessThanOrEqual(harness.state.repeatCount.value)
}
// Check after each beforeQueued
for (let i = 0; i < 20; i++) {
harness.widget.beforeQueued()
checkInvariant()
}
})
it('should consume all execution contexts (queue empty after matching executions)', async () => {
const harness = createTestHarness({ totalCount: 5 })
const simulator = new BatchQueueSimulator({ totalCount: 5 })
await simulator.runBatchQueue(
7,
{
beforeQueued: () => harness.widget.beforeQueued(),
onExecuted: (output) => harness.node.onExecuted(output)
},
() => harness.getConfig()
)
expect(harness.executionQueue.length).toBe(0)
})
})
describe('Batch vs Sequential Mode Equivalence', () => {
it('should produce same final state in both modes (basic cycle)', async () => {
// Create two identical harnesses
const batchHarness = createTestHarness({ totalCount: 5 })
const seqHarness = createTestHarness({ totalCount: 5 })
const batchSimulator = new BatchQueueSimulator({ totalCount: 5 })
const seqSimulator = new BatchQueueSimulator({ totalCount: 5 })
// Run same number of executions in different modes
await batchSimulator.runBatchQueue(
7,
{
beforeQueued: () => batchHarness.widget.beforeQueued(),
onExecuted: (output) => batchHarness.node.onExecuted(output)
},
() => batchHarness.getConfig()
)
await seqSimulator.runSequential(
7,
{
beforeQueued: () => seqHarness.widget.beforeQueued(),
onExecuted: (output) => seqHarness.node.onExecuted(output)
},
() => seqHarness.getConfig()
)
// Final state should be identical
expect(batchHarness.state.currentIndex.value).toBe(seqHarness.state.currentIndex.value)
expect(batchHarness.state.repeatUsed.value).toBe(seqHarness.state.repeatUsed.value)
expect(batchHarness.state.displayRepeatUsed.value).toBe(seqHarness.state.displayRepeatUsed.value)
})
it('should produce same final state in both modes (with repeat)', async () => {
const batchHarness = createTestHarness({ totalCount: 3, repeatCount: 2 })
const seqHarness = createTestHarness({ totalCount: 3, repeatCount: 2 })
const batchSimulator = new BatchQueueSimulator({ totalCount: 3 })
const seqSimulator = new BatchQueueSimulator({ totalCount: 3 })
await batchSimulator.runBatchQueue(
10,
{
beforeQueued: () => batchHarness.widget.beforeQueued(),
onExecuted: (output) => batchHarness.node.onExecuted(output)
},
() => batchHarness.getConfig()
)
await seqSimulator.runSequential(
10,
{
beforeQueued: () => seqHarness.widget.beforeQueued(),
onExecuted: (output) => seqHarness.node.onExecuted(output)
},
() => seqHarness.getConfig()
)
expect(batchHarness.state.currentIndex.value).toBe(seqHarness.state.currentIndex.value)
expect(batchHarness.state.repeatUsed.value).toBe(seqHarness.state.repeatUsed.value)
expect(batchHarness.state.displayRepeatUsed.value).toBe(seqHarness.state.displayRepeatUsed.value)
})
})
})

View File

@@ -0,0 +1,75 @@
/**
* Vitest test setup file
* Configures global mocks for ComfyUI modules and browser APIs
*/
import { vi } from 'vitest'
// Mock ComfyUI app module
vi.mock('../../../scripts/app.js', () => ({
app: {
graph: {
_nodes: []
},
registerExtension: vi.fn()
}
}))
// Mock ComfyUI loras_widget module
vi.mock('../loras_widget.js', () => ({
addLoraCard: vi.fn(),
removeLoraCard: vi.fn()
}))
// Mock ComfyUI autocomplete module
vi.mock('../autocomplete.js', () => ({
setupAutocomplete: vi.fn()
}))
// Global fetch mock - exported so tests can access it directly
export const mockFetch = vi.fn()
vi.stubGlobal('fetch', mockFetch)
// Helper to reset fetch mock between tests
export function resetFetchMock() {
mockFetch.mockReset()
// Re-stub global to ensure it's the same mock
vi.stubGlobal('fetch', mockFetch)
}
// Helper to setup fetch mock with default success response
export function setupFetchMock(response: unknown = { success: true, loras: [] }) {
// Ensure we're using the same mock
mockFetch.mockReset()
mockFetch.mockResolvedValue({
ok: true,
json: () => Promise.resolve(response)
})
vi.stubGlobal('fetch', mockFetch)
return mockFetch
}
// Helper to setup fetch mock with error response
export function setupFetchErrorMock(error: string = 'Network error') {
mockFetch.mockReset()
mockFetch.mockRejectedValue(new Error(error))
vi.stubGlobal('fetch', mockFetch)
return mockFetch
}
// Mock btoa for hashing (jsdom should have this, but just in case)
if (typeof global.btoa === 'undefined') {
vi.stubGlobal('btoa', (str: string) => Buffer.from(str).toString('base64'))
}
// Mock console methods to reduce noise in tests
vi.spyOn(console, 'log').mockImplementation(() => {})
vi.spyOn(console, 'error').mockImplementation(() => {})
vi.spyOn(console, 'warn').mockImplementation(() => {})
// Re-enable console for debugging when needed
export function enableConsole() {
vi.spyOn(console, 'log').mockRestore()
vi.spyOn(console, 'error').mockRestore()
vi.spyOn(console, 'warn').mockRestore()
}

View File

@@ -0,0 +1,230 @@
/**
* BatchQueueSimulator - Simulates ComfyUI's two execution modes
*
* ComfyUI has two distinct execution patterns:
* 1. Batch Queue Mode: ALL beforeQueued calls happen BEFORE any onExecuted calls
* 2. Sequential Mode: beforeQueued and onExecuted interleave for each prompt
*
* This simulator helps test how the widget behaves in both modes.
*/
import type { CyclerConfig } from '@/composables/types'
export interface ExecutionHooks {
/** Called when a prompt is queued (before execution) */
beforeQueued: () => void
/** Called when execution completes with output */
onExecuted: (output: unknown) => void
}
export interface SimulatorOptions {
/** Total number of LoRAs in the pool */
totalCount: number
/** Function to generate output for each execution */
generateOutput?: (executionIndex: number, config: CyclerConfig) => unknown
}
/**
* Creates execution output based on the current state
*/
function defaultGenerateOutput(executionIndex: number, config: CyclerConfig) {
// Calculate what the next index would be after this execution
let nextIdx = (config.execution_index ?? config.current_index) + 1
if (nextIdx > config.total_count) {
nextIdx = 1
}
return {
next_index: [nextIdx],
total_count: [config.total_count],
next_lora_name: [`lora${nextIdx}.safetensors`],
next_lora_filename: [`lora${nextIdx}.safetensors`],
current_lora_name: [`lora${config.execution_index ?? config.current_index}.safetensors`],
current_lora_filename: [`lora${config.execution_index ?? config.current_index}.safetensors`]
}
}
export class BatchQueueSimulator {
private executionCount = 0
private options: Required<SimulatorOptions>
constructor(options: SimulatorOptions) {
this.options = {
totalCount: options.totalCount,
generateOutput: options.generateOutput ?? defaultGenerateOutput
}
}
/**
* Reset the simulator state
*/
reset() {
this.executionCount = 0
}
/**
* Simulates Batch Queue Mode execution
*
* In this mode, ComfyUI queues multiple prompts at once:
* - ALL beforeQueued() calls happen first (for all prompts in the batch)
* - THEN all onExecuted() calls happen (as each prompt completes)
*
* This is the mode used when queueing multiple prompts from the UI.
*
* @param count Number of prompts to simulate
* @param hooks The widget's execution hooks
* @param getConfig Function to get current widget config state
*/
async runBatchQueue(
count: number,
hooks: ExecutionHooks,
getConfig: () => CyclerConfig
): Promise<void> {
// Phase 1: All beforeQueued calls (snapshot configs)
const snapshotConfigs: CyclerConfig[] = []
for (let i = 0; i < count; i++) {
hooks.beforeQueued()
// Snapshot the config after beforeQueued updates it
snapshotConfigs.push({ ...getConfig() })
}
// Phase 2: All onExecuted calls (in order)
for (let i = 0; i < count; i++) {
const config = snapshotConfigs[i]
const output = this.options.generateOutput(this.executionCount, config)
hooks.onExecuted(output)
this.executionCount++
}
}
/**
* Simulates Sequential Mode execution
*
* In this mode, execution is one-at-a-time:
* - beforeQueued() is called
* - onExecuted() is called
* - Then the next prompt's beforeQueued() is called
* - And so on...
*
* This is the mode used in API-driven execution or single prompt queuing.
*
* @param count Number of prompts to simulate
* @param hooks The widget's execution hooks
* @param getConfig Function to get current widget config state
*/
async runSequential(
count: number,
hooks: ExecutionHooks,
getConfig: () => CyclerConfig
): Promise<void> {
for (let i = 0; i < count; i++) {
// Queue the prompt
hooks.beforeQueued()
const config = { ...getConfig() }
// Execute it immediately
const output = this.options.generateOutput(this.executionCount, config)
hooks.onExecuted(output)
this.executionCount++
}
}
/**
* Simulates a single execution (queue + execute)
*/
async runSingle(
hooks: ExecutionHooks,
getConfig: () => CyclerConfig
): Promise<void> {
return this.runSequential(1, hooks, getConfig)
}
/**
* Simulates interrupted execution (some beforeQueued calls without matching onExecuted)
*
* This can happen if the user cancels execution mid-batch.
*
* @param queuedCount Number of prompts queued (beforeQueued called)
* @param executedCount Number of prompts that actually executed
* @param hooks The widget's execution hooks
* @param getConfig Function to get current widget config state
*/
async runInterrupted(
queuedCount: number,
executedCount: number,
hooks: ExecutionHooks,
getConfig: () => CyclerConfig
): Promise<void> {
if (executedCount > queuedCount) {
throw new Error('executedCount cannot be greater than queuedCount')
}
// Phase 1: All beforeQueued calls
const snapshotConfigs: CyclerConfig[] = []
for (let i = 0; i < queuedCount; i++) {
hooks.beforeQueued()
snapshotConfigs.push({ ...getConfig() })
}
// Phase 2: Only some onExecuted calls
for (let i = 0; i < executedCount; i++) {
const config = snapshotConfigs[i]
const output = this.options.generateOutput(this.executionCount, config)
hooks.onExecuted(output)
this.executionCount++
}
}
}
/**
* Helper to create execution hooks from a widget-like object
*/
export function createHooksFromWidget(widget: {
beforeQueued?: () => void
}, node: {
onExecuted?: (output: unknown) => void
}): ExecutionHooks {
return {
beforeQueued: () => widget.beforeQueued?.(),
onExecuted: (output) => node.onExecuted?.(output)
}
}
/**
* Tracks index history during simulation for assertions
*/
export class IndexTracker {
public indexHistory: number[] = []
public repeatHistory: number[] = []
public pauseHistory: boolean[] = []
reset() {
this.indexHistory = []
this.repeatHistory = []
this.pauseHistory = []
}
record(config: CyclerConfig) {
this.indexHistory.push(config.current_index)
this.repeatHistory.push(config.repeat_used)
this.pauseHistory.push(config.is_paused)
}
/**
* Get the sequence of indices that were actually used for execution
*/
getExecutionIndices(): number[] {
return this.indexHistory
}
/**
* Verify that indices cycle correctly through totalCount
*/
verifyCyclePattern(expectedPattern: number[]): boolean {
if (this.indexHistory.length !== expectedPattern.length) {
return false
}
return this.indexHistory.every((idx, i) => idx === expectedPattern[i])
}
}

View File

@@ -19,6 +19,6 @@
"@/*": ["./src/*"] "@/*": ["./src/*"]
} }
}, },
"include": ["src/**/*.ts", "src/**/*.tsx", "src/**/*.vue"], "include": ["src/**/*.ts", "src/**/*.tsx", "src/**/*.vue", "tests/**/*.ts"],
"references": [{ "path": "./tsconfig.node.json" }] "references": [{ "path": "./tsconfig.node.json" }]
} }

View File

@@ -22,8 +22,10 @@ export default defineConfig({
rollupOptions: { rollupOptions: {
external: [ external: [
'../../../scripts/app.js', '../../../scripts/app.js',
'../../../scripts/api.js',
'../loras_widget.js', '../loras_widget.js',
'../autocomplete.js' '../autocomplete.js',
'../preview_tooltip.js'
], ],
output: { output: {
dir: '../web/comfyui/vue-widgets', dir: '../web/comfyui/vue-widgets',

View File

@@ -0,0 +1,25 @@
import { defineConfig } from 'vitest/config'
import vue from '@vitejs/plugin-vue'
import { resolve } from 'path'
export default defineConfig({
plugins: [vue()],
resolve: {
alias: {
'@': resolve(__dirname, './src')
}
},
test: {
globals: true,
environment: 'jsdom',
setupFiles: ['./tests/setup.ts'],
include: ['tests/**/*.test.ts'],
coverage: {
provider: 'v8',
reporter: ['text', 'html', 'json'],
reportsDirectory: './coverage',
include: ['src/**/*.ts', 'src/**/*.vue'],
exclude: ['src/main.ts', 'src/vite-env.d.ts']
}
}
})

View File

@@ -2,6 +2,7 @@ import { api } from "../../scripts/api.js";
import { app } from "../../scripts/app.js"; import { app } from "../../scripts/app.js";
import { TextAreaCaretHelper } from "./textarea_caret_helper.js"; import { TextAreaCaretHelper } from "./textarea_caret_helper.js";
import { getPromptTagAutocompletePreference, getTagSpaceReplacementPreference } from "./settings.js"; import { getPromptTagAutocompletePreference, getTagSpaceReplacementPreference } from "./settings.js";
import { showToast } from "./utils.js";
// Command definitions for category filtering // Command definitions for category filtering
const TAG_COMMANDS = { const TAG_COMMANDS = {
@@ -15,6 +16,21 @@ const TAG_COMMANDS = {
'/lore': { categories: [15], label: 'Lore' }, '/lore': { categories: [15], label: 'Lore' },
'/emb': { type: 'embedding', label: 'Embeddings' }, '/emb': { type: 'embedding', label: 'Embeddings' },
'/embedding': { type: 'embedding', label: 'Embeddings' }, '/embedding': { type: 'embedding', label: 'Embeddings' },
// Autocomplete toggle commands - only show one based on current state
'/ac': {
type: 'toggle_setting',
settingId: 'loramanager.prompt_tag_autocomplete',
value: true,
label: 'Autocomplete: ON',
condition: () => !getPromptTagAutocompletePreference()
},
'/noac': {
type: 'toggle_setting',
settingId: 'loramanager.prompt_tag_autocomplete',
value: false,
label: 'Autocomplete: OFF',
condition: () => getPromptTagAutocompletePreference()
},
}; };
// Category display information // Category display information
@@ -488,6 +504,10 @@ class AutoComplete {
this.searchType = 'commands'; this.searchType = 'commands';
this._showCommandList(commandResult.commandFilter); this._showCommandList(commandResult.commandFilter);
return; return;
} else if (commandResult.command?.type === 'toggle_setting') {
// Handle toggle setting command (/ac, /noac)
this._handleToggleSettingCommand(commandResult.command);
return;
} else if (commandResult.command) { } else if (commandResult.command) {
// Command is active, use filtered search // Command is active, use filtered search
this.showingCommands = false; this.showingCommands = false;
@@ -509,7 +529,10 @@ class AutoComplete {
this.showingCommands = false; this.showingCommands = false;
this.activeCommand = null; this.activeCommand = null;
endpoint = '/lm/custom-words/search?enriched=true'; endpoint = '/lm/custom-words/search?enriched=true';
searchTerm = rawSearchTerm; // Extract last space-separated token for search
// Tag names don't contain spaces, so we only need the last token
// This allows "hello 1gi" to search for "1gi" and find "1girl"
searchTerm = this._getLastSpaceToken(rawSearchTerm);
this.searchType = 'custom_words'; this.searchType = 'custom_words';
} else { } else {
// No command and setting disabled - no autocomplete for direct typing // No command and setting disabled - no autocomplete for direct typing
@@ -545,6 +568,17 @@ class AutoComplete {
return lastSegment.trim(); return lastSegment.trim();
} }
/**
* Extract the last space-separated token from a search term
* Tag names don't contain spaces, so for tag autocomplete we only need the last token
* @param {string} term - The full search term (e.g., "hello 1gi")
* @returns {string} - The last token (e.g., "1gi"), or the original term if no spaces
*/
_getLastSpaceToken(term) {
const tokens = term.trim().split(/\s+/);
return tokens[tokens.length - 1] || term;
}
async search(term = '', endpoint = null) { async search(term = '', endpoint = null) {
try { try {
this.currentSearchTerm = term; this.currentSearchTerm = term;
@@ -606,9 +640,14 @@ class AutoComplete {
// Check for exact command match // Check for exact command match
if (TAG_COMMANDS[partialCommand]) { if (TAG_COMMANDS[partialCommand]) {
const cmd = TAG_COMMANDS[partialCommand];
// Filter out toggle commands that don't meet their condition
if (cmd.type === 'toggle_setting' && cmd.condition && !cmd.condition()) {
return { showCommands: false, command: null, searchTerm: '' };
}
return { return {
showCommands: false, showCommands: false,
command: TAG_COMMANDS[partialCommand], command: cmd,
searchTerm: '', searchTerm: '',
}; };
} }
@@ -627,9 +666,14 @@ class AutoComplete {
const searchPart = trimmed.slice(spaceIndex + 1).trim(); const searchPart = trimmed.slice(spaceIndex + 1).trim();
if (TAG_COMMANDS[commandPart]) { if (TAG_COMMANDS[commandPart]) {
const cmd = TAG_COMMANDS[commandPart];
// Filter out toggle commands that don't meet their condition
if (cmd.type === 'toggle_setting' && cmd.condition && !cmd.condition()) {
return { showCommands: false, command: null, searchTerm: trimmed };
}
return { return {
showCommands: false, showCommands: false,
command: TAG_COMMANDS[commandPart], command: cmd,
searchTerm: searchPart, searchTerm: searchPart,
}; };
} }
@@ -652,6 +696,11 @@ class AutoComplete {
for (const [cmd, info] of Object.entries(TAG_COMMANDS)) { for (const [cmd, info] of Object.entries(TAG_COMMANDS)) {
if (seenLabels.has(info.label)) continue; if (seenLabels.has(info.label)) continue;
// Filter out toggle commands that don't meet their condition
if (info.type === 'toggle_setting' && info.condition) {
if (!info.condition()) continue;
}
if (!filter || cmd.slice(1).startsWith(filterLower)) { if (!filter || cmd.slice(1).startsWith(filterLower)) {
seenLabels.add(info.label); seenLabels.add(info.label);
commands.push({ command: cmd, ...info }); commands.push({ command: cmd, ...info });
@@ -1117,7 +1166,16 @@ class AutoComplete {
// Use getSearchTerm to get the current search term before cursor // Use getSearchTerm to get the current search term before cursor
const beforeCursor = currentValue.substring(0, caretPos); const beforeCursor = currentValue.substring(0, caretPos);
const searchTerm = this.getSearchTerm(beforeCursor); const fullSearchTerm = this.getSearchTerm(beforeCursor);
// For regular tag autocomplete (no command), only replace the last space-separated token
// This allows "hello 1gi" + selecting "1girl" to become "hello 1girl, "
// Command mode (e.g., "/char miku") should replace the entire command+search
let searchTerm = fullSearchTerm;
if (this.modelType === 'prompt' && this.searchType === 'custom_words' && !this.activeCommand) {
searchTerm = this._getLastSpaceToken(fullSearchTerm);
}
const searchStartPos = caretPos - searchTerm.length; const searchStartPos = caretPos - searchTerm.length;
// Only replace the search term, not everything after the last comma // Only replace the search term, not everything after the last comma
@@ -1175,6 +1233,119 @@ class AutoComplete {
} }
} }
/**
* Handle toggle setting command (/ac, /noac)
* @param {Object} command - The toggle command with settingId and value
*/
async _handleToggleSettingCommand(command) {
const { settingId, value } = command;
try {
// Use ComfyUI's setting API to update global setting
const settingManager = app?.extensionManager?.setting;
if (settingManager && typeof settingManager.set === 'function') {
await settingManager.set(settingId, value);
this._showToggleFeedback(value);
this._clearCurrentToken();
} else {
// Fallback: use legacy settings API
const setting = app.ui.settings.settingsById?.[settingId];
if (setting) {
app.ui.settings.setSettingValue(settingId, value);
this._showToggleFeedback(value);
this._clearCurrentToken();
}
}
} catch (error) {
console.error('[Lora Manager] Failed to toggle setting:', error);
showToast({
severity: 'error',
summary: 'Error',
detail: 'Failed to toggle autocomplete setting',
life: 3000
});
}
this.hide();
}
/**
* Show visual feedback for toggle action using toast
* @param {boolean} enabled - New autocomplete state
*/
_showToggleFeedback(enabled) {
showToast({
severity: enabled ? 'success' : 'secondary',
summary: enabled ? 'Autocomplete Enabled' : 'Autocomplete Disabled',
detail: enabled
? 'Tag autocomplete is now ON. Type to see suggestions.'
: 'Tag autocomplete is now OFF. Use /ac to re-enable.',
life: 3000
});
}
/**
* Clear the current command token from input
* Preserves leading spaces after delimiters (e.g., "1girl, /ac" -> "1girl, ")
*/
_clearCurrentToken() {
const currentValue = this.inputElement.value;
const caretPos = this.inputElement.selectionStart;
// Find the command text before cursor
const beforeCursor = currentValue.substring(0, caretPos);
const segments = beforeCursor.split(/[,\>]+/);
const lastSegment = segments[segments.length - 1] || '';
// Find the command start position, preserving leading spaces
// lastSegment includes leading spaces (e.g., " /ac"), find where command actually starts
const commandMatch = lastSegment.match(/^(\s*)(\/\w+)/);
if (commandMatch) {
// commandMatch[1] is leading spaces, commandMatch[2] is the command
const leadingSpaces = commandMatch[1].length;
// Keep the spaces by starting after them
const commandStartPos = caretPos - lastSegment.length + leadingSpaces;
// Skip trailing spaces when deleting
let endPos = caretPos;
while (endPos < currentValue.length && currentValue[endPos] === ' ') {
endPos++;
}
const newValue = currentValue.substring(0, commandStartPos) + currentValue.substring(endPos);
const newCaretPos = commandStartPos;
this.inputElement.value = newValue;
// Trigger input event to notify about the change
const event = new Event('input', { bubbles: true });
this.inputElement.dispatchEvent(event);
// Focus back to input and position cursor
this.inputElement.focus();
this.inputElement.setSelectionRange(newCaretPos, newCaretPos);
} else {
// Fallback: delete the whole last segment (original behavior)
const commandStartPos = caretPos - lastSegment.length;
let endPos = caretPos;
while (endPos < currentValue.length && currentValue[endPos] === ' ') {
endPos++;
}
const newValue = currentValue.substring(0, commandStartPos) + currentValue.substring(endPos);
const newCaretPos = commandStartPos;
this.inputElement.value = newValue;
const event = new Event('input', { bubbles: true });
this.inputElement.dispatchEvent(event);
this.inputElement.focus();
this.inputElement.setSelectionRange(newCaretPos, newCaretPos);
}
}
destroy() { destroy() {
if (this.debounceTimer) { if (this.debounceTimer) {
clearTimeout(this.debounceTimer); clearTimeout(this.debounceTimer);

View File

@@ -1,7 +1,7 @@
/* Shared styling for the LoRA Manager frontend widgets */ /* Shared styling for the LoRA Manager frontend widgets */
.lm-tooltip { .lm-tooltip {
position: fixed; position: fixed;
z-index: 9999; z-index: 10001;
background: rgba(0, 0, 0, 0.85); background: rgba(0, 0, 0, 0.85);
border-radius: 6px; border-radius: 6px;
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.3); box-shadow: 0 4px 12px rgba(0, 0, 0, 0.3);

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long