mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 15:15:44 -03:00
feat(example-images): add check pending models endpoint and improve async handling
- Add /api/example-images/check-pending endpoint to quickly check models needing downloads - Improve DownloadManager.start_download() to return immediately without blocking - Add _handle_download_task_done callback for proper error handling and progress saving - Add check_pending_models() method for lightweight pre-download validation - Update frontend ExampleImagesManager to use new check-pending endpoint - Add comprehensive tests for new functionality
This commit is contained in:
@@ -30,6 +30,7 @@ ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
|||||||
RouteDefinition("POST", "/api/lm/force-download-example-images", "force_download_example_images"),
|
RouteDefinition("POST", "/api/lm/force-download-example-images", "force_download_example_images"),
|
||||||
RouteDefinition("POST", "/api/lm/cleanup-example-image-folders", "cleanup_example_image_folders"),
|
RouteDefinition("POST", "/api/lm/cleanup-example-image-folders", "cleanup_example_image_folders"),
|
||||||
RouteDefinition("POST", "/api/lm/example-images/set-nsfw-level", "set_example_image_nsfw_level"),
|
RouteDefinition("POST", "/api/lm/example-images/set-nsfw-level", "set_example_image_nsfw_level"),
|
||||||
|
RouteDefinition("POST", "/api/lm/check-example-images-needed", "check_example_images_needed"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -92,6 +92,19 @@ class ExampleImagesDownloadHandler:
|
|||||||
except ExampleImagesDownloadError as exc:
|
except ExampleImagesDownloadError as exc:
|
||||||
return web.json_response({'success': False, 'error': str(exc)}, status=500)
|
return web.json_response({'success': False, 'error': str(exc)}, status=500)
|
||||||
|
|
||||||
|
async def check_example_images_needed(self, request: web.Request) -> web.StreamResponse:
|
||||||
|
"""Lightweight check to see if any models need example images downloaded."""
|
||||||
|
try:
|
||||||
|
payload = await request.json()
|
||||||
|
model_types = payload.get('model_types', ['lora', 'checkpoint', 'embedding'])
|
||||||
|
result = await self._download_manager.check_pending_models(model_types)
|
||||||
|
return web.json_response(result)
|
||||||
|
except Exception as exc:
|
||||||
|
return web.json_response(
|
||||||
|
{'success': False, 'error': str(exc)},
|
||||||
|
status=500
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ExampleImagesManagementHandler:
|
class ExampleImagesManagementHandler:
|
||||||
"""HTTP adapters for import/delete endpoints."""
|
"""HTTP adapters for import/delete endpoints."""
|
||||||
@@ -161,6 +174,7 @@ class ExampleImagesHandlerSet:
|
|||||||
"resume_example_images": self.download.resume_example_images,
|
"resume_example_images": self.download.resume_example_images,
|
||||||
"stop_example_images": self.download.stop_example_images,
|
"stop_example_images": self.download.stop_example_images,
|
||||||
"force_download_example_images": self.download.force_download_example_images,
|
"force_download_example_images": self.download.force_download_example_images,
|
||||||
|
"check_example_images_needed": self.download.check_example_images_needed,
|
||||||
"import_example_images": self.management.import_example_images,
|
"import_example_images": self.management.import_example_images,
|
||||||
"delete_example_image": self.management.delete_example_image,
|
"delete_example_image": self.management.delete_example_image,
|
||||||
"set_example_image_nsfw_level": self.management.set_example_image_nsfw_level,
|
"set_example_image_nsfw_level": self.management.set_example_image_nsfw_level,
|
||||||
|
|||||||
@@ -216,6 +216,11 @@ class DownloadManager:
|
|||||||
self._progress["failed_models"] = set()
|
self._progress["failed_models"] = set()
|
||||||
|
|
||||||
self._is_downloading = True
|
self._is_downloading = True
|
||||||
|
snapshot = self._progress.snapshot()
|
||||||
|
|
||||||
|
# Create the download task without awaiting it
|
||||||
|
# This ensures the HTTP response is returned immediately
|
||||||
|
# while the actual processing happens in the background
|
||||||
self._download_task = asyncio.create_task(
|
self._download_task = asyncio.create_task(
|
||||||
self._download_all_example_images(
|
self._download_all_example_images(
|
||||||
output_dir,
|
output_dir,
|
||||||
@@ -227,7 +232,10 @@ class DownloadManager:
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
snapshot = self._progress.snapshot()
|
# Add a callback to handle task completion/errors
|
||||||
|
self._download_task.add_done_callback(
|
||||||
|
lambda t: self._handle_download_task_done(t, output_dir)
|
||||||
|
)
|
||||||
except ExampleImagesDownloadError:
|
except ExampleImagesDownloadError:
|
||||||
# Re-raise our own exception types without wrapping
|
# Re-raise our own exception types without wrapping
|
||||||
self._is_downloading = False
|
self._is_downloading = False
|
||||||
@@ -241,10 +249,25 @@ class DownloadManager:
|
|||||||
)
|
)
|
||||||
raise ExampleImagesDownloadError(str(e)) from e
|
raise ExampleImagesDownloadError(str(e)) from e
|
||||||
|
|
||||||
await self._broadcast_progress(status="running")
|
# Broadcast progress in the background without blocking the response
|
||||||
|
# This ensures the HTTP response is returned immediately
|
||||||
|
asyncio.create_task(self._broadcast_progress(status="running"))
|
||||||
|
|
||||||
return {"success": True, "message": "Download started", "status": snapshot}
|
return {"success": True, "message": "Download started", "status": snapshot}
|
||||||
|
|
||||||
|
def _handle_download_task_done(self, task: asyncio.Task, output_dir: str) -> None:
|
||||||
|
"""Handle download task completion, including saving progress on error."""
|
||||||
|
try:
|
||||||
|
# This will re-raise any exception from the task
|
||||||
|
task.result()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Download task failed with error: {e}", exc_info=True)
|
||||||
|
# Ensure progress is saved even on failure
|
||||||
|
try:
|
||||||
|
self._save_progress(output_dir)
|
||||||
|
except Exception as save_error:
|
||||||
|
logger.error(f"Failed to save progress after task failure: {save_error}")
|
||||||
|
|
||||||
async def get_status(self, request):
|
async def get_status(self, request):
|
||||||
"""Get the current status of example images download."""
|
"""Get the current status of example images download."""
|
||||||
|
|
||||||
@@ -254,6 +277,130 @@ class DownloadManager:
|
|||||||
"status": self._progress.snapshot(),
|
"status": self._progress.snapshot(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async def check_pending_models(self, model_types: list[str]) -> dict:
|
||||||
|
"""Quickly check how many models need example images downloaded.
|
||||||
|
|
||||||
|
This is a lightweight check that avoids the overhead of starting
|
||||||
|
a full download task when no work is needed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict with keys:
|
||||||
|
- total_models: Total number of models across specified types
|
||||||
|
- pending_count: Number of models needing example images
|
||||||
|
- processed_count: Number of already processed models
|
||||||
|
- failed_count: Number of models marked as failed
|
||||||
|
- needs_download: True if there are pending models to process
|
||||||
|
"""
|
||||||
|
from ..services.service_registry import ServiceRegistry
|
||||||
|
|
||||||
|
if self._is_downloading:
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"is_downloading": True,
|
||||||
|
"total_models": 0,
|
||||||
|
"pending_count": 0,
|
||||||
|
"processed_count": 0,
|
||||||
|
"failed_count": 0,
|
||||||
|
"needs_download": False,
|
||||||
|
"message": "Download already in progress",
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get scanners
|
||||||
|
scanners = []
|
||||||
|
if "lora" in model_types:
|
||||||
|
lora_scanner = await ServiceRegistry.get_lora_scanner()
|
||||||
|
scanners.append(("lora", lora_scanner))
|
||||||
|
|
||||||
|
if "checkpoint" in model_types:
|
||||||
|
checkpoint_scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||||
|
scanners.append(("checkpoint", checkpoint_scanner))
|
||||||
|
|
||||||
|
if "embedding" in model_types:
|
||||||
|
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||||
|
scanners.append(("embedding", embedding_scanner))
|
||||||
|
|
||||||
|
# Load progress file to check processed models
|
||||||
|
settings_manager = get_settings_manager()
|
||||||
|
active_library = settings_manager.get_active_library_name()
|
||||||
|
output_dir = self._resolve_output_dir(active_library)
|
||||||
|
|
||||||
|
processed_models: set[str] = set()
|
||||||
|
failed_models: set[str] = set()
|
||||||
|
|
||||||
|
if output_dir:
|
||||||
|
progress_file = os.path.join(output_dir, ".download_progress.json")
|
||||||
|
if os.path.exists(progress_file):
|
||||||
|
try:
|
||||||
|
with open(progress_file, "r", encoding="utf-8") as f:
|
||||||
|
saved_progress = json.load(f)
|
||||||
|
processed_models = set(saved_progress.get("processed_models", []))
|
||||||
|
failed_models = set(saved_progress.get("failed_models", []))
|
||||||
|
except Exception:
|
||||||
|
pass # Ignore progress file errors for quick check
|
||||||
|
|
||||||
|
# Count models
|
||||||
|
total_models = 0
|
||||||
|
models_with_hash = 0
|
||||||
|
|
||||||
|
for scanner_type, scanner in scanners:
|
||||||
|
cache = await scanner.get_cached_data()
|
||||||
|
if cache and cache.raw_data:
|
||||||
|
for model in cache.raw_data:
|
||||||
|
total_models += 1
|
||||||
|
if model.get("sha256"):
|
||||||
|
models_with_hash += 1
|
||||||
|
|
||||||
|
# Calculate pending count
|
||||||
|
# A model is pending if it has a hash and is not in processed_models
|
||||||
|
# We also exclude failed_models unless force mode would be used
|
||||||
|
pending_count = models_with_hash - len(processed_models.intersection(
|
||||||
|
{m.get("sha256", "").lower() for scanner_type, scanner in scanners
|
||||||
|
for m in (await scanner.get_cached_data()).raw_data if m.get("sha256")}
|
||||||
|
))
|
||||||
|
|
||||||
|
# More accurate pending count: check which models actually need processing
|
||||||
|
pending_hashes = set()
|
||||||
|
for scanner_type, scanner in scanners:
|
||||||
|
cache = await scanner.get_cached_data()
|
||||||
|
if cache and cache.raw_data:
|
||||||
|
for model in cache.raw_data:
|
||||||
|
raw_hash = model.get("sha256")
|
||||||
|
if not raw_hash:
|
||||||
|
continue
|
||||||
|
model_hash = raw_hash.lower()
|
||||||
|
if model_hash not in processed_models:
|
||||||
|
# Check if model folder exists with files
|
||||||
|
model_dir = ExampleImagePathResolver.get_model_folder(
|
||||||
|
model_hash, active_library
|
||||||
|
)
|
||||||
|
if not _model_directory_has_files(model_dir):
|
||||||
|
pending_hashes.add(model_hash)
|
||||||
|
|
||||||
|
pending_count = len(pending_hashes)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"is_downloading": False,
|
||||||
|
"total_models": total_models,
|
||||||
|
"pending_count": pending_count,
|
||||||
|
"processed_count": len(processed_models),
|
||||||
|
"failed_count": len(failed_models),
|
||||||
|
"needs_download": pending_count > 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error checking pending models: {e}", exc_info=True)
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": str(e),
|
||||||
|
"total_models": 0,
|
||||||
|
"pending_count": 0,
|
||||||
|
"processed_count": 0,
|
||||||
|
"failed_count": 0,
|
||||||
|
"needs_download": False,
|
||||||
|
}
|
||||||
|
|
||||||
async def pause_download(self, request):
|
async def pause_download(self, request):
|
||||||
"""Pause the example images download."""
|
"""Pause the example images download."""
|
||||||
|
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ export class ExampleImagesManager {
|
|||||||
// Auto download properties
|
// Auto download properties
|
||||||
this.autoDownloadInterval = null;
|
this.autoDownloadInterval = null;
|
||||||
this.lastAutoDownloadCheck = 0;
|
this.lastAutoDownloadCheck = 0;
|
||||||
this.autoDownloadCheckInterval = 10 * 60 * 1000; // 10 minutes in milliseconds
|
this.autoDownloadCheckInterval = 30 * 60 * 1000; // 30 minutes in milliseconds
|
||||||
this.pageInitTime = Date.now(); // Track when page was initialized
|
this.pageInitTime = Date.now(); // Track when page was initialized
|
||||||
|
|
||||||
// Initialize download path field and check download status
|
// Initialize download path field and check download status
|
||||||
@@ -808,19 +808,58 @@ export class ExampleImagesManager {
|
|||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
this.lastAutoDownloadCheck = now;
|
|
||||||
|
|
||||||
if (!this.canAutoDownload()) {
|
if (!this.canAutoDownload()) {
|
||||||
console.log('Auto download conditions not met, skipping check');
|
console.log('Auto download conditions not met, skipping check');
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
console.log('Performing auto download check...');
|
console.log('Performing auto download pre-check...');
|
||||||
|
|
||||||
|
// Step 1: Lightweight pre-check to see if any work is needed
|
||||||
|
const checkResponse = await fetch('/api/lm/check-example-images-needed', {
|
||||||
|
method: 'POST',
|
||||||
|
headers: {
|
||||||
|
'Content-Type': 'application/json'
|
||||||
|
},
|
||||||
|
body: JSON.stringify({
|
||||||
|
model_types: ['lora', 'checkpoint', 'embedding']
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!checkResponse.ok) {
|
||||||
|
console.warn('Auto download pre-check HTTP error:', checkResponse.status);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const checkData = await checkResponse.json();
|
||||||
|
|
||||||
|
if (!checkData.success) {
|
||||||
|
console.warn('Auto download pre-check failed:', checkData.error);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update the check timestamp only after successful pre-check
|
||||||
|
this.lastAutoDownloadCheck = now;
|
||||||
|
|
||||||
|
// If download already in progress, skip
|
||||||
|
if (checkData.is_downloading) {
|
||||||
|
console.log('Download already in progress, skipping auto check');
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no models need downloading, skip
|
||||||
|
if (!checkData.needs_download || checkData.pending_count === 0) {
|
||||||
|
console.log(`Auto download pre-check complete: ${checkData.processed_count}/${checkData.total_models} models already processed, no work needed`);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
console.log(`Auto download pre-check: ${checkData.pending_count} models need processing, starting download...`);
|
||||||
|
|
||||||
|
// Step 2: Start the actual download (fire-and-forget)
|
||||||
const optimize = state.global.settings.optimize_example_images;
|
const optimize = state.global.settings.optimize_example_images;
|
||||||
|
|
||||||
const response = await fetch('/api/lm/download-example-images', {
|
fetch('/api/lm/download-example-images', {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
headers: {
|
headers: {
|
||||||
'Content-Type': 'application/json'
|
'Content-Type': 'application/json'
|
||||||
@@ -830,18 +869,29 @@ export class ExampleImagesManager {
|
|||||||
model_types: ['lora', 'checkpoint', 'embedding'],
|
model_types: ['lora', 'checkpoint', 'embedding'],
|
||||||
auto_mode: true // Flag to indicate this is an automatic download
|
auto_mode: true // Flag to indicate this is an automatic download
|
||||||
})
|
})
|
||||||
|
}).then(response => {
|
||||||
|
if (!response.ok) {
|
||||||
|
console.warn('Auto download start HTTP error:', response.status);
|
||||||
|
return null;
|
||||||
|
}
|
||||||
|
return response.json();
|
||||||
|
}).then(data => {
|
||||||
|
if (data && !data.success) {
|
||||||
|
console.warn('Auto download start failed:', data.error);
|
||||||
|
// If already in progress, push back the next check to avoid hammering the API
|
||||||
|
if (data.error && data.error.includes('already in progress')) {
|
||||||
|
console.log('Download already in progress, backing off next check');
|
||||||
|
this.lastAutoDownloadCheck = now + (5 * 60 * 1000); // Back off for 5 extra minutes
|
||||||
|
}
|
||||||
|
} else if (data && data.success) {
|
||||||
|
console.log('Auto download started:', data.message || 'Download started');
|
||||||
|
}
|
||||||
|
}).catch(error => {
|
||||||
|
console.error('Auto download start error:', error);
|
||||||
});
|
});
|
||||||
|
|
||||||
const data = await response.json();
|
// Immediately return without waiting for the download fetch to complete
|
||||||
|
// This keeps the UI responsive
|
||||||
if (!data.success) {
|
|
||||||
console.warn('Auto download check failed:', data.error);
|
|
||||||
// If already in progress, push back the next check to avoid hammering the API
|
|
||||||
if (data.error && data.error.includes('already in progress')) {
|
|
||||||
console.log('Download already in progress, backing off next check');
|
|
||||||
this.lastAutoDownloadCheck = now + (5 * 60 * 1000); // Back off for 5 extra minutes
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Auto download check error:', error);
|
console.error('Auto download check error:', error);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -47,6 +47,8 @@ class StubDownloadManager:
|
|||||||
self.resume_error: Exception | None = None
|
self.resume_error: Exception | None = None
|
||||||
self.stop_error: Exception | None = None
|
self.stop_error: Exception | None = None
|
||||||
self.force_error: Exception | None = None
|
self.force_error: Exception | None = None
|
||||||
|
self.check_pending_result: dict[str, Any] | None = None
|
||||||
|
self.check_pending_calls: list[list[str]] = []
|
||||||
|
|
||||||
async def get_status(self, request: web.Request) -> dict[str, Any]:
|
async def get_status(self, request: web.Request) -> dict[str, Any]:
|
||||||
return {"success": True, "status": "idle"}
|
return {"success": True, "status": "idle"}
|
||||||
@@ -75,6 +77,20 @@ class StubDownloadManager:
|
|||||||
raise self.force_error
|
raise self.force_error
|
||||||
return {"success": True, "payload": payload}
|
return {"success": True, "payload": payload}
|
||||||
|
|
||||||
|
async def check_pending_models(self, model_types: list[str]) -> dict[str, Any]:
|
||||||
|
self.check_pending_calls.append(model_types)
|
||||||
|
if self.check_pending_result is not None:
|
||||||
|
return self.check_pending_result
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"is_downloading": False,
|
||||||
|
"total_models": 100,
|
||||||
|
"pending_count": 10,
|
||||||
|
"processed_count": 90,
|
||||||
|
"failed_count": 0,
|
||||||
|
"needs_download": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
class StubImportUseCase:
|
class StubImportUseCase:
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
@@ -236,3 +252,123 @@ async def test_import_route_returns_validation_errors():
|
|||||||
assert response.status == 400
|
assert response.status == 400
|
||||||
body = await _json(response)
|
body = await _json(response)
|
||||||
assert body == {"success": False, "error": "bad payload"}
|
assert body == {"success": False, "error": "bad payload"}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_check_example_images_needed_returns_pending_counts():
|
||||||
|
"""Test that check_example_images_needed endpoint returns pending model counts."""
|
||||||
|
async with registrar_app() as harness:
|
||||||
|
harness.download_manager.check_pending_result = {
|
||||||
|
"success": True,
|
||||||
|
"is_downloading": False,
|
||||||
|
"total_models": 5500,
|
||||||
|
"pending_count": 12,
|
||||||
|
"processed_count": 5488,
|
||||||
|
"failed_count": 45,
|
||||||
|
"needs_download": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = await harness.client.post(
|
||||||
|
"/api/lm/check-example-images-needed",
|
||||||
|
json={"model_types": ["lora", "checkpoint"]},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status == 200
|
||||||
|
body = await _json(response)
|
||||||
|
assert body["success"] is True
|
||||||
|
assert body["total_models"] == 5500
|
||||||
|
assert body["pending_count"] == 12
|
||||||
|
assert body["processed_count"] == 5488
|
||||||
|
assert body["failed_count"] == 45
|
||||||
|
assert body["needs_download"] is True
|
||||||
|
assert body["is_downloading"] is False
|
||||||
|
|
||||||
|
# Verify the manager was called with correct model types
|
||||||
|
assert harness.download_manager.check_pending_calls == [["lora", "checkpoint"]]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_check_example_images_needed_handles_download_in_progress():
|
||||||
|
"""Test that check_example_images_needed returns correct status when download is running."""
|
||||||
|
async with registrar_app() as harness:
|
||||||
|
harness.download_manager.check_pending_result = {
|
||||||
|
"success": True,
|
||||||
|
"is_downloading": True,
|
||||||
|
"total_models": 0,
|
||||||
|
"pending_count": 0,
|
||||||
|
"processed_count": 0,
|
||||||
|
"failed_count": 0,
|
||||||
|
"needs_download": False,
|
||||||
|
"message": "Download already in progress",
|
||||||
|
}
|
||||||
|
|
||||||
|
response = await harness.client.post(
|
||||||
|
"/api/lm/check-example-images-needed",
|
||||||
|
json={"model_types": ["lora"]},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status == 200
|
||||||
|
body = await _json(response)
|
||||||
|
assert body["success"] is True
|
||||||
|
assert body["is_downloading"] is True
|
||||||
|
assert body["needs_download"] is False
|
||||||
|
|
||||||
|
|
||||||
|
async def test_check_example_images_needed_handles_no_pending_models():
|
||||||
|
"""Test that check_example_images_needed returns correct status when no work is needed."""
|
||||||
|
async with registrar_app() as harness:
|
||||||
|
harness.download_manager.check_pending_result = {
|
||||||
|
"success": True,
|
||||||
|
"is_downloading": False,
|
||||||
|
"total_models": 5500,
|
||||||
|
"pending_count": 0,
|
||||||
|
"processed_count": 5500,
|
||||||
|
"failed_count": 0,
|
||||||
|
"needs_download": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = await harness.client.post(
|
||||||
|
"/api/lm/check-example-images-needed",
|
||||||
|
json={"model_types": ["lora", "checkpoint", "embedding"]},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status == 200
|
||||||
|
body = await _json(response)
|
||||||
|
assert body["success"] is True
|
||||||
|
assert body["pending_count"] == 0
|
||||||
|
assert body["needs_download"] is False
|
||||||
|
assert body["processed_count"] == 5500
|
||||||
|
|
||||||
|
|
||||||
|
async def test_check_example_images_needed_uses_default_model_types():
|
||||||
|
"""Test that check_example_images_needed uses default model types when not specified."""
|
||||||
|
async with registrar_app() as harness:
|
||||||
|
response = await harness.client.post(
|
||||||
|
"/api/lm/check-example-images-needed",
|
||||||
|
json={}, # No model_types specified
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status == 200
|
||||||
|
# Should use default model types
|
||||||
|
assert harness.download_manager.check_pending_calls == [["lora", "checkpoint", "embedding"]]
|
||||||
|
|
||||||
|
|
||||||
|
async def test_check_example_images_needed_returns_error_on_exception():
|
||||||
|
"""Test that check_example_images_needed returns 500 on internal error."""
|
||||||
|
async with registrar_app() as harness:
|
||||||
|
# Simulate an error by setting result to an error state
|
||||||
|
# Actually, we need to make the method raise an exception
|
||||||
|
original_method = harness.download_manager.check_pending_models
|
||||||
|
|
||||||
|
async def failing_check(_model_types):
|
||||||
|
raise RuntimeError("Database connection failed")
|
||||||
|
|
||||||
|
harness.download_manager.check_pending_models = failing_check
|
||||||
|
|
||||||
|
response = await harness.client.post(
|
||||||
|
"/api/lm/check-example-images-needed",
|
||||||
|
json={"model_types": ["lora"]},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status == 500
|
||||||
|
body = await _json(response)
|
||||||
|
assert body["success"] is False
|
||||||
|
assert "Database connection failed" in body["error"]
|
||||||
|
|||||||
@@ -502,6 +502,7 @@ def test_handler_set_route_mapping_includes_all_handlers() -> None:
|
|||||||
"resume_example_images",
|
"resume_example_images",
|
||||||
"stop_example_images",
|
"stop_example_images",
|
||||||
"force_download_example_images",
|
"force_download_example_images",
|
||||||
|
"check_example_images_needed",
|
||||||
"import_example_images",
|
"import_example_images",
|
||||||
"delete_example_image",
|
"delete_example_image",
|
||||||
"set_example_image_nsfw_level",
|
"set_example_image_nsfw_level",
|
||||||
|
|||||||
368
tests/services/test_check_pending_models.py
Normal file
368
tests/services/test_check_pending_models.py
Normal file
@@ -0,0 +1,368 @@
|
|||||||
|
"""Tests for the check_pending_models lightweight pre-check functionality."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from py.services.settings_manager import get_settings_manager
|
||||||
|
from py.utils import example_images_download_manager as download_module
|
||||||
|
|
||||||
|
|
||||||
|
class StubScanner:
|
||||||
|
"""Scanner double returning predetermined cache contents."""
|
||||||
|
|
||||||
|
def __init__(self, models: list[dict]) -> None:
|
||||||
|
self._cache = SimpleNamespace(raw_data=models)
|
||||||
|
|
||||||
|
async def get_cached_data(self):
|
||||||
|
return self._cache
|
||||||
|
|
||||||
|
|
||||||
|
def _patch_scanners(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
lora_scanner: StubScanner | None = None,
|
||||||
|
checkpoint_scanner: StubScanner | None = None,
|
||||||
|
embedding_scanner: StubScanner | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Patch ServiceRegistry to return stub scanners."""
|
||||||
|
|
||||||
|
async def _get_lora_scanner(cls):
|
||||||
|
return lora_scanner or StubScanner([])
|
||||||
|
|
||||||
|
async def _get_checkpoint_scanner(cls):
|
||||||
|
return checkpoint_scanner or StubScanner([])
|
||||||
|
|
||||||
|
async def _get_embedding_scanner(cls):
|
||||||
|
return embedding_scanner or StubScanner([])
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
download_module.ServiceRegistry,
|
||||||
|
"get_lora_scanner",
|
||||||
|
classmethod(_get_lora_scanner),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
download_module.ServiceRegistry,
|
||||||
|
"get_checkpoint_scanner",
|
||||||
|
classmethod(_get_checkpoint_scanner),
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
download_module.ServiceRegistry,
|
||||||
|
"get_embedding_scanner",
|
||||||
|
classmethod(_get_embedding_scanner),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RecordingWebSocketManager:
|
||||||
|
"""Collects broadcast payloads for assertions."""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.payloads: list[dict] = []
|
||||||
|
|
||||||
|
async def broadcast(self, payload: dict) -> None:
|
||||||
|
self.payloads.append(payload)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.usefixtures("tmp_path")
|
||||||
|
async def test_check_pending_models_returns_zero_when_all_processed(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
tmp_path,
|
||||||
|
settings_manager,
|
||||||
|
):
|
||||||
|
"""Test that check_pending_models returns 0 pending when all models are processed."""
|
||||||
|
ws_manager = RecordingWebSocketManager()
|
||||||
|
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||||
|
|
||||||
|
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
|
||||||
|
|
||||||
|
# Create processed models
|
||||||
|
processed_hashes = ["a" * 64, "b" * 64, "c" * 64]
|
||||||
|
models = [
|
||||||
|
{"sha256": h, "model_name": f"Model {i}"}
|
||||||
|
for i, h in enumerate(processed_hashes)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create progress file with all models processed
|
||||||
|
progress_file = tmp_path / ".download_progress.json"
|
||||||
|
progress_file.write_text(
|
||||||
|
json.dumps({"processed_models": processed_hashes, "failed_models": []}),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create model directories with files (simulating completed downloads)
|
||||||
|
for h in processed_hashes:
|
||||||
|
model_dir = tmp_path / h
|
||||||
|
model_dir.mkdir()
|
||||||
|
(model_dir / "image_0.png").write_text("data")
|
||||||
|
|
||||||
|
_patch_scanners(monkeypatch, lora_scanner=StubScanner(models))
|
||||||
|
|
||||||
|
result = await manager.check_pending_models(["lora"])
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result["is_downloading"] is False
|
||||||
|
assert result["total_models"] == 3
|
||||||
|
assert result["pending_count"] == 0
|
||||||
|
assert result["processed_count"] == 3
|
||||||
|
assert result["needs_download"] is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.usefixtures("tmp_path")
|
||||||
|
async def test_check_pending_models_finds_unprocessed_models(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
tmp_path,
|
||||||
|
settings_manager,
|
||||||
|
):
|
||||||
|
"""Test that check_pending_models correctly identifies unprocessed models."""
|
||||||
|
ws_manager = RecordingWebSocketManager()
|
||||||
|
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||||
|
|
||||||
|
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
|
||||||
|
|
||||||
|
# Create models - some processed, some not
|
||||||
|
processed_hash = "a" * 64
|
||||||
|
unprocessed_hash = "b" * 64
|
||||||
|
models = [
|
||||||
|
{"sha256": processed_hash, "model_name": "Processed Model"},
|
||||||
|
{"sha256": unprocessed_hash, "model_name": "Unprocessed Model"},
|
||||||
|
]
|
||||||
|
|
||||||
|
# Create progress file with only one model processed
|
||||||
|
progress_file = tmp_path / ".download_progress.json"
|
||||||
|
progress_file.write_text(
|
||||||
|
json.dumps({"processed_models": [processed_hash], "failed_models": []}),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create directory only for processed model
|
||||||
|
processed_dir = tmp_path / processed_hash
|
||||||
|
processed_dir.mkdir()
|
||||||
|
(processed_dir / "image_0.png").write_text("data")
|
||||||
|
|
||||||
|
_patch_scanners(monkeypatch, lora_scanner=StubScanner(models))
|
||||||
|
|
||||||
|
result = await manager.check_pending_models(["lora"])
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result["total_models"] == 2
|
||||||
|
assert result["pending_count"] == 1
|
||||||
|
assert result["processed_count"] == 1
|
||||||
|
assert result["needs_download"] is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.usefixtures("tmp_path")
|
||||||
|
async def test_check_pending_models_skips_models_without_hash(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
tmp_path,
|
||||||
|
settings_manager,
|
||||||
|
):
|
||||||
|
"""Test that models without sha256 are not counted as pending."""
|
||||||
|
ws_manager = RecordingWebSocketManager()
|
||||||
|
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||||
|
|
||||||
|
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
|
||||||
|
|
||||||
|
# Models - one with hash, one without
|
||||||
|
models = [
|
||||||
|
{"sha256": "a" * 64, "model_name": "Hashed Model"},
|
||||||
|
{"sha256": None, "model_name": "No Hash Model"},
|
||||||
|
{"model_name": "Missing Hash Model"}, # No sha256 key at all
|
||||||
|
]
|
||||||
|
|
||||||
|
_patch_scanners(monkeypatch, lora_scanner=StubScanner(models))
|
||||||
|
|
||||||
|
result = await manager.check_pending_models(["lora"])
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result["total_models"] == 3
|
||||||
|
assert result["pending_count"] == 1 # Only the one with hash
|
||||||
|
assert result["needs_download"] is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.usefixtures("tmp_path")
|
||||||
|
async def test_check_pending_models_handles_multiple_model_types(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
tmp_path,
|
||||||
|
settings_manager,
|
||||||
|
):
|
||||||
|
"""Test that check_pending_models aggregates counts across multiple model types."""
|
||||||
|
ws_manager = RecordingWebSocketManager()
|
||||||
|
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||||
|
|
||||||
|
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
|
||||||
|
|
||||||
|
lora_models = [
|
||||||
|
{"sha256": "a" * 64, "model_name": "Lora 1"},
|
||||||
|
{"sha256": "b" * 64, "model_name": "Lora 2"},
|
||||||
|
]
|
||||||
|
checkpoint_models = [
|
||||||
|
{"sha256": "c" * 64, "model_name": "Checkpoint 1"},
|
||||||
|
]
|
||||||
|
embedding_models = [
|
||||||
|
{"sha256": "d" * 64, "model_name": "Embedding 1"},
|
||||||
|
{"sha256": "e" * 64, "model_name": "Embedding 2"},
|
||||||
|
{"sha256": "f" * 64, "model_name": "Embedding 3"},
|
||||||
|
]
|
||||||
|
|
||||||
|
_patch_scanners(
|
||||||
|
monkeypatch,
|
||||||
|
lora_scanner=StubScanner(lora_models),
|
||||||
|
checkpoint_scanner=StubScanner(checkpoint_models),
|
||||||
|
embedding_scanner=StubScanner(embedding_models),
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await manager.check_pending_models(["lora", "checkpoint", "embedding"])
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result["total_models"] == 6 # 2 + 1 + 3
|
||||||
|
assert result["pending_count"] == 6 # All unprocessed
|
||||||
|
assert result["needs_download"] is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.usefixtures("tmp_path")
|
||||||
|
async def test_check_pending_models_returns_error_when_download_in_progress(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
tmp_path,
|
||||||
|
settings_manager,
|
||||||
|
):
|
||||||
|
"""Test that check_pending_models returns special response when download is running."""
|
||||||
|
ws_manager = RecordingWebSocketManager()
|
||||||
|
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||||
|
|
||||||
|
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
|
||||||
|
|
||||||
|
# Simulate download in progress
|
||||||
|
manager._is_downloading = True
|
||||||
|
|
||||||
|
result = await manager.check_pending_models(["lora"])
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result["is_downloading"] is True
|
||||||
|
assert result["needs_download"] is False
|
||||||
|
assert result["pending_count"] == 0
|
||||||
|
assert "already in progress" in result["message"].lower()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.usefixtures("tmp_path")
|
||||||
|
async def test_check_pending_models_handles_empty_library(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
tmp_path,
|
||||||
|
settings_manager,
|
||||||
|
):
|
||||||
|
"""Test that check_pending_models handles empty model library."""
|
||||||
|
ws_manager = RecordingWebSocketManager()
|
||||||
|
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||||
|
|
||||||
|
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
|
||||||
|
|
||||||
|
_patch_scanners(monkeypatch, lora_scanner=StubScanner([]))
|
||||||
|
|
||||||
|
result = await manager.check_pending_models(["lora"])
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result["total_models"] == 0
|
||||||
|
assert result["pending_count"] == 0
|
||||||
|
assert result["processed_count"] == 0
|
||||||
|
assert result["needs_download"] is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.usefixtures("tmp_path")
|
||||||
|
async def test_check_pending_models_reads_failed_models(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
tmp_path,
|
||||||
|
settings_manager,
|
||||||
|
):
|
||||||
|
"""Test that check_pending_models correctly reports failed model count."""
|
||||||
|
ws_manager = RecordingWebSocketManager()
|
||||||
|
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||||
|
|
||||||
|
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
|
||||||
|
|
||||||
|
models = [{"sha256": "a" * 64, "model_name": "Model"}]
|
||||||
|
|
||||||
|
# Create progress file with failed models
|
||||||
|
progress_file = tmp_path / ".download_progress.json"
|
||||||
|
progress_file.write_text(
|
||||||
|
json.dumps({"processed_models": [], "failed_models": ["a" * 64, "b" * 64]}),
|
||||||
|
encoding="utf-8",
|
||||||
|
)
|
||||||
|
|
||||||
|
_patch_scanners(monkeypatch, lora_scanner=StubScanner(models))
|
||||||
|
|
||||||
|
result = await manager.check_pending_models(["lora"])
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result["failed_count"] == 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.usefixtures("tmp_path")
|
||||||
|
async def test_check_pending_models_handles_missing_progress_file(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
tmp_path,
|
||||||
|
settings_manager,
|
||||||
|
):
|
||||||
|
"""Test that check_pending_models works correctly when no progress file exists."""
|
||||||
|
ws_manager = RecordingWebSocketManager()
|
||||||
|
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||||
|
|
||||||
|
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
|
||||||
|
|
||||||
|
models = [
|
||||||
|
{"sha256": "a" * 64, "model_name": "Model 1"},
|
||||||
|
{"sha256": "b" * 64, "model_name": "Model 2"},
|
||||||
|
]
|
||||||
|
|
||||||
|
_patch_scanners(monkeypatch, lora_scanner=StubScanner(models))
|
||||||
|
|
||||||
|
# No progress file created
|
||||||
|
result = await manager.check_pending_models(["lora"])
|
||||||
|
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result["total_models"] == 2
|
||||||
|
assert result["pending_count"] == 2 # All pending since no progress
|
||||||
|
assert result["processed_count"] == 0
|
||||||
|
assert result["failed_count"] == 0
|
||||||
|
assert result["needs_download"] is True
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.usefixtures("tmp_path")
|
||||||
|
async def test_check_pending_models_handles_corrupted_progress_file(
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
tmp_path,
|
||||||
|
settings_manager,
|
||||||
|
):
|
||||||
|
"""Test that check_pending_models handles corrupted progress file gracefully."""
|
||||||
|
ws_manager = RecordingWebSocketManager()
|
||||||
|
manager = download_module.DownloadManager(ws_manager=ws_manager)
|
||||||
|
|
||||||
|
monkeypatch.setitem(settings_manager.settings, "example_images_path", str(tmp_path))
|
||||||
|
|
||||||
|
models = [{"sha256": "a" * 64, "model_name": "Model"}]
|
||||||
|
|
||||||
|
# Create corrupted progress file
|
||||||
|
progress_file = tmp_path / ".download_progress.json"
|
||||||
|
progress_file.write_text("not valid json", encoding="utf-8")
|
||||||
|
|
||||||
|
_patch_scanners(monkeypatch, lora_scanner=StubScanner(models))
|
||||||
|
|
||||||
|
result = await manager.check_pending_models(["lora"])
|
||||||
|
|
||||||
|
# Should still succeed, treating all as unprocessed
|
||||||
|
assert result["success"] is True
|
||||||
|
assert result["total_models"] == 1
|
||||||
|
assert result["pending_count"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def settings_manager():
|
||||||
|
return get_settings_manager()
|
||||||
Reference in New Issue
Block a user