mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-04-10 12:52:15 -03:00
fix(backup): add user-state backup UI and storage
This commit is contained in:
@@ -9,11 +9,14 @@ objects that can be composed by the route controller.
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import zipfile
|
||||
from dataclasses import dataclass
|
||||
from typing import Awaitable, Callable, Dict, Mapping, Protocol
|
||||
|
||||
@@ -130,6 +133,22 @@ class MetadataArchiveManagerProtocol(Protocol):
|
||||
...
|
||||
|
||||
|
||||
class BackupServiceProtocol(Protocol):
|
||||
async def create_snapshot(
|
||||
self, *, snapshot_type: str = "manual", persist: bool = False
|
||||
) -> dict: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
async def restore_snapshot(self, archive_path: str) -> dict: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def get_status(self) -> dict: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
def get_available_snapshots(self) -> list[dict]: # pragma: no cover - protocol
|
||||
...
|
||||
|
||||
|
||||
class NodeRegistry:
|
||||
"""Thread-safe registry for tracking LoRA nodes in active workflows."""
|
||||
|
||||
@@ -746,12 +765,17 @@ class ModelExampleFilesHandler:
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
|
||||
async def _noop_backup_service() -> None:
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ServiceRegistryAdapter:
|
||||
get_lora_scanner: Callable[[], Awaitable]
|
||||
get_checkpoint_scanner: Callable[[], Awaitable]
|
||||
get_embedding_scanner: Callable[[], Awaitable]
|
||||
get_downloaded_version_history_service: Callable[[], Awaitable]
|
||||
get_backup_service: Callable[[], Awaitable] = _noop_backup_service
|
||||
|
||||
|
||||
class ModelLibraryHandler:
|
||||
@@ -1418,10 +1442,150 @@ class MetadataArchiveHandler:
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
|
||||
class BackupHandler:
|
||||
"""Handler for user-state backup export/import."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
backup_service_factory: Callable[[], Awaitable[BackupServiceProtocol]] = ServiceRegistry.get_backup_service,
|
||||
) -> None:
|
||||
self._backup_service_factory = backup_service_factory
|
||||
|
||||
async def get_backup_status(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
service = await self._backup_service_factory()
|
||||
return web.json_response(
|
||||
{
|
||||
"success": True,
|
||||
"status": service.get_status(),
|
||||
"snapshots": service.get_available_snapshots(),
|
||||
}
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.error("Error getting backup status: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def export_backup(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
service = await self._backup_service_factory()
|
||||
result = await service.create_snapshot(snapshot_type="manual", persist=False)
|
||||
headers = {
|
||||
"Content-Type": "application/zip",
|
||||
"Content-Disposition": f'attachment; filename="{result["archive_name"]}"',
|
||||
}
|
||||
return web.Response(body=result["archive_bytes"], headers=headers)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.error("Error exporting backup: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def import_backup(self, request: web.Request) -> web.Response:
|
||||
temp_path: str | None = None
|
||||
try:
|
||||
fd, temp_path = tempfile.mkstemp(
|
||||
suffix=".zip", prefix="lora-manager-backup-"
|
||||
)
|
||||
os.close(fd)
|
||||
|
||||
if request.content_type.startswith("multipart/"):
|
||||
reader = await request.multipart()
|
||||
field = await reader.next()
|
||||
uploaded = False
|
||||
while field is not None:
|
||||
if getattr(field, "filename", None):
|
||||
with open(temp_path, "wb") as handle:
|
||||
while True:
|
||||
chunk = await field.read_chunk()
|
||||
if not chunk:
|
||||
break
|
||||
handle.write(chunk)
|
||||
uploaded = True
|
||||
break
|
||||
field = await reader.next()
|
||||
if not uploaded:
|
||||
return web.json_response(
|
||||
{"success": False, "error": "Missing backup archive"},
|
||||
status=400,
|
||||
)
|
||||
else:
|
||||
body = await request.read()
|
||||
if not body:
|
||||
return web.json_response(
|
||||
{"success": False, "error": "Missing backup archive"},
|
||||
status=400,
|
||||
)
|
||||
with open(temp_path, "wb") as handle:
|
||||
handle.write(body)
|
||||
|
||||
if not temp_path or not os.path.exists(temp_path) or os.path.getsize(temp_path) == 0:
|
||||
return web.json_response(
|
||||
{"success": False, "error": "Missing backup archive"},
|
||||
status=400,
|
||||
)
|
||||
|
||||
service = await self._backup_service_factory()
|
||||
result = await service.restore_snapshot(temp_path)
|
||||
return web.json_response({"success": True, **result})
|
||||
except (ValueError, zipfile.BadZipFile) as exc:
|
||||
logger.error("Error importing backup: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=400)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.error("Error importing backup: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
finally:
|
||||
if temp_path and os.path.exists(temp_path):
|
||||
with contextlib.suppress(OSError):
|
||||
os.remove(temp_path)
|
||||
|
||||
|
||||
class FileSystemHandler:
|
||||
def __init__(self, settings_service=None) -> None:
|
||||
self._settings = settings_service or get_settings_manager()
|
||||
|
||||
async def _open_path(self, path: str) -> web.Response:
|
||||
path = os.path.abspath(path)
|
||||
if not os.path.isdir(path):
|
||||
return web.json_response(
|
||||
{"success": False, "error": "Folder does not exist"},
|
||||
status=404,
|
||||
)
|
||||
|
||||
if os.name == "nt":
|
||||
subprocess.Popen(["explorer", path])
|
||||
elif os.name == "posix":
|
||||
if _is_docker():
|
||||
return web.json_response(
|
||||
{
|
||||
"success": True,
|
||||
"message": "Running in Docker: Path available for copying",
|
||||
"path": path,
|
||||
"mode": "clipboard",
|
||||
}
|
||||
)
|
||||
if _is_wsl():
|
||||
windows_path = _wsl_to_windows_path(path)
|
||||
if windows_path:
|
||||
subprocess.Popen(["explorer.exe", windows_path])
|
||||
else:
|
||||
logger.error(
|
||||
"Failed to convert WSL path to Windows path: %s", path
|
||||
)
|
||||
return web.json_response(
|
||||
{
|
||||
"success": False,
|
||||
"error": "Failed to open folder location: path conversion error",
|
||||
},
|
||||
status=500,
|
||||
)
|
||||
elif sys.platform == "darwin":
|
||||
subprocess.Popen(["open", path])
|
||||
else:
|
||||
subprocess.Popen(["xdg-open", path])
|
||||
|
||||
return web.json_response(
|
||||
{"success": True, "message": f"Opened folder: {path}", "path": path}
|
||||
)
|
||||
|
||||
async def open_file_location(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
data = await request.json()
|
||||
@@ -1536,6 +1700,20 @@ class FileSystemHandler:
|
||||
logger.error("Failed to open settings location: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
async def open_backup_location(self, request: web.Request) -> web.Response:
|
||||
try:
|
||||
settings_file = getattr(self._settings, "settings_file", None)
|
||||
if not settings_file:
|
||||
return web.json_response(
|
||||
{"success": False, "error": "Settings file not found"}, status=404
|
||||
)
|
||||
|
||||
backup_dir = os.path.join(os.path.dirname(os.path.abspath(settings_file)), "backups")
|
||||
return await self._open_path(backup_dir)
|
||||
except Exception as exc: # pragma: no cover - defensive logging
|
||||
logger.error("Failed to open backup location: %s", exc, exc_info=True)
|
||||
return web.json_response({"success": False, "error": str(exc)}, status=500)
|
||||
|
||||
|
||||
class CustomWordsHandler:
|
||||
"""Handler for autocomplete via TagFTSIndex."""
|
||||
@@ -1840,6 +2018,7 @@ class MiscHandlerSet:
|
||||
node_registry: NodeRegistryHandler,
|
||||
model_library: ModelLibraryHandler,
|
||||
metadata_archive: MetadataArchiveHandler,
|
||||
backup: BackupHandler,
|
||||
filesystem: FileSystemHandler,
|
||||
custom_words: CustomWordsHandler,
|
||||
supporters: SupportersHandler,
|
||||
@@ -1855,6 +2034,7 @@ class MiscHandlerSet:
|
||||
self.node_registry = node_registry
|
||||
self.model_library = model_library
|
||||
self.metadata_archive = metadata_archive
|
||||
self.backup = backup
|
||||
self.filesystem = filesystem
|
||||
self.custom_words = custom_words
|
||||
self.supporters = supporters
|
||||
@@ -1886,9 +2066,13 @@ class MiscHandlerSet:
|
||||
"download_metadata_archive": self.metadata_archive.download_metadata_archive,
|
||||
"remove_metadata_archive": self.metadata_archive.remove_metadata_archive,
|
||||
"get_metadata_archive_status": self.metadata_archive.get_metadata_archive_status,
|
||||
"get_backup_status": self.backup.get_backup_status,
|
||||
"export_backup": self.backup.export_backup,
|
||||
"import_backup": self.backup.import_backup,
|
||||
"get_model_versions_status": self.model_library.get_model_versions_status,
|
||||
"open_file_location": self.filesystem.open_file_location,
|
||||
"open_settings_location": self.filesystem.open_settings_location,
|
||||
"open_backup_location": self.filesystem.open_backup_location,
|
||||
"search_custom_words": self.custom_words.search_custom_words,
|
||||
"get_supporters": self.supporters.get_supporters,
|
||||
"get_example_workflows": self.example_workflows.get_example_workflows,
|
||||
@@ -1907,4 +2091,5 @@ def build_service_registry_adapter() -> ServiceRegistryAdapter:
|
||||
get_checkpoint_scanner=ServiceRegistry.get_checkpoint_scanner,
|
||||
get_embedding_scanner=ServiceRegistry.get_embedding_scanner,
|
||||
get_downloaded_version_history_service=ServiceRegistry.get_downloaded_version_history_service,
|
||||
get_backup_service=ServiceRegistry.get_backup_service,
|
||||
)
|
||||
|
||||
@@ -62,6 +62,10 @@ MISC_ROUTE_DEFINITIONS: tuple[RouteDefinition, ...] = (
|
||||
RouteDefinition(
|
||||
"GET", "/api/lm/metadata-archive-status", "get_metadata_archive_status"
|
||||
),
|
||||
RouteDefinition("GET", "/api/lm/backup/status", "get_backup_status"),
|
||||
RouteDefinition("POST", "/api/lm/backup/export", "export_backup"),
|
||||
RouteDefinition("POST", "/api/lm/backup/import", "import_backup"),
|
||||
RouteDefinition("POST", "/api/lm/backup/open-location", "open_backup_location"),
|
||||
RouteDefinition(
|
||||
"GET", "/api/lm/model-versions-status", "get_model_versions_status"
|
||||
),
|
||||
|
||||
@@ -23,6 +23,7 @@ from .handlers.misc_handlers import (
|
||||
FileSystemHandler,
|
||||
HealthCheckHandler,
|
||||
LoraCodeHandler,
|
||||
BackupHandler,
|
||||
MetadataArchiveHandler,
|
||||
MiscHandlerSet,
|
||||
ModelExampleFilesHandler,
|
||||
@@ -116,6 +117,7 @@ class MiscRoutes:
|
||||
settings_service=self._settings,
|
||||
metadata_provider_updater=self._metadata_provider_updater,
|
||||
)
|
||||
backup = BackupHandler()
|
||||
filesystem = FileSystemHandler(settings_service=self._settings)
|
||||
node_registry_handler = NodeRegistryHandler(
|
||||
node_registry=self._node_registry,
|
||||
@@ -141,6 +143,7 @@ class MiscRoutes:
|
||||
node_registry=node_registry_handler,
|
||||
model_library=model_library,
|
||||
metadata_archive=metadata_archive,
|
||||
backup=backup,
|
||||
filesystem=filesystem,
|
||||
custom_words=custom_words,
|
||||
supporters=supporters,
|
||||
|
||||
Reference in New Issue
Block a user