mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
Add hard reset and clean operations to ensure a clean working directory before switching branches or checking out release tags. This prevents local changes from interfering with the update process and ensures consistent behavior across both nightly and release update paths.
555 lines
21 KiB
Python
555 lines
21 KiB
Python
import os
|
|
import logging
|
|
import toml
|
|
import git
|
|
import zipfile
|
|
import shutil
|
|
import tempfile
|
|
import asyncio
|
|
from aiohttp import web, ClientError
|
|
from typing import Dict, List
|
|
|
|
from ..utils.settings_paths import ensure_settings_file
|
|
from ..services.downloader import get_downloader
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
NETWORK_EXCEPTIONS = (ClientError, OSError, asyncio.TimeoutError)
|
|
|
|
|
|
class UpdateRoutes:
|
|
"""Routes for handling plugin update checks"""
|
|
|
|
@staticmethod
|
|
def setup_routes(app):
|
|
"""Register update check routes"""
|
|
app.router.add_get('/api/lm/check-updates', UpdateRoutes.check_updates)
|
|
app.router.add_get('/api/lm/version-info', UpdateRoutes.get_version_info)
|
|
app.router.add_post('/api/lm/perform-update', UpdateRoutes.perform_update)
|
|
|
|
@staticmethod
|
|
async def check_updates(request):
|
|
"""
|
|
Check for plugin updates by comparing local version with GitHub
|
|
Returns update status and version information
|
|
"""
|
|
try:
|
|
nightly = request.query.get('nightly', 'false').lower() == 'true'
|
|
|
|
# Read local version from pyproject.toml
|
|
local_version = UpdateRoutes._get_local_version()
|
|
|
|
# Get git info (commit hash, branch)
|
|
git_info = UpdateRoutes._get_git_info()
|
|
|
|
# Fetch remote version from GitHub
|
|
if nightly:
|
|
remote_version, changelog = await UpdateRoutes._get_nightly_version()
|
|
else:
|
|
remote_version, changelog = await UpdateRoutes._get_remote_version()
|
|
|
|
# Compare versions
|
|
if nightly:
|
|
# For nightly, compare commit hashes
|
|
update_available = UpdateRoutes._compare_nightly_versions(git_info, remote_version)
|
|
else:
|
|
# For stable, compare semantic versions
|
|
update_available = UpdateRoutes._compare_versions(
|
|
local_version.replace('v', ''),
|
|
remote_version.replace('v', '')
|
|
)
|
|
|
|
return web.json_response({
|
|
'success': True,
|
|
'current_version': local_version,
|
|
'latest_version': remote_version,
|
|
'update_available': update_available,
|
|
'changelog': changelog,
|
|
'git_info': git_info,
|
|
'nightly': nightly
|
|
})
|
|
|
|
except NETWORK_EXCEPTIONS as e:
|
|
logger.warning("Network unavailable during update check: %s", e)
|
|
return web.json_response({
|
|
'success': False,
|
|
'error': 'Network unavailable for update check'
|
|
})
|
|
except Exception as e:
|
|
logger.error(f"Failed to check for updates: {e}", exc_info=True)
|
|
return web.json_response({
|
|
'success': False,
|
|
'error': str(e)
|
|
})
|
|
|
|
@staticmethod
|
|
async def get_version_info(request):
|
|
"""
|
|
Returns the current version in the format 'version-short_hash'
|
|
"""
|
|
try:
|
|
# Read local version from pyproject.toml
|
|
local_version = UpdateRoutes._get_local_version().replace('v', '')
|
|
|
|
# Get git info (commit hash, branch)
|
|
git_info = UpdateRoutes._get_git_info()
|
|
short_hash = git_info['short_hash']
|
|
|
|
# Format: version-short_hash
|
|
version_string = f"{local_version}-{short_hash}"
|
|
|
|
return web.json_response({
|
|
'success': True,
|
|
'version': version_string
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to get version info: {e}", exc_info=True)
|
|
return web.json_response({
|
|
'success': False,
|
|
'error': str(e)
|
|
})
|
|
|
|
@staticmethod
|
|
async def perform_update(request):
|
|
"""
|
|
Perform Git-based update to latest release tag or main branch.
|
|
If .git is missing, fallback to ZIP download.
|
|
"""
|
|
try:
|
|
body = await request.json() if request.has_body else {}
|
|
nightly = body.get('nightly', False)
|
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
plugin_root = os.path.dirname(os.path.dirname(current_dir))
|
|
|
|
settings_path = ensure_settings_file(logger)
|
|
settings_backup = None
|
|
if os.path.exists(settings_path):
|
|
with open(settings_path, 'r', encoding='utf-8') as f:
|
|
settings_backup = f.read()
|
|
logger.info("Backed up settings.json")
|
|
|
|
git_folder = os.path.join(plugin_root, '.git')
|
|
if os.path.exists(git_folder):
|
|
# Git update
|
|
success, new_version = await UpdateRoutes._perform_git_update(plugin_root, nightly)
|
|
else:
|
|
# Fallback: Download ZIP and replace files
|
|
success, new_version = await UpdateRoutes._download_and_replace_zip(plugin_root)
|
|
|
|
if settings_backup and success:
|
|
with open(settings_path, 'w', encoding='utf-8') as f:
|
|
f.write(settings_backup)
|
|
logger.info("Restored settings.json")
|
|
|
|
if success:
|
|
return web.json_response({
|
|
'success': True,
|
|
'message': f'Successfully updated to {new_version}',
|
|
'new_version': new_version
|
|
})
|
|
else:
|
|
return web.json_response({
|
|
'success': False,
|
|
'error': 'Failed to complete update'
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to perform update: {e}", exc_info=True)
|
|
return web.json_response({
|
|
'success': False,
|
|
'error': str(e)
|
|
})
|
|
|
|
@staticmethod
|
|
async def _download_and_replace_zip(plugin_root: str) -> tuple[bool, str]:
|
|
"""
|
|
Download latest release ZIP from GitHub and replace plugin files.
|
|
Skips settings.json and civitai folder. Writes extracted file list to .tracking.
|
|
"""
|
|
repo_owner = "willmiao"
|
|
repo_name = "ComfyUI-Lora-Manager"
|
|
github_api = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/latest"
|
|
|
|
try:
|
|
downloader = await get_downloader()
|
|
|
|
# Get release info
|
|
success, data = await downloader.make_request(
|
|
'GET',
|
|
github_api,
|
|
use_auth=False
|
|
)
|
|
if not success:
|
|
logger.error(f"Failed to fetch release info: {data}")
|
|
return False, ""
|
|
|
|
zip_url = data.get("zipball_url")
|
|
version = data.get("tag_name", "unknown")
|
|
|
|
# Download ZIP to temporary file
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp_zip:
|
|
tmp_zip_path = tmp_zip.name
|
|
|
|
success, result = await downloader.download_file(
|
|
url=zip_url,
|
|
save_path=tmp_zip_path,
|
|
use_auth=False,
|
|
allow_resume=False
|
|
)
|
|
|
|
if not success:
|
|
logger.error(f"Failed to download ZIP: {result}")
|
|
return False, ""
|
|
|
|
zip_path = tmp_zip_path
|
|
|
|
# Skip both settings.json, civitai and model cache folder
|
|
UpdateRoutes._clean_plugin_folder(plugin_root, skip_files=['settings.json', 'civitai', 'model_cache'])
|
|
|
|
# Extract ZIP to temp dir
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
|
zip_ref.extractall(tmp_dir)
|
|
# Find extracted folder (GitHub ZIP contains a root folder)
|
|
extracted_root = next(os.scandir(tmp_dir)).path
|
|
|
|
# Copy files, skipping settings.json and civitai folder
|
|
for item in os.listdir(extracted_root):
|
|
if item == 'settings.json' or item == 'civitai':
|
|
continue
|
|
src = os.path.join(extracted_root, item)
|
|
dst = os.path.join(plugin_root, item)
|
|
if os.path.isdir(src):
|
|
if os.path.exists(dst):
|
|
shutil.rmtree(dst)
|
|
shutil.copytree(src, dst, ignore=shutil.ignore_patterns('settings.json', 'civitai'))
|
|
else:
|
|
shutil.copy2(src, dst)
|
|
|
|
# Write .tracking file: list all files under extracted_root, relative to extracted_root
|
|
# for ComfyUI Manager to work properly
|
|
tracking_info_file = os.path.join(plugin_root, '.tracking')
|
|
tracking_files = []
|
|
for root, dirs, files in os.walk(extracted_root):
|
|
# Skip civitai folder and its contents
|
|
rel_root = os.path.relpath(root, extracted_root)
|
|
if rel_root == 'civitai' or rel_root.startswith('civitai' + os.sep):
|
|
continue
|
|
for file in files:
|
|
rel_path = os.path.relpath(os.path.join(root, file), extracted_root)
|
|
# Skip settings.json and any file under civitai
|
|
if rel_path == 'settings.json' or rel_path.startswith('civitai' + os.sep):
|
|
continue
|
|
tracking_files.append(rel_path.replace("\\", "/"))
|
|
with open(tracking_info_file, "w", encoding='utf-8') as file:
|
|
file.write('\n'.join(tracking_files))
|
|
|
|
os.remove(zip_path)
|
|
logger.info(f"Updated plugin via ZIP to {version}")
|
|
return True, version
|
|
|
|
except Exception as e:
|
|
logger.error(f"ZIP update failed: {e}", exc_info=True)
|
|
return False, ""
|
|
|
|
def _clean_plugin_folder(plugin_root, skip_files=None):
|
|
skip_files = skip_files or []
|
|
for item in os.listdir(plugin_root):
|
|
if item in skip_files:
|
|
continue
|
|
path = os.path.join(plugin_root, item)
|
|
if os.path.isdir(path):
|
|
shutil.rmtree(path)
|
|
else:
|
|
os.remove(path)
|
|
|
|
@staticmethod
|
|
async def _get_nightly_version() -> tuple[str, List[str]]:
|
|
"""
|
|
Fetch latest commit from main branch
|
|
"""
|
|
repo_owner = "willmiao"
|
|
repo_name = "ComfyUI-Lora-Manager"
|
|
|
|
# Use GitHub API to fetch the latest commit from main branch
|
|
github_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/commits/main"
|
|
|
|
try:
|
|
downloader = await get_downloader()
|
|
success, data = await downloader.make_request('GET', github_url, custom_headers={'Accept': 'application/vnd.github+json'})
|
|
|
|
if not success:
|
|
logger.warning(f"Failed to fetch GitHub commit: {data}")
|
|
return "main", []
|
|
|
|
commit_sha = data.get('sha', '')[:7] # Short hash
|
|
commit_message = data.get('commit', {}).get('message', '')
|
|
|
|
# Format as "main-{short_hash}"
|
|
version = f"main-{commit_sha}"
|
|
|
|
# Use commit message as changelog
|
|
changelog = [commit_message] if commit_message else []
|
|
|
|
return version, changelog
|
|
|
|
except NETWORK_EXCEPTIONS as e:
|
|
logger.warning("Unable to reach GitHub for nightly version: %s", e)
|
|
return "main", []
|
|
except Exception as e:
|
|
logger.error(f"Error fetching nightly version: {e}", exc_info=True)
|
|
return "main", []
|
|
|
|
@staticmethod
|
|
def _compare_nightly_versions(local_git_info: Dict[str, str], remote_version: str) -> bool:
|
|
"""
|
|
Compare local commit hash with remote main branch
|
|
"""
|
|
try:
|
|
local_hash = local_git_info.get('short_hash', 'unknown')
|
|
if local_hash == 'unknown':
|
|
return True # Assume update available if we can't get local hash
|
|
|
|
# Extract remote hash from version string (format: "main-{hash}")
|
|
if '-' in remote_version:
|
|
remote_hash = remote_version.split('-')[-1]
|
|
return local_hash != remote_hash
|
|
|
|
return True # Default to update available
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error comparing nightly versions: {e}")
|
|
return False
|
|
|
|
@staticmethod
|
|
async def _perform_git_update(plugin_root: str, nightly: bool = False) -> tuple[bool, str]:
|
|
"""
|
|
Perform Git-based update using GitPython
|
|
|
|
Args:
|
|
plugin_root: Path to the plugin root directory
|
|
nightly: Whether to update to main branch or latest release
|
|
|
|
Returns:
|
|
tuple: (success, new_version)
|
|
"""
|
|
try:
|
|
# Open the Git repository
|
|
repo = git.Repo(plugin_root)
|
|
|
|
# Fetch latest changes
|
|
origin = repo.remotes.origin
|
|
origin.fetch()
|
|
|
|
if nightly:
|
|
# Reset to discard any local changes
|
|
repo.git.reset('--hard')
|
|
# Clean untracked files
|
|
repo.git.clean('-fd')
|
|
|
|
# Switch to main branch and pull latest
|
|
main_branch = 'main'
|
|
if main_branch not in [branch.name for branch in repo.branches]:
|
|
# Create local main branch if it doesn't exist
|
|
repo.create_head(main_branch, origin.refs.main)
|
|
|
|
repo.heads[main_branch].checkout()
|
|
origin.pull(main_branch)
|
|
|
|
# Get new commit hash
|
|
new_version = f"main-{repo.head.commit.hexsha[:7]}"
|
|
|
|
else:
|
|
# Reset to discard any local changes
|
|
repo.git.reset('--hard')
|
|
# Clean untracked files
|
|
repo.git.clean('-fd')
|
|
|
|
# Get latest release tag
|
|
tags = sorted(repo.tags, key=lambda t: t.commit.committed_datetime, reverse=True)
|
|
if not tags:
|
|
logger.error("No tags found in repository")
|
|
return False, ""
|
|
|
|
latest_tag = tags[0]
|
|
|
|
# Checkout to latest tag
|
|
repo.git.checkout(latest_tag.name)
|
|
|
|
new_version = latest_tag.name
|
|
|
|
logger.info(f"Successfully updated to {new_version}")
|
|
return True, new_version
|
|
|
|
except git.exc.GitError as e:
|
|
logger.error(f"Git error during update: {e}")
|
|
return False, ""
|
|
except Exception as e:
|
|
logger.error(f"Error during Git update: {e}")
|
|
return False, ""
|
|
|
|
@staticmethod
|
|
def _get_local_version() -> str:
|
|
"""Get local plugin version from pyproject.toml"""
|
|
try:
|
|
# Find the plugin's pyproject.toml file
|
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
plugin_root = os.path.dirname(os.path.dirname(current_dir))
|
|
pyproject_path = os.path.join(plugin_root, 'pyproject.toml')
|
|
|
|
# Read and parse the toml file
|
|
if os.path.exists(pyproject_path):
|
|
with open(pyproject_path, 'r', encoding='utf-8') as f:
|
|
project_data = toml.load(f)
|
|
version = project_data.get('project', {}).get('version', '0.0.0')
|
|
return f"v{version}"
|
|
else:
|
|
logger.warning(f"pyproject.toml not found at {pyproject_path}")
|
|
return "v0.0.0"
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to get local version: {e}", exc_info=True)
|
|
return "v0.0.0"
|
|
|
|
@staticmethod
|
|
def _get_git_info() -> Dict[str, str]:
|
|
"""Get Git repository information"""
|
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
plugin_root = os.path.dirname(os.path.dirname(current_dir))
|
|
|
|
git_info = {
|
|
'commit_hash': 'unknown',
|
|
'short_hash': 'stable',
|
|
'branch': 'unknown',
|
|
'commit_date': 'unknown'
|
|
}
|
|
|
|
try:
|
|
# Check if we're in a git repository
|
|
if not os.path.exists(os.path.join(plugin_root, '.git')):
|
|
return git_info
|
|
|
|
repo = git.Repo(plugin_root)
|
|
commit = repo.head.commit
|
|
git_info['commit_hash'] = commit.hexsha
|
|
git_info['short_hash'] = commit.hexsha[:7]
|
|
git_info['branch'] = repo.active_branch.name if not repo.head.is_detached else 'detached'
|
|
git_info['commit_date'] = commit.committed_datetime.strftime('%Y-%m-%d')
|
|
except Exception as e:
|
|
logger.warning(f"Error getting git info: {e}")
|
|
|
|
return git_info
|
|
|
|
@staticmethod
|
|
async def _get_remote_version() -> tuple[str, List[str]]:
|
|
"""
|
|
Fetch remote version from GitHub
|
|
Returns:
|
|
tuple: (version string, changelog list)
|
|
"""
|
|
repo_owner = "willmiao"
|
|
repo_name = "ComfyUI-Lora-Manager"
|
|
|
|
# Use GitHub API to fetch the latest release
|
|
github_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/latest"
|
|
|
|
try:
|
|
downloader = await get_downloader()
|
|
success, data = await downloader.make_request('GET', github_url, custom_headers={'Accept': 'application/vnd.github+json'})
|
|
|
|
if not success:
|
|
logger.warning(f"Failed to fetch GitHub release: {data}")
|
|
return "v0.0.0", []
|
|
|
|
version = data.get('tag_name', '')
|
|
if not version.startswith('v'):
|
|
version = f"v{version}"
|
|
|
|
# Extract changelog from release notes
|
|
body = data.get('body', '')
|
|
changelog = UpdateRoutes._parse_changelog(body)
|
|
|
|
return version, changelog
|
|
|
|
except NETWORK_EXCEPTIONS as e:
|
|
logger.warning("Unable to reach GitHub for release info: %s", e)
|
|
return "v0.0.0", []
|
|
except Exception as e:
|
|
logger.error(f"Error fetching remote version: {e}", exc_info=True)
|
|
return "v0.0.0", []
|
|
|
|
@staticmethod
|
|
def _parse_changelog(release_notes: str) -> List[str]:
|
|
"""
|
|
Parse GitHub release notes to extract changelog items
|
|
|
|
Args:
|
|
release_notes: GitHub release notes markdown text
|
|
|
|
Returns:
|
|
List of changelog items
|
|
"""
|
|
changelog = []
|
|
|
|
# Simple parsing - extract bullet points
|
|
lines = release_notes.split('\n')
|
|
for line in lines:
|
|
line = line.strip()
|
|
# Look for bullet points or numbered items
|
|
if line.startswith('- ') or line.startswith('* '):
|
|
item = line[2:].strip()
|
|
if item:
|
|
changelog.append(item)
|
|
# Match numbered items like "1. Item"
|
|
elif len(line) > 2 and line[0].isdigit() and line[1:].startswith('. '):
|
|
item = line[line.index('. ')+2:].strip()
|
|
if item:
|
|
changelog.append(item)
|
|
|
|
# If we couldn't parse specific items, use the whole text (limited)
|
|
if not changelog and release_notes:
|
|
# Limit to first 500 chars and add ellipsis
|
|
summary = release_notes.strip()[:500]
|
|
if len(release_notes) > 500:
|
|
summary += "..."
|
|
changelog.append(summary)
|
|
|
|
return changelog
|
|
|
|
@staticmethod
|
|
def _compare_versions(version1: str, version2: str) -> bool:
|
|
"""
|
|
Compare two semantic version strings
|
|
Returns True if version2 is newer than version1
|
|
Ignores any suffixes after '-' (e.g., -bugfix, -alpha)
|
|
"""
|
|
try:
|
|
# Clean version strings - remove any suffix after '-'
|
|
v1_clean = version1.split('-')[0]
|
|
v2_clean = version2.split('-')[0]
|
|
|
|
# Split versions into components
|
|
v1_parts = [int(x) for x in v1_clean.split('.')]
|
|
v2_parts = [int(x) for x in v2_clean.split('.')]
|
|
|
|
# Ensure both have 3 components (major.minor.patch)
|
|
while len(v1_parts) < 3:
|
|
v1_parts.append(0)
|
|
while len(v2_parts) < 3:
|
|
v2_parts.append(0)
|
|
|
|
# Compare version components
|
|
for i in range(3):
|
|
if v2_parts[i] > v1_parts[i]:
|
|
return True
|
|
elif v2_parts[i] < v1_parts[i]:
|
|
return False
|
|
|
|
# Versions are equal
|
|
return False
|
|
except Exception as e:
|
|
logger.error(f"Error comparing versions: {e}", exc_info=True)
|
|
return False
|