mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 07:05:43 -03:00
fix(preview): resolve CORS error when setting CivitAI remote media as preview
- Add new endpoint POST /api/lm/{prefix}/set-preview-from-url to handle
remote image downloads server-side, avoiding CORS issues
- Use rewrite_preview_url() to download optimized smaller images (450px width)
- Use Downloader service for reliable downloads with retry logic and proxy support
- Update frontend to call new endpoint instead of fetching images in browser
fixes #837
This commit is contained in:
@@ -74,18 +74,14 @@ class ModelPageView:
|
|||||||
os.path.dirname(os.path.dirname(os.path.dirname(current_file)))
|
os.path.dirname(os.path.dirname(os.path.dirname(current_file)))
|
||||||
)
|
)
|
||||||
supporters_path = os.path.join(root_dir, "data", "supporters.json")
|
supporters_path = os.path.join(root_dir, "data", "supporters.json")
|
||||||
|
|
||||||
if os.path.exists(supporters_path):
|
if os.path.exists(supporters_path):
|
||||||
with open(supporters_path, "r", encoding="utf-8") as f:
|
with open(supporters_path, "r", encoding="utf-8") as f:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self._logger.debug(f"Failed to load supporters data: {e}")
|
self._logger.debug(f"Failed to load supporters data: {e}")
|
||||||
|
|
||||||
return {
|
return {"specialThanks": [], "allSupporters": [], "totalCount": 0}
|
||||||
"specialThanks": [],
|
|
||||||
"allSupporters": [],
|
|
||||||
"totalCount": 0
|
|
||||||
}
|
|
||||||
|
|
||||||
def _get_app_version(self) -> str:
|
def _get_app_version(self) -> str:
|
||||||
version = "1.0.0"
|
version = "1.0.0"
|
||||||
@@ -404,20 +400,26 @@ class ModelManagementHandler:
|
|||||||
return web.json_response(
|
return web.json_response(
|
||||||
{"success": False, "error": "Model not found in cache"}, status=404
|
{"success": False, "error": "Model not found in cache"}, status=404
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if hash needs to be calculated (lazy hash for checkpoints)
|
# Check if hash needs to be calculated (lazy hash for checkpoints)
|
||||||
sha256 = model_data.get("sha256")
|
sha256 = model_data.get("sha256")
|
||||||
hash_status = model_data.get("hash_status", "completed")
|
hash_status = model_data.get("hash_status", "completed")
|
||||||
|
|
||||||
if not sha256 or hash_status != "completed":
|
if not sha256 or hash_status != "completed":
|
||||||
# For checkpoints, calculate hash on-demand
|
# For checkpoints, calculate hash on-demand
|
||||||
scanner = self._service.scanner
|
scanner = self._service.scanner
|
||||||
if hasattr(scanner, 'calculate_hash_for_model'):
|
if hasattr(scanner, "calculate_hash_for_model"):
|
||||||
self._logger.info(f"Lazy hash calculation triggered for {file_path}")
|
self._logger.info(
|
||||||
|
f"Lazy hash calculation triggered for {file_path}"
|
||||||
|
)
|
||||||
sha256 = await scanner.calculate_hash_for_model(file_path)
|
sha256 = await scanner.calculate_hash_for_model(file_path)
|
||||||
if not sha256:
|
if not sha256:
|
||||||
return web.json_response(
|
return web.json_response(
|
||||||
{"success": False, "error": "Failed to calculate SHA256 hash"}, status=500
|
{
|
||||||
|
"success": False,
|
||||||
|
"error": "Failed to calculate SHA256 hash",
|
||||||
|
},
|
||||||
|
status=500,
|
||||||
)
|
)
|
||||||
# Update model_data with new hash
|
# Update model_data with new hash
|
||||||
model_data["sha256"] = sha256
|
model_data["sha256"] = sha256
|
||||||
@@ -545,6 +547,153 @@ class ModelManagementHandler:
|
|||||||
self._logger.error("Error replacing preview: %s", exc, exc_info=True)
|
self._logger.error("Error replacing preview: %s", exc, exc_info=True)
|
||||||
return web.Response(text=str(exc), status=500)
|
return web.Response(text=str(exc), status=500)
|
||||||
|
|
||||||
|
async def set_preview_from_url(self, request: web.Request) -> web.Response:
|
||||||
|
"""Set a preview image from a remote URL (e.g., CivitAI)."""
|
||||||
|
try:
|
||||||
|
from ...utils.civitai_utils import rewrite_preview_url
|
||||||
|
from ...services.downloader import get_downloader
|
||||||
|
|
||||||
|
data = await request.json()
|
||||||
|
model_path = data.get("model_path")
|
||||||
|
image_url = data.get("image_url")
|
||||||
|
nsfw_level = data.get("nsfw_level", 0)
|
||||||
|
|
||||||
|
if not model_path:
|
||||||
|
return web.json_response(
|
||||||
|
{"success": False, "error": "Model path is required"}, status=400
|
||||||
|
)
|
||||||
|
|
||||||
|
if not image_url:
|
||||||
|
return web.json_response(
|
||||||
|
{"success": False, "error": "Image URL is required"}, status=400
|
||||||
|
)
|
||||||
|
|
||||||
|
# Rewrite URL to use optimized rendition if it's a Civitai URL
|
||||||
|
optimized_url, was_rewritten = rewrite_preview_url(
|
||||||
|
image_url, media_type="image"
|
||||||
|
)
|
||||||
|
if was_rewritten and optimized_url:
|
||||||
|
self._logger.info(
|
||||||
|
f"Rewritten preview URL to optimized version: {optimized_url}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
optimized_url = image_url
|
||||||
|
|
||||||
|
# Download the image using the Downloader service
|
||||||
|
self._logger.info(
|
||||||
|
f"Downloading preview from {optimized_url} for {model_path}"
|
||||||
|
)
|
||||||
|
downloader = await get_downloader()
|
||||||
|
success, preview_data, headers = await downloader.download_to_memory(
|
||||||
|
optimized_url, use_auth=False, return_headers=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if not success:
|
||||||
|
return web.json_response(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"error": f"Failed to download image: {preview_data}",
|
||||||
|
},
|
||||||
|
status=502,
|
||||||
|
)
|
||||||
|
|
||||||
|
# preview_data is bytes when success is True
|
||||||
|
preview_bytes = (
|
||||||
|
preview_data
|
||||||
|
if isinstance(preview_data, bytes)
|
||||||
|
else preview_data.encode("utf-8")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Determine content type from response headers
|
||||||
|
content_type = (
|
||||||
|
headers.get("Content-Type", "image/jpeg") if headers else "image/jpeg"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract original filename from URL
|
||||||
|
original_filename = None
|
||||||
|
if "?" in image_url:
|
||||||
|
url_path = image_url.split("?")[0]
|
||||||
|
else:
|
||||||
|
url_path = image_url
|
||||||
|
original_filename = url_path.split("/")[-1] if "/" in url_path else None
|
||||||
|
|
||||||
|
result = await self._preview_service.replace_preview(
|
||||||
|
model_path=model_path,
|
||||||
|
preview_data=preview_data,
|
||||||
|
content_type=content_type,
|
||||||
|
original_filename=original_filename,
|
||||||
|
nsfw_level=nsfw_level,
|
||||||
|
update_preview_in_cache=self._service.scanner.update_preview_in_cache,
|
||||||
|
metadata_loader=self._metadata_sync.load_local_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
return web.json_response(
|
||||||
|
{
|
||||||
|
"success": True,
|
||||||
|
"preview_url": config.get_preview_static_url(
|
||||||
|
result["preview_path"]
|
||||||
|
),
|
||||||
|
"preview_nsfw_level": result["preview_nsfw_level"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
self._logger.error("Error setting preview from URL: %s", exc, exc_info=True)
|
||||||
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||||
|
|
||||||
|
if not image_url:
|
||||||
|
return web.json_response(
|
||||||
|
{"success": False, "error": "Image URL is required"}, status=400
|
||||||
|
)
|
||||||
|
|
||||||
|
# Download the image from the remote URL
|
||||||
|
self._logger.info(f"Downloading preview from {image_url} for {model_path}")
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
async with session.get(image_url) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
return web.json_response(
|
||||||
|
{
|
||||||
|
"success": False,
|
||||||
|
"error": f"Failed to download image: HTTP {response.status}",
|
||||||
|
},
|
||||||
|
status=502,
|
||||||
|
)
|
||||||
|
|
||||||
|
content_type = response.headers.get("Content-Type", "image/jpeg")
|
||||||
|
preview_data = await response.read()
|
||||||
|
|
||||||
|
# Extract original filename from URL
|
||||||
|
original_filename = None
|
||||||
|
if "?" in image_url:
|
||||||
|
url_path = image_url.split("?")[0]
|
||||||
|
else:
|
||||||
|
url_path = image_url
|
||||||
|
original_filename = (
|
||||||
|
url_path.split("/")[-1] if "/" in url_path else None
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await self._preview_service.replace_preview(
|
||||||
|
model_path=model_path,
|
||||||
|
preview_data=preview_bytes,
|
||||||
|
content_type=content_type,
|
||||||
|
original_filename=original_filename,
|
||||||
|
nsfw_level=nsfw_level,
|
||||||
|
update_preview_in_cache=self._service.scanner.update_preview_in_cache,
|
||||||
|
metadata_loader=self._metadata_sync.load_local_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
return web.json_response(
|
||||||
|
{
|
||||||
|
"success": True,
|
||||||
|
"preview_url": config.get_preview_static_url(
|
||||||
|
result["preview_path"]
|
||||||
|
),
|
||||||
|
"preview_nsfw_level": result["preview_nsfw_level"],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
self._logger.error("Error setting preview from URL: %s", exc, exc_info=True)
|
||||||
|
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||||
|
|
||||||
async def save_metadata(self, request: web.Request) -> web.Response:
|
async def save_metadata(self, request: web.Request) -> web.Response:
|
||||||
try:
|
try:
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
@@ -835,9 +984,7 @@ class ModelQueryHandler:
|
|||||||
# Format response
|
# Format response
|
||||||
group = {"hash": sha256, "models": []}
|
group = {"hash": sha256, "models": []}
|
||||||
for model in sorted_models:
|
for model in sorted_models:
|
||||||
group["models"].append(
|
group["models"].append(await self._service.format_response(model))
|
||||||
await self._service.format_response(model)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Only include groups with 2+ models after filtering
|
# Only include groups with 2+ models after filtering
|
||||||
if len(group["models"]) > 1:
|
if len(group["models"]) > 1:
|
||||||
@@ -866,7 +1013,9 @@ class ModelQueryHandler:
|
|||||||
"favorites_only": request.query.get("favorites_only", "").lower() == "true",
|
"favorites_only": request.query.get("favorites_only", "").lower() == "true",
|
||||||
}
|
}
|
||||||
|
|
||||||
def _apply_duplicate_filters(self, models: List[Dict[str, Any]], filters: Dict[str, Any]) -> List[Dict[str, Any]]:
|
def _apply_duplicate_filters(
|
||||||
|
self, models: List[Dict[str, Any]], filters: Dict[str, Any]
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
"""Apply filters to a list of models within a duplicate group."""
|
"""Apply filters to a list of models within a duplicate group."""
|
||||||
result = models
|
result = models
|
||||||
|
|
||||||
@@ -907,7 +1056,9 @@ class ModelQueryHandler:
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def _sort_duplicate_group(self, models: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
|
def _sort_duplicate_group(
|
||||||
|
self, models: List[Dict[str, Any]]
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
"""Sort models: originals first (left), copies (with -????. pattern) last (right)."""
|
"""Sort models: originals first (left), copies (with -????. pattern) last (right)."""
|
||||||
if len(models) <= 1:
|
if len(models) <= 1:
|
||||||
return models
|
return models
|
||||||
@@ -1192,10 +1343,13 @@ class ModelDownloadHandler:
|
|||||||
data["source"] = source
|
data["source"] = source
|
||||||
if file_params_json:
|
if file_params_json:
|
||||||
import json
|
import json
|
||||||
|
|
||||||
try:
|
try:
|
||||||
data["file_params"] = json.loads(file_params_json)
|
data["file_params"] = json.loads(file_params_json)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
self._logger.warning("Invalid file_params JSON: %s", file_params_json)
|
self._logger.warning(
|
||||||
|
"Invalid file_params JSON: %s", file_params_json
|
||||||
|
)
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
future = loop.create_future()
|
future = loop.create_future()
|
||||||
@@ -1926,7 +2080,8 @@ class ModelUpdateHandler:
|
|||||||
from dataclasses import replace
|
from dataclasses import replace
|
||||||
|
|
||||||
new_record = replace(
|
new_record = replace(
|
||||||
record, versions=list(version_map.values()),
|
record,
|
||||||
|
versions=list(version_map.values()),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Optionally persist to database for caching
|
# Optionally persist to database for caching
|
||||||
@@ -2141,6 +2296,7 @@ class ModelUpdateHandler:
|
|||||||
if version.early_access_ends_at:
|
if version.early_access_ends_at:
|
||||||
try:
|
try:
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
ea_date = datetime.fromisoformat(
|
ea_date = datetime.fromisoformat(
|
||||||
version.early_access_ends_at.replace("Z", "+00:00")
|
version.early_access_ends_at.replace("Z", "+00:00")
|
||||||
)
|
)
|
||||||
@@ -2148,7 +2304,7 @@ class ModelUpdateHandler:
|
|||||||
except (ValueError, AttributeError):
|
except (ValueError, AttributeError):
|
||||||
# If date parsing fails, treat as active EA (conservative)
|
# If date parsing fails, treat as active EA (conservative)
|
||||||
is_early_access = True
|
is_early_access = True
|
||||||
elif getattr(version, 'is_early_access', False):
|
elif getattr(version, "is_early_access", False):
|
||||||
# Fallback to basic EA flag from bulk API
|
# Fallback to basic EA flag from bulk API
|
||||||
is_early_access = True
|
is_early_access = True
|
||||||
|
|
||||||
@@ -2228,6 +2384,7 @@ class ModelHandlerSet:
|
|||||||
"fetch_all_civitai": self.civitai.fetch_all_civitai,
|
"fetch_all_civitai": self.civitai.fetch_all_civitai,
|
||||||
"relink_civitai": self.management.relink_civitai,
|
"relink_civitai": self.management.relink_civitai,
|
||||||
"replace_preview": self.management.replace_preview,
|
"replace_preview": self.management.replace_preview,
|
||||||
|
"set_preview_from_url": self.management.set_preview_from_url,
|
||||||
"save_metadata": self.management.save_metadata,
|
"save_metadata": self.management.save_metadata,
|
||||||
"add_tags": self.management.add_tags,
|
"add_tags": self.management.add_tags,
|
||||||
"rename_model": self.management.rename_model,
|
"rename_model": self.management.rename_model,
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Route registrar for model endpoints."""
|
"""Route registrar for model endpoints."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@@ -27,6 +28,9 @@ COMMON_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
|||||||
RouteDefinition("POST", "/api/lm/{prefix}/fetch-all-civitai", "fetch_all_civitai"),
|
RouteDefinition("POST", "/api/lm/{prefix}/fetch-all-civitai", "fetch_all_civitai"),
|
||||||
RouteDefinition("POST", "/api/lm/{prefix}/relink-civitai", "relink_civitai"),
|
RouteDefinition("POST", "/api/lm/{prefix}/relink-civitai", "relink_civitai"),
|
||||||
RouteDefinition("POST", "/api/lm/{prefix}/replace-preview", "replace_preview"),
|
RouteDefinition("POST", "/api/lm/{prefix}/replace-preview", "replace_preview"),
|
||||||
|
RouteDefinition(
|
||||||
|
"POST", "/api/lm/{prefix}/set-preview-from-url", "set_preview_from_url"
|
||||||
|
),
|
||||||
RouteDefinition("POST", "/api/lm/{prefix}/save-metadata", "save_metadata"),
|
RouteDefinition("POST", "/api/lm/{prefix}/save-metadata", "save_metadata"),
|
||||||
RouteDefinition("POST", "/api/lm/{prefix}/add-tags", "add_tags"),
|
RouteDefinition("POST", "/api/lm/{prefix}/add-tags", "add_tags"),
|
||||||
RouteDefinition("POST", "/api/lm/{prefix}/rename", "rename_model"),
|
RouteDefinition("POST", "/api/lm/{prefix}/rename", "rename_model"),
|
||||||
@@ -36,7 +40,9 @@ COMMON_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
|||||||
RouteDefinition("POST", "/api/lm/{prefix}/move_models_bulk", "move_models_bulk"),
|
RouteDefinition("POST", "/api/lm/{prefix}/move_models_bulk", "move_models_bulk"),
|
||||||
RouteDefinition("GET", "/api/lm/{prefix}/auto-organize", "auto_organize_models"),
|
RouteDefinition("GET", "/api/lm/{prefix}/auto-organize", "auto_organize_models"),
|
||||||
RouteDefinition("POST", "/api/lm/{prefix}/auto-organize", "auto_organize_models"),
|
RouteDefinition("POST", "/api/lm/{prefix}/auto-organize", "auto_organize_models"),
|
||||||
RouteDefinition("GET", "/api/lm/{prefix}/auto-organize-progress", "get_auto_organize_progress"),
|
RouteDefinition(
|
||||||
|
"GET", "/api/lm/{prefix}/auto-organize-progress", "get_auto_organize_progress"
|
||||||
|
),
|
||||||
RouteDefinition("GET", "/api/lm/{prefix}/top-tags", "get_top_tags"),
|
RouteDefinition("GET", "/api/lm/{prefix}/top-tags", "get_top_tags"),
|
||||||
RouteDefinition("GET", "/api/lm/{prefix}/base-models", "get_base_models"),
|
RouteDefinition("GET", "/api/lm/{prefix}/base-models", "get_base_models"),
|
||||||
RouteDefinition("GET", "/api/lm/{prefix}/model-types", "get_model_types"),
|
RouteDefinition("GET", "/api/lm/{prefix}/model-types", "get_model_types"),
|
||||||
@@ -44,30 +50,60 @@ COMMON_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
|||||||
RouteDefinition("GET", "/api/lm/{prefix}/roots", "get_model_roots"),
|
RouteDefinition("GET", "/api/lm/{prefix}/roots", "get_model_roots"),
|
||||||
RouteDefinition("GET", "/api/lm/{prefix}/folders", "get_folders"),
|
RouteDefinition("GET", "/api/lm/{prefix}/folders", "get_folders"),
|
||||||
RouteDefinition("GET", "/api/lm/{prefix}/folder-tree", "get_folder_tree"),
|
RouteDefinition("GET", "/api/lm/{prefix}/folder-tree", "get_folder_tree"),
|
||||||
RouteDefinition("GET", "/api/lm/{prefix}/unified-folder-tree", "get_unified_folder_tree"),
|
RouteDefinition(
|
||||||
|
"GET", "/api/lm/{prefix}/unified-folder-tree", "get_unified_folder_tree"
|
||||||
|
),
|
||||||
RouteDefinition("GET", "/api/lm/{prefix}/find-duplicates", "find_duplicate_models"),
|
RouteDefinition("GET", "/api/lm/{prefix}/find-duplicates", "find_duplicate_models"),
|
||||||
RouteDefinition("GET", "/api/lm/{prefix}/find-filename-conflicts", "find_filename_conflicts"),
|
RouteDefinition(
|
||||||
|
"GET", "/api/lm/{prefix}/find-filename-conflicts", "find_filename_conflicts"
|
||||||
|
),
|
||||||
RouteDefinition("GET", "/api/lm/{prefix}/get-notes", "get_model_notes"),
|
RouteDefinition("GET", "/api/lm/{prefix}/get-notes", "get_model_notes"),
|
||||||
RouteDefinition("GET", "/api/lm/{prefix}/preview-url", "get_model_preview_url"),
|
RouteDefinition("GET", "/api/lm/{prefix}/preview-url", "get_model_preview_url"),
|
||||||
RouteDefinition("GET", "/api/lm/{prefix}/civitai-url", "get_model_civitai_url"),
|
RouteDefinition("GET", "/api/lm/{prefix}/civitai-url", "get_model_civitai_url"),
|
||||||
RouteDefinition("GET", "/api/lm/{prefix}/metadata", "get_model_metadata"),
|
RouteDefinition("GET", "/api/lm/{prefix}/metadata", "get_model_metadata"),
|
||||||
RouteDefinition("GET", "/api/lm/{prefix}/model-description", "get_model_description"),
|
RouteDefinition(
|
||||||
|
"GET", "/api/lm/{prefix}/model-description", "get_model_description"
|
||||||
|
),
|
||||||
RouteDefinition("GET", "/api/lm/{prefix}/relative-paths", "get_relative_paths"),
|
RouteDefinition("GET", "/api/lm/{prefix}/relative-paths", "get_relative_paths"),
|
||||||
RouteDefinition("GET", "/api/lm/{prefix}/civitai/versions/{model_id}", "get_civitai_versions"),
|
RouteDefinition(
|
||||||
RouteDefinition("GET", "/api/lm/{prefix}/civitai/model/version/{modelVersionId}", "get_civitai_model_by_version"),
|
"GET", "/api/lm/{prefix}/civitai/versions/{model_id}", "get_civitai_versions"
|
||||||
RouteDefinition("GET", "/api/lm/{prefix}/civitai/model/hash/{hash}", "get_civitai_model_by_hash"),
|
),
|
||||||
RouteDefinition("POST", "/api/lm/{prefix}/updates/refresh", "refresh_model_updates"),
|
RouteDefinition(
|
||||||
RouteDefinition("POST", "/api/lm/{prefix}/updates/fetch-missing-license", "fetch_missing_civitai_license_data"),
|
"GET",
|
||||||
RouteDefinition("POST", "/api/lm/{prefix}/updates/ignore", "set_model_update_ignore"),
|
"/api/lm/{prefix}/civitai/model/version/{modelVersionId}",
|
||||||
RouteDefinition("POST", "/api/lm/{prefix}/updates/ignore-version", "set_version_update_ignore"),
|
"get_civitai_model_by_version",
|
||||||
RouteDefinition("GET", "/api/lm/{prefix}/updates/status/{model_id}", "get_model_update_status"),
|
),
|
||||||
RouteDefinition("GET", "/api/lm/{prefix}/updates/versions/{model_id}", "get_model_versions"),
|
RouteDefinition(
|
||||||
|
"GET", "/api/lm/{prefix}/civitai/model/hash/{hash}", "get_civitai_model_by_hash"
|
||||||
|
),
|
||||||
|
RouteDefinition(
|
||||||
|
"POST", "/api/lm/{prefix}/updates/refresh", "refresh_model_updates"
|
||||||
|
),
|
||||||
|
RouteDefinition(
|
||||||
|
"POST",
|
||||||
|
"/api/lm/{prefix}/updates/fetch-missing-license",
|
||||||
|
"fetch_missing_civitai_license_data",
|
||||||
|
),
|
||||||
|
RouteDefinition(
|
||||||
|
"POST", "/api/lm/{prefix}/updates/ignore", "set_model_update_ignore"
|
||||||
|
),
|
||||||
|
RouteDefinition(
|
||||||
|
"POST", "/api/lm/{prefix}/updates/ignore-version", "set_version_update_ignore"
|
||||||
|
),
|
||||||
|
RouteDefinition(
|
||||||
|
"GET", "/api/lm/{prefix}/updates/status/{model_id}", "get_model_update_status"
|
||||||
|
),
|
||||||
|
RouteDefinition(
|
||||||
|
"GET", "/api/lm/{prefix}/updates/versions/{model_id}", "get_model_versions"
|
||||||
|
),
|
||||||
RouteDefinition("POST", "/api/lm/download-model", "download_model"),
|
RouteDefinition("POST", "/api/lm/download-model", "download_model"),
|
||||||
RouteDefinition("GET", "/api/lm/download-model-get", "download_model_get"),
|
RouteDefinition("GET", "/api/lm/download-model-get", "download_model_get"),
|
||||||
RouteDefinition("GET", "/api/lm/cancel-download-get", "cancel_download_get"),
|
RouteDefinition("GET", "/api/lm/cancel-download-get", "cancel_download_get"),
|
||||||
RouteDefinition("GET", "/api/lm/pause-download", "pause_download_get"),
|
RouteDefinition("GET", "/api/lm/pause-download", "pause_download_get"),
|
||||||
RouteDefinition("GET", "/api/lm/resume-download", "resume_download_get"),
|
RouteDefinition("GET", "/api/lm/resume-download", "resume_download_get"),
|
||||||
RouteDefinition("GET", "/api/lm/download-progress/{download_id}", "get_download_progress"),
|
RouteDefinition(
|
||||||
|
"GET", "/api/lm/download-progress/{download_id}", "get_download_progress"
|
||||||
|
),
|
||||||
RouteDefinition("POST", "/api/lm/{prefix}/cancel-task", "cancel_task"),
|
RouteDefinition("POST", "/api/lm/{prefix}/cancel-task", "cancel_task"),
|
||||||
RouteDefinition("GET", "/{prefix}", "handle_models_page"),
|
RouteDefinition("GET", "/{prefix}", "handle_models_page"),
|
||||||
)
|
)
|
||||||
@@ -94,12 +130,18 @@ class ModelRouteRegistrar:
|
|||||||
definitions: Iterable[RouteDefinition] = COMMON_ROUTE_DEFINITIONS,
|
definitions: Iterable[RouteDefinition] = COMMON_ROUTE_DEFINITIONS,
|
||||||
) -> None:
|
) -> None:
|
||||||
for definition in definitions:
|
for definition in definitions:
|
||||||
self._bind_route(definition.method, definition.build_path(prefix), handler_lookup[definition.handler_name])
|
self._bind_route(
|
||||||
|
definition.method,
|
||||||
|
definition.build_path(prefix),
|
||||||
|
handler_lookup[definition.handler_name],
|
||||||
|
)
|
||||||
|
|
||||||
def add_route(self, method: str, path: str, handler: Callable) -> None:
|
def add_route(self, method: str, path: str, handler: Callable) -> None:
|
||||||
self._bind_route(method, path, handler)
|
self._bind_route(method, path, handler)
|
||||||
|
|
||||||
def add_prefixed_route(self, method: str, path_template: str, prefix: str, handler: Callable) -> None:
|
def add_prefixed_route(
|
||||||
|
self, method: str, path_template: str, prefix: str, handler: Callable
|
||||||
|
) -> None:
|
||||||
self._bind_route(method, path_template.replace("{prefix}", prefix), handler)
|
self._bind_route(method, path_template.replace("{prefix}", prefix), handler)
|
||||||
|
|
||||||
def _bind_route(self, method: str, path: str, handler: Callable) -> None:
|
def _bind_route(self, method: str, path: str, handler: Callable) -> None:
|
||||||
|
|||||||
@@ -86,6 +86,7 @@ export function getApiEndpoints(modelType) {
|
|||||||
|
|
||||||
// Preview management
|
// Preview management
|
||||||
replacePreview: `/api/lm/${modelType}/replace-preview`,
|
replacePreview: `/api/lm/${modelType}/replace-preview`,
|
||||||
|
setPreviewFromUrl: `/api/lm/${modelType}/set-preview-from-url`,
|
||||||
|
|
||||||
// Query operations
|
// Query operations
|
||||||
scan: `/api/lm/${modelType}/scan`,
|
scan: `/api/lm/${modelType}/scan`,
|
||||||
|
|||||||
@@ -307,6 +307,56 @@ export class BaseModelApiClient {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set a preview from a remote URL (e.g., CivitAI)
|
||||||
|
* @param {string} filePath - Path to the model file
|
||||||
|
* @param {string} imageUrl - Remote image URL
|
||||||
|
* @param {number} nsfwLevel - NSFW level for the preview
|
||||||
|
*/
|
||||||
|
async setPreviewFromUrl(filePath, imageUrl, nsfwLevel = 0) {
|
||||||
|
try {
|
||||||
|
state.loadingManager.showSimpleLoading('Setting preview from URL...');
|
||||||
|
|
||||||
|
const response = await fetch(this.apiConfig.endpoints.setPreviewFromUrl, {
|
||||||
|
method: 'POST',
|
||||||
|
headers: { 'Content-Type': 'application/json' },
|
||||||
|
body: JSON.stringify({
|
||||||
|
model_path: filePath,
|
||||||
|
image_url: imageUrl,
|
||||||
|
nsfw_level: nsfwLevel
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error('Failed to set preview from URL');
|
||||||
|
}
|
||||||
|
|
||||||
|
const data = await response.json();
|
||||||
|
const pageState = this.getPageState();
|
||||||
|
|
||||||
|
const timestamp = Date.now();
|
||||||
|
if (pageState.previewVersions) {
|
||||||
|
pageState.previewVersions.set(filePath, timestamp);
|
||||||
|
|
||||||
|
const storageKey = `${this.modelType}_preview_versions`;
|
||||||
|
saveMapToStorage(storageKey, pageState.previewVersions);
|
||||||
|
}
|
||||||
|
|
||||||
|
const updateData = {
|
||||||
|
preview_url: data.preview_url,
|
||||||
|
preview_nsfw_level: data.preview_nsfw_level
|
||||||
|
};
|
||||||
|
|
||||||
|
state.virtualScroller.updateSingleItem(filePath, updateData);
|
||||||
|
showToast('toast.api.previewUpdated', {}, 'success');
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error setting preview from URL:', error);
|
||||||
|
showToast('toast.api.previewUploadFailed', {}, 'error');
|
||||||
|
} finally {
|
||||||
|
state.loadingManager.hide();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async saveModelMetadata(filePath, data) {
|
async saveModelMetadata(filePath, data) {
|
||||||
try {
|
try {
|
||||||
state.loadingManager.showSimpleLoading('Saving metadata...');
|
state.loadingManager.showSimpleLoading('Saving metadata...');
|
||||||
|
|||||||
@@ -527,17 +527,18 @@ function initSetPreviewHandlers(container) {
|
|||||||
const response = await fetch(mediaElement.dataset.localSrc);
|
const response = await fetch(mediaElement.dataset.localSrc);
|
||||||
const blob = await response.blob();
|
const blob = await response.blob();
|
||||||
const file = new File([blob], 'preview.jpg', { type: blob.type });
|
const file = new File([blob], 'preview.jpg', { type: blob.type });
|
||||||
|
|
||||||
// Use the existing baseModelApi uploadPreview method with nsfw level
|
// Use the existing baseModelApi uploadPreview method with nsfw level
|
||||||
await apiClient.uploadPreview(modelFilePath, file, modelType, nsfwLevel);
|
await apiClient.uploadPreview(modelFilePath, file, nsfwLevel);
|
||||||
} else {
|
} else {
|
||||||
// We need to download the remote file first
|
// Remote file - send URL to backend to download (avoids CORS issues)
|
||||||
const response = await fetch(mediaElement.src);
|
const imageUrl = mediaElement.src || mediaElement.dataset.remoteSrc;
|
||||||
const blob = await response.blob();
|
if (!imageUrl) {
|
||||||
const file = new File([blob], 'preview.jpg', { type: blob.type });
|
throw new Error('No image URL available');
|
||||||
|
}
|
||||||
// Use the existing baseModelApi uploadPreview method with nsfw level
|
|
||||||
await apiClient.uploadPreview(modelFilePath, file, modelType, nsfwLevel);
|
// Use the new setPreviewFromUrl method
|
||||||
|
await apiClient.setPreviewFromUrl(modelFilePath, imageUrl, nsfwLevel);
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
console.error('Error setting preview:', error);
|
console.error('Error setting preview:', error);
|
||||||
|
|||||||
@@ -31,7 +31,9 @@ from py.utils.metadata_manager import MetadataManager
|
|||||||
class DummyRoutes(BaseModelRoutes):
|
class DummyRoutes(BaseModelRoutes):
|
||||||
template_name = "dummy.html"
|
template_name = "dummy.html"
|
||||||
|
|
||||||
def setup_specific_routes(self, registrar, prefix: str) -> None: # pragma: no cover - no extra routes in smoke tests
|
def setup_specific_routes(
|
||||||
|
self, registrar, prefix: str
|
||||||
|
) -> None: # pragma: no cover - no extra routes in smoke tests
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def __init__(self, service=None):
|
def __init__(self, service=None):
|
||||||
@@ -59,7 +61,9 @@ class NullUpdateRecord:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def in_library_version_ids(self) -> list[int]:
|
def in_library_version_ids(self) -> list[int]:
|
||||||
return [version.version_id for version in self.versions if version.is_in_library]
|
return [
|
||||||
|
version.version_id for version in self.versions if version.is_in_library
|
||||||
|
]
|
||||||
|
|
||||||
def has_update(self) -> bool:
|
def has_update(self) -> bool:
|
||||||
return False
|
return False
|
||||||
@@ -86,7 +90,9 @@ class NullModelUpdateService:
|
|||||||
)
|
)
|
||||||
for version_id in version_ids
|
for version_id in version_ids
|
||||||
]
|
]
|
||||||
return NullUpdateRecord(model_type=model_type, model_id=model_id, versions=versions)
|
return NullUpdateRecord(
|
||||||
|
model_type=model_type, model_id=model_id, versions=versions
|
||||||
|
)
|
||||||
|
|
||||||
async def set_should_ignore(self, model_type, model_id, should_ignore):
|
async def set_should_ignore(self, model_type, model_id, should_ignore):
|
||||||
return NullUpdateRecord(
|
return NullUpdateRecord(
|
||||||
@@ -95,7 +101,9 @@ class NullModelUpdateService:
|
|||||||
should_ignore_model=should_ignore,
|
should_ignore_model=should_ignore,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def set_version_should_ignore(self, model_type, model_id, version_id, should_ignore):
|
async def set_version_should_ignore(
|
||||||
|
self, model_type, model_id, version_id, should_ignore
|
||||||
|
):
|
||||||
return await self.set_should_ignore(model_type, model_id, should_ignore)
|
return await self.set_should_ignore(model_type, model_id, should_ignore)
|
||||||
|
|
||||||
async def get_record(self, *args, **kwargs):
|
async def get_record(self, *args, **kwargs):
|
||||||
@@ -167,7 +175,9 @@ def download_manager_stub():
|
|||||||
|
|
||||||
|
|
||||||
def test_list_models_returns_formatted_items(mock_service, mock_scanner):
|
def test_list_models_returns_formatted_items(mock_service, mock_scanner):
|
||||||
mock_service.paginated_items = [{"file_path": "/tmp/demo.safetensors", "name": "Demo"}]
|
mock_service.paginated_items = [
|
||||||
|
{"file_path": "/tmp/demo.safetensors", "name": "Demo"}
|
||||||
|
]
|
||||||
|
|
||||||
async def scenario():
|
async def scenario():
|
||||||
client = await create_test_client(mock_service)
|
client = await create_test_client(mock_service)
|
||||||
@@ -176,7 +186,13 @@ def test_list_models_returns_formatted_items(mock_service, mock_scanner):
|
|||||||
payload = await response.json()
|
payload = await response.json()
|
||||||
|
|
||||||
assert response.status == 200
|
assert response.status == 200
|
||||||
assert payload["items"] == [{"file_path": "/tmp/demo.safetensors", "name": "Demo", "formatted": True}]
|
assert payload["items"] == [
|
||||||
|
{
|
||||||
|
"file_path": "/tmp/demo.safetensors",
|
||||||
|
"name": "Demo",
|
||||||
|
"formatted": True,
|
||||||
|
}
|
||||||
|
]
|
||||||
assert payload["total"] == 1
|
assert payload["total"] == 1
|
||||||
assert mock_service.formatted == payload["items"]
|
assert mock_service.formatted == payload["items"]
|
||||||
finally:
|
finally:
|
||||||
@@ -220,7 +236,9 @@ def test_routes_return_service_not_ready_when_unattached():
|
|||||||
asyncio.run(scenario())
|
asyncio.run(scenario())
|
||||||
|
|
||||||
|
|
||||||
def test_delete_model_updates_cache_and_hash_index(mock_service, mock_scanner, tmp_path: Path):
|
def test_delete_model_updates_cache_and_hash_index(
|
||||||
|
mock_service, mock_scanner, tmp_path: Path
|
||||||
|
):
|
||||||
model_path = tmp_path / "sample.safetensors"
|
model_path = tmp_path / "sample.safetensors"
|
||||||
model_path.write_bytes(b"model")
|
model_path.write_bytes(b"model")
|
||||||
mock_scanner._cache.raw_data = [{"file_path": str(model_path)}]
|
mock_scanner._cache.raw_data = [{"file_path": str(model_path)}]
|
||||||
@@ -271,17 +289,23 @@ def test_replace_preview_writes_file_and_updates_cache(
|
|||||||
)
|
)
|
||||||
|
|
||||||
form = FormData()
|
form = FormData()
|
||||||
form.add_field("preview_file", b"binary-data", filename="preview.png", content_type="image/png")
|
form.add_field(
|
||||||
|
"preview_file", b"binary-data", filename="preview.png", content_type="image/png"
|
||||||
|
)
|
||||||
form.add_field("model_path", str(model_path))
|
form.add_field("model_path", str(model_path))
|
||||||
form.add_field("nsfw_level", "2")
|
form.add_field("nsfw_level", "2")
|
||||||
|
|
||||||
async def scenario():
|
async def scenario():
|
||||||
client = await create_test_client(mock_service)
|
client = await create_test_client(mock_service)
|
||||||
try:
|
try:
|
||||||
response = await client.post("/api/lm/test-models/replace-preview", data=form)
|
response = await client.post(
|
||||||
|
"/api/lm/test-models/replace-preview", data=form
|
||||||
|
)
|
||||||
payload = await response.json()
|
payload = await response.json()
|
||||||
|
|
||||||
expected_preview = str((tmp_path / "preview-model.webp")).replace(os.sep, "/")
|
expected_preview = str((tmp_path / "preview-model.webp")).replace(
|
||||||
|
os.sep, "/"
|
||||||
|
)
|
||||||
assert response.status == 200
|
assert response.status == 200
|
||||||
assert payload["success"] is True
|
assert payload["success"] is True
|
||||||
assert payload["preview_url"] == "/static/preview-model.webp"
|
assert payload["preview_url"] == "/static/preview-model.webp"
|
||||||
@@ -299,6 +323,66 @@ def test_replace_preview_writes_file_and_updates_cache(
|
|||||||
asyncio.run(scenario())
|
asyncio.run(scenario())
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_preview_from_url_downloads_and_updates_cache(
|
||||||
|
mock_service,
|
||||||
|
mock_scanner,
|
||||||
|
monkeypatch: pytest.MonkeyPatch,
|
||||||
|
tmp_path: Path,
|
||||||
|
):
|
||||||
|
"""Test that set_preview_from_url endpoint downloads remote images and sets them as preview."""
|
||||||
|
model_path = tmp_path / "url-preview-model.safetensors"
|
||||||
|
model_path.write_bytes(b"model")
|
||||||
|
metadata_path = tmp_path / "url-preview-model.metadata.json"
|
||||||
|
metadata_path.write_text(json.dumps({"file_path": str(model_path)}))
|
||||||
|
|
||||||
|
mock_scanner._cache.raw_data = [{"file_path": str(model_path)}]
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
config,
|
||||||
|
"get_preview_static_url",
|
||||||
|
lambda preview_path: f"/static/{Path(preview_path).name}",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def scenario():
|
||||||
|
client = await create_test_client(mock_service)
|
||||||
|
try:
|
||||||
|
# Mock the Downloader to return a test image
|
||||||
|
from py.services import downloader
|
||||||
|
|
||||||
|
class FakeDownloader:
|
||||||
|
async def download_to_memory(
|
||||||
|
self, url, use_auth=False, return_headers=True
|
||||||
|
):
|
||||||
|
return True, b"fake-image-data", {"Content-Type": "image/jpeg"}
|
||||||
|
|
||||||
|
async def fake_get_downloader():
|
||||||
|
return FakeDownloader()
|
||||||
|
|
||||||
|
monkeypatch.setattr(downloader, "get_downloader", fake_get_downloader)
|
||||||
|
|
||||||
|
response = await client.post(
|
||||||
|
"/api/lm/test-models/set-preview-from-url",
|
||||||
|
json={
|
||||||
|
"model_path": str(model_path),
|
||||||
|
"image_url": "https://example.com/image.jpg",
|
||||||
|
"nsfw_level": 3,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
payload = await response.json()
|
||||||
|
|
||||||
|
expected_preview = str((tmp_path / "url-preview-model.webp")).replace(
|
||||||
|
os.sep, "/"
|
||||||
|
)
|
||||||
|
assert response.status == 200
|
||||||
|
assert payload["success"] is True
|
||||||
|
assert payload["preview_url"] == "/static/url-preview-model.webp"
|
||||||
|
assert Path(expected_preview).exists()
|
||||||
|
finally:
|
||||||
|
await client.close()
|
||||||
|
|
||||||
|
asyncio.run(scenario())
|
||||||
|
|
||||||
|
|
||||||
def test_fetch_civitai_hydrates_metadata_before_sync(
|
def test_fetch_civitai_hydrates_metadata_before_sync(
|
||||||
mock_service,
|
mock_service,
|
||||||
mock_scanner,
|
mock_scanner,
|
||||||
@@ -370,9 +454,15 @@ def test_fetch_civitai_hydrates_metadata_before_sync(
|
|||||||
save_calls: list[tuple[str, dict]] = []
|
save_calls: list[tuple[str, dict]] = []
|
||||||
captured: dict[str, dict] = {}
|
captured: dict[str, dict] = {}
|
||||||
|
|
||||||
monkeypatch.setattr(MetadataManager, "load_metadata", staticmethod(fake_load_metadata))
|
monkeypatch.setattr(
|
||||||
monkeypatch.setattr(MetadataManager, "save_metadata", staticmethod(fake_save_metadata))
|
MetadataManager, "load_metadata", staticmethod(fake_load_metadata)
|
||||||
monkeypatch.setattr(MetadataSyncService, "fetch_and_update_model", fake_fetch_and_update_model)
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
MetadataManager, "save_metadata", staticmethod(fake_save_metadata)
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
MetadataSyncService, "fetch_and_update_model", fake_fetch_and_update_model
|
||||||
|
)
|
||||||
|
|
||||||
async def scenario():
|
async def scenario():
|
||||||
client = await create_test_client(mock_service)
|
client = await create_test_client(mock_service)
|
||||||
@@ -386,7 +476,10 @@ def test_fetch_civitai_hydrates_metadata_before_sync(
|
|||||||
assert response.status == 200
|
assert response.status == 200
|
||||||
assert payload["success"] is True
|
assert payload["success"] is True
|
||||||
assert captured["model_data"]["custom_field"] == "preserve"
|
assert captured["model_data"]["custom_field"] == "preserve"
|
||||||
assert captured["model_data"]["civitai"]["images"][0]["url"] == "https://example.com/existing.png"
|
assert (
|
||||||
|
captured["model_data"]["civitai"]["images"][0]["url"]
|
||||||
|
== "https://example.com/existing.png"
|
||||||
|
)
|
||||||
assert captured["model_data"]["civitai"]["trainedWords"] == ["keep"]
|
assert captured["model_data"]["civitai"]["trainedWords"] == ["keep"]
|
||||||
assert captured["model_data"]["civitai"]["id"] == 99
|
assert captured["model_data"]["civitai"]["id"] == 99
|
||||||
finally:
|
finally:
|
||||||
@@ -398,7 +491,10 @@ def test_fetch_civitai_hydrates_metadata_before_sync(
|
|||||||
saved_path, saved_payload = save_calls[0]
|
saved_path, saved_payload = save_calls[0]
|
||||||
assert saved_path == str(metadata_path)
|
assert saved_path == str(metadata_path)
|
||||||
assert saved_payload["custom_field"] == "preserve"
|
assert saved_payload["custom_field"] == "preserve"
|
||||||
assert saved_payload["civitai"]["images"][0]["url"] == "https://example.com/existing.png"
|
assert (
|
||||||
|
saved_payload["civitai"]["images"][0]["url"]
|
||||||
|
== "https://example.com/existing.png"
|
||||||
|
)
|
||||||
assert saved_payload["civitai"]["trainedWords"] == ["keep"]
|
assert saved_payload["civitai"]["trainedWords"] == ["keep"]
|
||||||
assert saved_payload["civitai"]["id"] == 99
|
assert saved_payload["civitai"]["id"] == 99
|
||||||
assert saved_payload["legacy_field"] == "legacy"
|
assert saved_payload["legacy_field"] == "legacy"
|
||||||
@@ -432,11 +528,22 @@ def test_download_model_invokes_download_manager(
|
|||||||
assert call_args["download_id"] == payload["download_id"]
|
assert call_args["download_id"] == payload["download_id"]
|
||||||
progress = ws_manager.get_download_progress(payload["download_id"])
|
progress = ws_manager.get_download_progress(payload["download_id"])
|
||||||
assert progress is not None
|
assert progress is not None
|
||||||
expected_progress = round(download_manager_stub.last_progress_snapshot.percent_complete)
|
expected_progress = round(
|
||||||
|
download_manager_stub.last_progress_snapshot.percent_complete
|
||||||
|
)
|
||||||
assert progress["progress"] == expected_progress
|
assert progress["progress"] == expected_progress
|
||||||
assert progress["bytes_downloaded"] == download_manager_stub.last_progress_snapshot.bytes_downloaded
|
assert (
|
||||||
assert progress["total_bytes"] == download_manager_stub.last_progress_snapshot.total_bytes
|
progress["bytes_downloaded"]
|
||||||
assert progress["bytes_per_second"] == download_manager_stub.last_progress_snapshot.bytes_per_second
|
== download_manager_stub.last_progress_snapshot.bytes_downloaded
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
progress["total_bytes"]
|
||||||
|
== download_manager_stub.last_progress_snapshot.total_bytes
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
progress["bytes_per_second"]
|
||||||
|
== download_manager_stub.last_progress_snapshot.bytes_per_second
|
||||||
|
)
|
||||||
assert "timestamp" in progress
|
assert "timestamp" in progress
|
||||||
|
|
||||||
progress_response = await client.get(
|
progress_response = await client.get(
|
||||||
@@ -526,21 +633,30 @@ def test_auto_organize_progress_returns_latest_snapshot(mock_service):
|
|||||||
async def scenario():
|
async def scenario():
|
||||||
client = await create_test_client(mock_service)
|
client = await create_test_client(mock_service)
|
||||||
try:
|
try:
|
||||||
await ws_manager.broadcast_auto_organize_progress({"status": "processing", "percent": 50})
|
await ws_manager.broadcast_auto_organize_progress(
|
||||||
|
{"status": "processing", "percent": 50}
|
||||||
|
)
|
||||||
|
|
||||||
response = await client.get("/api/lm/test-models/auto-organize-progress")
|
response = await client.get("/api/lm/test-models/auto-organize-progress")
|
||||||
payload = await response.json()
|
payload = await response.json()
|
||||||
|
|
||||||
assert response.status == 200
|
assert response.status == 200
|
||||||
assert payload == {"success": True, "progress": {"status": "processing", "percent": 50}}
|
assert payload == {
|
||||||
|
"success": True,
|
||||||
|
"progress": {"status": "processing", "percent": 50},
|
||||||
|
}
|
||||||
finally:
|
finally:
|
||||||
await client.close()
|
await client.close()
|
||||||
|
|
||||||
asyncio.run(scenario())
|
asyncio.run(scenario())
|
||||||
|
|
||||||
|
|
||||||
def test_auto_organize_route_emits_progress(mock_service, monkeypatch: pytest.MonkeyPatch):
|
def test_auto_organize_route_emits_progress(
|
||||||
async def fake_auto_organize(self, file_paths=None, progress_callback=None, exclusion_patterns=None):
|
mock_service, monkeypatch: pytest.MonkeyPatch
|
||||||
|
):
|
||||||
|
async def fake_auto_organize(
|
||||||
|
self, file_paths=None, progress_callback=None, exclusion_patterns=None
|
||||||
|
):
|
||||||
result = AutoOrganizeResult()
|
result = AutoOrganizeResult()
|
||||||
result.total = 1
|
result.total = 1
|
||||||
result.processed = 1
|
result.processed = 1
|
||||||
@@ -549,8 +665,12 @@ def test_auto_organize_route_emits_progress(mock_service, monkeypatch: pytest.Mo
|
|||||||
result.failure_count = 0
|
result.failure_count = 0
|
||||||
result.operation_type = "bulk"
|
result.operation_type = "bulk"
|
||||||
if progress_callback is not None:
|
if progress_callback is not None:
|
||||||
await progress_callback.on_progress({"type": "auto_organize_progress", "status": "started"})
|
await progress_callback.on_progress(
|
||||||
await progress_callback.on_progress({"type": "auto_organize_progress", "status": "completed"})
|
{"type": "auto_organize_progress", "status": "started"}
|
||||||
|
)
|
||||||
|
await progress_callback.on_progress(
|
||||||
|
{"type": "auto_organize_progress", "status": "completed"}
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
@@ -562,7 +682,9 @@ def test_auto_organize_route_emits_progress(mock_service, monkeypatch: pytest.Mo
|
|||||||
async def scenario():
|
async def scenario():
|
||||||
client = await create_test_client(mock_service)
|
client = await create_test_client(mock_service)
|
||||||
try:
|
try:
|
||||||
response = await client.post("/api/lm/test-models/auto-organize", json={"file_paths": []})
|
response = await client.post(
|
||||||
|
"/api/lm/test-models/auto-organize", json={"file_paths": []}
|
||||||
|
)
|
||||||
payload = await response.json()
|
payload = await response.json()
|
||||||
|
|
||||||
assert response.status == 200
|
assert response.status == 200
|
||||||
|
|||||||
Reference in New Issue
Block a user