mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 13:12:12 -03:00
545 lines
21 KiB
Python
545 lines
21 KiB
Python
import os
|
|
import logging
|
|
import toml
|
|
import git
|
|
import zipfile
|
|
import shutil
|
|
import tempfile
|
|
import asyncio
|
|
from aiohttp import web, ClientError
|
|
from typing import Dict, List
|
|
|
|
from py.utils.settings_paths import ensure_settings_file
|
|
from ..services.downloader import get_downloader
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
NETWORK_EXCEPTIONS = (ClientError, OSError, asyncio.TimeoutError)
|
|
|
|
|
|
class UpdateRoutes:
|
|
"""Routes for handling plugin update checks"""
|
|
|
|
@staticmethod
|
|
def setup_routes(app):
|
|
"""Register update check routes"""
|
|
app.router.add_get('/api/lm/check-updates', UpdateRoutes.check_updates)
|
|
app.router.add_get('/api/lm/version-info', UpdateRoutes.get_version_info)
|
|
app.router.add_post('/api/lm/perform-update', UpdateRoutes.perform_update)
|
|
|
|
@staticmethod
|
|
async def check_updates(request):
|
|
"""
|
|
Check for plugin updates by comparing local version with GitHub
|
|
Returns update status and version information
|
|
"""
|
|
try:
|
|
nightly = request.query.get('nightly', 'false').lower() == 'true'
|
|
|
|
# Read local version from pyproject.toml
|
|
local_version = UpdateRoutes._get_local_version()
|
|
|
|
# Get git info (commit hash, branch)
|
|
git_info = UpdateRoutes._get_git_info()
|
|
|
|
# Fetch remote version from GitHub
|
|
if nightly:
|
|
remote_version, changelog = await UpdateRoutes._get_nightly_version()
|
|
else:
|
|
remote_version, changelog = await UpdateRoutes._get_remote_version()
|
|
|
|
# Compare versions
|
|
if nightly:
|
|
# For nightly, compare commit hashes
|
|
update_available = UpdateRoutes._compare_nightly_versions(git_info, remote_version)
|
|
else:
|
|
# For stable, compare semantic versions
|
|
update_available = UpdateRoutes._compare_versions(
|
|
local_version.replace('v', ''),
|
|
remote_version.replace('v', '')
|
|
)
|
|
|
|
return web.json_response({
|
|
'success': True,
|
|
'current_version': local_version,
|
|
'latest_version': remote_version,
|
|
'update_available': update_available,
|
|
'changelog': changelog,
|
|
'git_info': git_info,
|
|
'nightly': nightly
|
|
})
|
|
|
|
except NETWORK_EXCEPTIONS as e:
|
|
logger.warning("Network unavailable during update check: %s", e)
|
|
return web.json_response({
|
|
'success': False,
|
|
'error': 'Network unavailable for update check'
|
|
})
|
|
except Exception as e:
|
|
logger.error(f"Failed to check for updates: {e}", exc_info=True)
|
|
return web.json_response({
|
|
'success': False,
|
|
'error': str(e)
|
|
})
|
|
|
|
@staticmethod
|
|
async def get_version_info(request):
|
|
"""
|
|
Returns the current version in the format 'version-short_hash'
|
|
"""
|
|
try:
|
|
# Read local version from pyproject.toml
|
|
local_version = UpdateRoutes._get_local_version().replace('v', '')
|
|
|
|
# Get git info (commit hash, branch)
|
|
git_info = UpdateRoutes._get_git_info()
|
|
short_hash = git_info['short_hash']
|
|
|
|
# Format: version-short_hash
|
|
version_string = f"{local_version}-{short_hash}"
|
|
|
|
return web.json_response({
|
|
'success': True,
|
|
'version': version_string
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to get version info: {e}", exc_info=True)
|
|
return web.json_response({
|
|
'success': False,
|
|
'error': str(e)
|
|
})
|
|
|
|
@staticmethod
|
|
async def perform_update(request):
|
|
"""
|
|
Perform Git-based update to latest release tag or main branch.
|
|
If .git is missing, fallback to ZIP download.
|
|
"""
|
|
try:
|
|
body = await request.json() if request.has_body else {}
|
|
nightly = body.get('nightly', False)
|
|
|
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
plugin_root = os.path.dirname(os.path.dirname(current_dir))
|
|
|
|
settings_path = ensure_settings_file(logger)
|
|
settings_backup = None
|
|
if os.path.exists(settings_path):
|
|
with open(settings_path, 'r', encoding='utf-8') as f:
|
|
settings_backup = f.read()
|
|
logger.info("Backed up settings.json")
|
|
|
|
git_folder = os.path.join(plugin_root, '.git')
|
|
if os.path.exists(git_folder):
|
|
# Git update
|
|
success, new_version = await UpdateRoutes._perform_git_update(plugin_root, nightly)
|
|
else:
|
|
# Fallback: Download ZIP and replace files
|
|
success, new_version = await UpdateRoutes._download_and_replace_zip(plugin_root)
|
|
|
|
if settings_backup and success:
|
|
with open(settings_path, 'w', encoding='utf-8') as f:
|
|
f.write(settings_backup)
|
|
logger.info("Restored settings.json")
|
|
|
|
if success:
|
|
return web.json_response({
|
|
'success': True,
|
|
'message': f'Successfully updated to {new_version}',
|
|
'new_version': new_version
|
|
})
|
|
else:
|
|
return web.json_response({
|
|
'success': False,
|
|
'error': 'Failed to complete update'
|
|
})
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to perform update: {e}", exc_info=True)
|
|
return web.json_response({
|
|
'success': False,
|
|
'error': str(e)
|
|
})
|
|
|
|
@staticmethod
|
|
async def _download_and_replace_zip(plugin_root: str) -> tuple[bool, str]:
|
|
"""
|
|
Download latest release ZIP from GitHub and replace plugin files.
|
|
Skips settings.json and civitai folder. Writes extracted file list to .tracking.
|
|
"""
|
|
repo_owner = "willmiao"
|
|
repo_name = "ComfyUI-Lora-Manager"
|
|
github_api = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/latest"
|
|
|
|
try:
|
|
downloader = await get_downloader()
|
|
|
|
# Get release info
|
|
success, data = await downloader.make_request(
|
|
'GET',
|
|
github_api,
|
|
use_auth=False
|
|
)
|
|
if not success:
|
|
logger.error(f"Failed to fetch release info: {data}")
|
|
return False, ""
|
|
|
|
zip_url = data.get("zipball_url")
|
|
version = data.get("tag_name", "unknown")
|
|
|
|
# Download ZIP to temporary file
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".zip") as tmp_zip:
|
|
tmp_zip_path = tmp_zip.name
|
|
|
|
success, result = await downloader.download_file(
|
|
url=zip_url,
|
|
save_path=tmp_zip_path,
|
|
use_auth=False,
|
|
allow_resume=False
|
|
)
|
|
|
|
if not success:
|
|
logger.error(f"Failed to download ZIP: {result}")
|
|
return False, ""
|
|
|
|
zip_path = tmp_zip_path
|
|
|
|
# Skip both settings.json and civitai folder
|
|
UpdateRoutes._clean_plugin_folder(plugin_root, skip_files=['settings.json', 'civitai'])
|
|
|
|
# Extract ZIP to temp dir
|
|
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
|
zip_ref.extractall(tmp_dir)
|
|
# Find extracted folder (GitHub ZIP contains a root folder)
|
|
extracted_root = next(os.scandir(tmp_dir)).path
|
|
|
|
# Copy files, skipping settings.json and civitai folder
|
|
for item in os.listdir(extracted_root):
|
|
if item == 'settings.json' or item == 'civitai':
|
|
continue
|
|
src = os.path.join(extracted_root, item)
|
|
dst = os.path.join(plugin_root, item)
|
|
if os.path.isdir(src):
|
|
if os.path.exists(dst):
|
|
shutil.rmtree(dst)
|
|
shutil.copytree(src, dst, ignore=shutil.ignore_patterns('settings.json', 'civitai'))
|
|
else:
|
|
shutil.copy2(src, dst)
|
|
|
|
# Write .tracking file: list all files under extracted_root, relative to extracted_root
|
|
# for ComfyUI Manager to work properly
|
|
tracking_info_file = os.path.join(plugin_root, '.tracking')
|
|
tracking_files = []
|
|
for root, dirs, files in os.walk(extracted_root):
|
|
# Skip civitai folder and its contents
|
|
rel_root = os.path.relpath(root, extracted_root)
|
|
if rel_root == 'civitai' or rel_root.startswith('civitai' + os.sep):
|
|
continue
|
|
for file in files:
|
|
rel_path = os.path.relpath(os.path.join(root, file), extracted_root)
|
|
# Skip settings.json and any file under civitai
|
|
if rel_path == 'settings.json' or rel_path.startswith('civitai' + os.sep):
|
|
continue
|
|
tracking_files.append(rel_path.replace("\\", "/"))
|
|
with open(tracking_info_file, "w", encoding='utf-8') as file:
|
|
file.write('\n'.join(tracking_files))
|
|
|
|
os.remove(zip_path)
|
|
logger.info(f"Updated plugin via ZIP to {version}")
|
|
return True, version
|
|
|
|
except Exception as e:
|
|
logger.error(f"ZIP update failed: {e}", exc_info=True)
|
|
return False, ""
|
|
|
|
def _clean_plugin_folder(plugin_root, skip_files=None):
|
|
skip_files = skip_files or []
|
|
for item in os.listdir(plugin_root):
|
|
if item in skip_files:
|
|
continue
|
|
path = os.path.join(plugin_root, item)
|
|
if os.path.isdir(path):
|
|
shutil.rmtree(path)
|
|
else:
|
|
os.remove(path)
|
|
|
|
@staticmethod
|
|
async def _get_nightly_version() -> tuple[str, List[str]]:
|
|
"""
|
|
Fetch latest commit from main branch
|
|
"""
|
|
repo_owner = "willmiao"
|
|
repo_name = "ComfyUI-Lora-Manager"
|
|
|
|
# Use GitHub API to fetch the latest commit from main branch
|
|
github_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/commits/main"
|
|
|
|
try:
|
|
downloader = await get_downloader()
|
|
success, data = await downloader.make_request('GET', github_url, custom_headers={'Accept': 'application/vnd.github+json'})
|
|
|
|
if not success:
|
|
logger.warning(f"Failed to fetch GitHub commit: {data}")
|
|
return "main", []
|
|
|
|
commit_sha = data.get('sha', '')[:7] # Short hash
|
|
commit_message = data.get('commit', {}).get('message', '')
|
|
|
|
# Format as "main-{short_hash}"
|
|
version = f"main-{commit_sha}"
|
|
|
|
# Use commit message as changelog
|
|
changelog = [commit_message] if commit_message else []
|
|
|
|
return version, changelog
|
|
|
|
except NETWORK_EXCEPTIONS as e:
|
|
logger.warning("Unable to reach GitHub for nightly version: %s", e)
|
|
return "main", []
|
|
except Exception as e:
|
|
logger.error(f"Error fetching nightly version: {e}", exc_info=True)
|
|
return "main", []
|
|
|
|
@staticmethod
|
|
def _compare_nightly_versions(local_git_info: Dict[str, str], remote_version: str) -> bool:
|
|
"""
|
|
Compare local commit hash with remote main branch
|
|
"""
|
|
try:
|
|
local_hash = local_git_info.get('short_hash', 'unknown')
|
|
if local_hash == 'unknown':
|
|
return True # Assume update available if we can't get local hash
|
|
|
|
# Extract remote hash from version string (format: "main-{hash}")
|
|
if '-' in remote_version:
|
|
remote_hash = remote_version.split('-')[-1]
|
|
return local_hash != remote_hash
|
|
|
|
return True # Default to update available
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error comparing nightly versions: {e}")
|
|
return False
|
|
|
|
@staticmethod
|
|
async def _perform_git_update(plugin_root: str, nightly: bool = False) -> tuple[bool, str]:
|
|
"""
|
|
Perform Git-based update using GitPython
|
|
|
|
Args:
|
|
plugin_root: Path to the plugin root directory
|
|
nightly: Whether to update to main branch or latest release
|
|
|
|
Returns:
|
|
tuple: (success, new_version)
|
|
"""
|
|
try:
|
|
# Open the Git repository
|
|
repo = git.Repo(plugin_root)
|
|
|
|
# Fetch latest changes
|
|
origin = repo.remotes.origin
|
|
origin.fetch()
|
|
|
|
if nightly:
|
|
# Switch to main branch and pull latest
|
|
main_branch = 'main'
|
|
if main_branch not in [branch.name for branch in repo.branches]:
|
|
# Create local main branch if it doesn't exist
|
|
repo.create_head(main_branch, origin.refs.main)
|
|
|
|
repo.heads[main_branch].checkout()
|
|
origin.pull(main_branch)
|
|
|
|
# Get new commit hash
|
|
new_version = f"main-{repo.head.commit.hexsha[:7]}"
|
|
|
|
else:
|
|
# Get latest release tag
|
|
tags = sorted(repo.tags, key=lambda t: t.commit.committed_datetime, reverse=True)
|
|
if not tags:
|
|
logger.error("No tags found in repository")
|
|
return False, ""
|
|
|
|
latest_tag = tags[0]
|
|
|
|
# Checkout to latest tag
|
|
repo.git.checkout(latest_tag.name)
|
|
|
|
new_version = latest_tag.name
|
|
|
|
logger.info(f"Successfully updated to {new_version}")
|
|
return True, new_version
|
|
|
|
except git.exc.GitError as e:
|
|
logger.error(f"Git error during update: {e}")
|
|
return False, ""
|
|
except Exception as e:
|
|
logger.error(f"Error during Git update: {e}")
|
|
return False, ""
|
|
|
|
@staticmethod
|
|
def _get_local_version() -> str:
|
|
"""Get local plugin version from pyproject.toml"""
|
|
try:
|
|
# Find the plugin's pyproject.toml file
|
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
plugin_root = os.path.dirname(os.path.dirname(current_dir))
|
|
pyproject_path = os.path.join(plugin_root, 'pyproject.toml')
|
|
|
|
# Read and parse the toml file
|
|
if os.path.exists(pyproject_path):
|
|
with open(pyproject_path, 'r', encoding='utf-8') as f:
|
|
project_data = toml.load(f)
|
|
version = project_data.get('project', {}).get('version', '0.0.0')
|
|
return f"v{version}"
|
|
else:
|
|
logger.warning(f"pyproject.toml not found at {pyproject_path}")
|
|
return "v0.0.0"
|
|
|
|
except Exception as e:
|
|
logger.error(f"Failed to get local version: {e}", exc_info=True)
|
|
return "v0.0.0"
|
|
|
|
@staticmethod
|
|
def _get_git_info() -> Dict[str, str]:
|
|
"""Get Git repository information"""
|
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
plugin_root = os.path.dirname(os.path.dirname(current_dir))
|
|
|
|
git_info = {
|
|
'commit_hash': 'unknown',
|
|
'short_hash': 'stable',
|
|
'branch': 'unknown',
|
|
'commit_date': 'unknown'
|
|
}
|
|
|
|
try:
|
|
# Check if we're in a git repository
|
|
if not os.path.exists(os.path.join(plugin_root, '.git')):
|
|
return git_info
|
|
|
|
repo = git.Repo(plugin_root)
|
|
commit = repo.head.commit
|
|
git_info['commit_hash'] = commit.hexsha
|
|
git_info['short_hash'] = commit.hexsha[:7]
|
|
git_info['branch'] = repo.active_branch.name if not repo.head.is_detached else 'detached'
|
|
git_info['commit_date'] = commit.committed_datetime.strftime('%Y-%m-%d')
|
|
except Exception as e:
|
|
logger.warning(f"Error getting git info: {e}")
|
|
|
|
return git_info
|
|
|
|
@staticmethod
|
|
async def _get_remote_version() -> tuple[str, List[str]]:
|
|
"""
|
|
Fetch remote version from GitHub
|
|
Returns:
|
|
tuple: (version string, changelog list)
|
|
"""
|
|
repo_owner = "willmiao"
|
|
repo_name = "ComfyUI-Lora-Manager"
|
|
|
|
# Use GitHub API to fetch the latest release
|
|
github_url = f"https://api.github.com/repos/{repo_owner}/{repo_name}/releases/latest"
|
|
|
|
try:
|
|
downloader = await get_downloader()
|
|
success, data = await downloader.make_request('GET', github_url, custom_headers={'Accept': 'application/vnd.github+json'})
|
|
|
|
if not success:
|
|
logger.warning(f"Failed to fetch GitHub release: {data}")
|
|
return "v0.0.0", []
|
|
|
|
version = data.get('tag_name', '')
|
|
if not version.startswith('v'):
|
|
version = f"v{version}"
|
|
|
|
# Extract changelog from release notes
|
|
body = data.get('body', '')
|
|
changelog = UpdateRoutes._parse_changelog(body)
|
|
|
|
return version, changelog
|
|
|
|
except NETWORK_EXCEPTIONS as e:
|
|
logger.warning("Unable to reach GitHub for release info: %s", e)
|
|
return "v0.0.0", []
|
|
except Exception as e:
|
|
logger.error(f"Error fetching remote version: {e}", exc_info=True)
|
|
return "v0.0.0", []
|
|
|
|
@staticmethod
|
|
def _parse_changelog(release_notes: str) -> List[str]:
|
|
"""
|
|
Parse GitHub release notes to extract changelog items
|
|
|
|
Args:
|
|
release_notes: GitHub release notes markdown text
|
|
|
|
Returns:
|
|
List of changelog items
|
|
"""
|
|
changelog = []
|
|
|
|
# Simple parsing - extract bullet points
|
|
lines = release_notes.split('\n')
|
|
for line in lines:
|
|
line = line.strip()
|
|
# Look for bullet points or numbered items
|
|
if line.startswith('- ') or line.startswith('* '):
|
|
item = line[2:].strip()
|
|
if item:
|
|
changelog.append(item)
|
|
# Match numbered items like "1. Item"
|
|
elif len(line) > 2 and line[0].isdigit() and line[1:].startswith('. '):
|
|
item = line[line.index('. ')+2:].strip()
|
|
if item:
|
|
changelog.append(item)
|
|
|
|
# If we couldn't parse specific items, use the whole text (limited)
|
|
if not changelog and release_notes:
|
|
# Limit to first 500 chars and add ellipsis
|
|
summary = release_notes.strip()[:500]
|
|
if len(release_notes) > 500:
|
|
summary += "..."
|
|
changelog.append(summary)
|
|
|
|
return changelog
|
|
|
|
@staticmethod
|
|
def _compare_versions(version1: str, version2: str) -> bool:
|
|
"""
|
|
Compare two semantic version strings
|
|
Returns True if version2 is newer than version1
|
|
Ignores any suffixes after '-' (e.g., -bugfix, -alpha)
|
|
"""
|
|
try:
|
|
# Clean version strings - remove any suffix after '-'
|
|
v1_clean = version1.split('-')[0]
|
|
v2_clean = version2.split('-')[0]
|
|
|
|
# Split versions into components
|
|
v1_parts = [int(x) for x in v1_clean.split('.')]
|
|
v2_parts = [int(x) for x in v2_clean.split('.')]
|
|
|
|
# Ensure both have 3 components (major.minor.patch)
|
|
while len(v1_parts) < 3:
|
|
v1_parts.append(0)
|
|
while len(v2_parts) < 3:
|
|
v2_parts.append(0)
|
|
|
|
# Compare version components
|
|
for i in range(3):
|
|
if v2_parts[i] > v1_parts[i]:
|
|
return True
|
|
elif v2_parts[i] < v1_parts[i]:
|
|
return False
|
|
|
|
# Versions are equal
|
|
return False
|
|
except Exception as e:
|
|
logger.error(f"Error comparing versions: {e}", exc_info=True)
|
|
return False
|