checkpoint

This commit is contained in:
Will Miao
2025-02-13 11:34:27 +08:00
parent 2222731f36
commit b7aca9b6fc
11 changed files with 303 additions and 86 deletions

1
.gitignore vendored
View File

@@ -1 +1,2 @@
__pycache__/ __pycache__/
settings.json

View File

@@ -8,6 +8,7 @@ from ..config import config
from ..services.lora_scanner import LoraScanner from ..services.lora_scanner import LoraScanner
from operator import itemgetter from operator import itemgetter
from ..services.websocket_manager import ws_manager from ..services.websocket_manager import ws_manager
from ..services.settings_manager import settings # 添加这行
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -31,6 +32,7 @@ class ApiRoutes:
app.router.add_get('/api/lora-roots', routes.get_lora_roots) app.router.add_get('/api/lora-roots', routes.get_lora_roots)
app.router.add_get('/api/civitai/versions/{model_id}', routes.get_civitai_versions) app.router.add_get('/api/civitai/versions/{model_id}', routes.get_civitai_versions)
app.router.add_post('/api/download-lora', routes.download_lora) app.router.add_post('/api/download-lora', routes.download_lora)
app.router.add_post('/api/settings', routes.update_settings)
async def delete_model(self, request: web.Request) -> web.Response: async def delete_model(self, request: web.Request) -> web.Response:
"""Handle model deletion request""" """Handle model deletion request"""
@@ -491,8 +493,8 @@ class ApiRoutes:
) )
if result.get('success'): if result.get('success'):
# 更新缓存 # 更新缓存 - 使用正确的扫描方法
await self.scanner.rescan_directory(save_dir) await self.scanner.scan_directory(save_dir) # Changed from rescan_directory to scan_directory
return web.json_response(result) return web.json_response(result)
else: else:
return web.Response(status=500, text=result.get('error', 'Download failed')) return web.Response(status=500, text=result.get('error', 'Download failed'))
@@ -501,6 +503,20 @@ class ApiRoutes:
logger.error(f"Error downloading LoRA: {e}") logger.error(f"Error downloading LoRA: {e}")
return web.Response(status=500, text=str(e)) return web.Response(status=500, text=str(e))
async def update_settings(self, request: web.Request) -> web.Response:
"""Update application settings"""
try:
data = await request.json()
# Validate and update settings
if 'civitai_api_key' in data:
settings.set('civitai_api_key', data['civitai_api_key'])
return web.json_response({'success': True})
except Exception as e:
logger.error(f"Error updating settings: {e}", exc_info=True) # 添加 exc_info=True 以获取完整堆栈
return web.Response(status=500, text=str(e))
@classmethod @classmethod
async def cleanup(cls): async def cleanup(cls):
"""Add cleanup method for application shutdown""" """Add cleanup method for application shutdown"""

View File

@@ -5,6 +5,7 @@ from typing import Dict, List
import logging import logging
from ..services.lora_scanner import LoraScanner from ..services.lora_scanner import LoraScanner
from ..config import config from ..config import config
from ..services.settings_manager import settings # Add this import
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
logging.getLogger('asyncio').setLevel(logging.CRITICAL) logging.getLogger('asyncio').setLevel(logging.CRITICAL)
@@ -60,7 +61,8 @@ class LoraRoutes:
template = self.template_env.get_template('loras.html') template = self.template_env.get_template('loras.html')
rendered = template.render( rendered = template.render(
folders=[], # 空文件夹列表 folders=[], # 空文件夹列表
is_initializing=True # 新增标志 is_initializing=True, # 新增标志
settings=settings # Pass settings to template
) )
else: else:
# 正常流程 # 正常流程
@@ -68,7 +70,8 @@ class LoraRoutes:
template = self.template_env.get_template('loras.html') template = self.template_env.get_template('loras.html')
rendered = template.render( rendered = template.render(
folders=cache.folders, folders=cache.folders,
is_initializing=False is_initializing=False,
settings=settings # Pass settings to template
) )
return web.Response( return web.Response(

View File

@@ -2,7 +2,9 @@ import aiohttp
import os import os
import json import json
import logging import logging
from typing import Optional, Dict from email.parser import Parser
from typing import Optional, Dict, Tuple
from urllib.parse import unquote
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -18,9 +20,74 @@ class CivitaiClient:
async def session(self) -> aiohttp.ClientSession: async def session(self) -> aiohttp.ClientSession:
"""Lazy initialize the session""" """Lazy initialize the session"""
if self._session is None: if self._session is None:
self._session = aiohttp.ClientSession() connector = aiohttp.TCPConnector(ssl=True)
trust_env = True # 允许使用系统环境变量中的代理设置
self._session = aiohttp.ClientSession(connector=connector, trust_env=trust_env)
return self._session return self._session
def _parse_content_disposition(self, header: str) -> str:
"""Parse filename from content-disposition header"""
if not header:
return None
# Handle quoted filenames
if 'filename="' in header:
start = header.index('filename="') + 10
end = header.index('"', start)
return unquote(header[start:end])
# Fallback to original parsing
disposition = Parser().parsestr(f'Content-Disposition: {header}')
filename = disposition.get_param('filename')
if filename:
return unquote(filename)
return None
def _get_request_headers(self) -> dict:
"""Get request headers with optional API key"""
headers = {
'User-Agent': 'ComfyUI-LoRA-Manager/1.0',
'Content-Type': 'application/json'
}
from .settings_manager import settings
api_key = settings.get('civitai_api_key')
if (api_key):
headers['Authorization'] = f'Bearer {api_key}'
return headers
async def _download_file(self, url: str, save_dir: str, default_filename: str) -> Tuple[bool, str]:
"""Download file with content-disposition support"""
session = await self.session
try:
headers = self._get_request_headers()
async with session.get(url, headers=headers, allow_redirects=True) as response:
if response.status != 200:
return False, f"Download failed with status {response.status}"
# Get filename from content-disposition header
content_disposition = response.headers.get('Content-Disposition')
filename = self._parse_content_disposition(content_disposition)
if not filename:
filename = default_filename
save_path = os.path.join(save_dir, filename)
# Stream download to file
with open(save_path, 'wb') as f:
while True:
chunk = await response.content.read(8192)
if not chunk:
break
f.write(chunk)
return True, save_path
except Exception as e:
logger.error(f"Download error: {e}")
return False, str(e)
async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]: async def get_model_by_hash(self, model_hash: str) -> Optional[Dict]:
try: try:
session = await self.session session = await self.session
@@ -60,92 +127,24 @@ class CivitaiClient:
logger.error(f"Error fetching model versions: {e}") logger.error(f"Error fetching model versions: {e}")
return None return None
async def download_model_version(self, version_id: str, save_dir: str) -> Dict:
"""Download a specific model version"""
try:
session = await self.session
# First get version info
url = f"{self.base_url}/model-versions/{version_id}"
async with session.get(url, headers=self.headers) as response:
if response.status != 200:
return {'success': False, 'error': 'Version not found'}
version_data = await response.json()
download_url = version_data.get('downloadUrl')
if not download_url:
return {'success': False, 'error': 'No download URL found'}
# Download the file
file_name = version_data.get('files', [{}])[0].get('name', f'lora_{version_id}.safetensors')
save_path = os.path.join(save_dir, file_name)
async with session.get(download_url, headers=self.headers) as response:
if response.status != 200:
return {'success': False, 'error': 'Download failed'}
with open(save_path, 'wb') as f:
while True:
chunk = await response.content.read(8192)
if not chunk:
break
f.write(chunk)
# Create metadata file
metadata_path = os.path.splitext(save_path)[0] + '.metadata.json'
metadata = {
'model_name': version_data.get('model', {}).get('name', file_name),
'civitai': version_data,
'preview_url': None,
'from_civitai': True
}
# Download preview image if available
images = version_data.get('images', [])
if images:
preview_ext = '.mp4' if images[0].get('type') == 'video' else '.png'
preview_path = os.path.splitext(save_path)[0] + '.preview' + preview_ext
await self.download_preview_image(images[0]['url'], preview_path)
metadata['preview_url'] = preview_path
# Save metadata
with open(metadata_path, 'w', encoding='utf-8') as f:
json.dump(metadata, f, indent=2, ensure_ascii=False)
return {
'success': True,
'file_path': save_path,
'metadata': metadata
}
except Exception as e:
logger.error(f"Error downloading model version: {e}")
return {'success': False, 'error': str(e)}
async def download_model_with_info(self, download_url: str, version_info: dict, save_dir: str) -> Dict: async def download_model_with_info(self, download_url: str, version_info: dict, save_dir: str) -> Dict:
"""Download model using provided version info and URL""" """Download model using provided version info and URL"""
try: try:
session = await self.session # Generate default filename
default_filename = f"lora_{version_info['id']}.safetensors"
logger.info(f"Downloading model: {version_info.get('name', 'Unknown')}")
# Use provided filename or generate one # Download the model file
file_name = version_info.get('files', [{}])[0].get('name', f'lora_{version_info["id"]}.safetensors') success, result = await self._download_file(download_url, save_dir, default_filename)
save_path = os.path.join(save_dir, file_name) if not success:
return {'success': False, 'error': result}
# Download the file save_path = result
async with session.get(download_url, headers=self.headers) as response:
if response.status != 200:
return {'success': False, 'error': 'Download failed'}
with open(save_path, 'wb') as f:
while True:
chunk = await response.content.read(8192)
if not chunk:
break
f.write(chunk)
# Create metadata file # Create metadata file
metadata_path = os.path.splitext(save_path)[0] + '.metadata.json' metadata_path = os.path.splitext(save_path)[0] + '.metadata.json'
metadata = { metadata = {
'model_name': version_info.get('model', {}).get('name', file_name), 'model_name': version_info.get('name', os.path.basename(save_path)),
'civitai': version_info, 'civitai': version_info,
'preview_url': None, 'preview_url': None,
'from_civitai': True 'from_civitai': True
@@ -157,7 +156,7 @@ class CivitaiClient:
preview_ext = '.mp4' if images[0].get('type') == 'video' else '.png' preview_ext = '.mp4' if images[0].get('type') == 'video' else '.png'
preview_path = os.path.splitext(save_path)[0] + '.preview' + preview_ext preview_path = os.path.splitext(save_path)[0] + '.preview' + preview_ext
await self.download_preview_image(images[0]['url'], preview_path) await self.download_preview_image(images[0]['url'], preview_path)
metadata['preview_url'] = preview_path metadata['preview_url'] = preview_path.replace(os.sep, '/')
# Save metadata # Save metadata
with open(metadata_path, 'w', encoding='utf-8') as f: with open(metadata_path, 'w', encoding='utf-8') as f:
@@ -165,7 +164,7 @@ class CivitaiClient:
return { return {
'success': True, 'success': True,
'file_path': save_path, 'file_path': save_path.replace(os.sep, '/'),
'metadata': metadata 'metadata': metadata
} }

View File

@@ -0,0 +1,46 @@
import os
import json
import logging
from typing import Any, Dict
logger = logging.getLogger(__name__)
class SettingsManager:
def __init__(self):
self.settings_file = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'settings.json')
self.settings = self._load_settings()
def _load_settings(self) -> Dict[str, Any]:
"""Load settings from file"""
if os.path.exists(self.settings_file):
try:
with open(self.settings_file, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
logger.error(f"Error loading settings: {e}")
return self._get_default_settings()
def _get_default_settings(self) -> Dict[str, Any]:
"""Return default settings"""
return {
"civitai_api_key": ""
}
def get(self, key: str, default: Any = None) -> Any:
"""Get setting value"""
return self.settings.get(key, default)
def set(self, key: str, value: Any) -> None:
"""Set setting value and save"""
self.settings[key] = value
self._save_settings()
def _save_settings(self) -> None:
"""Save settings to file"""
try:
with open(self.settings_file, 'w', encoding='utf-8') as f:
json.dump(self.settings, f, indent=2)
except Exception as e:
logger.error(f"Error saving settings: {e}")
settings = SettingsManager()

View File

@@ -1155,3 +1155,60 @@ body.modal-open {
max-height: 200px; max-height: 200px;
overflow-y: auto; overflow-y: auto;
} }
/* Settings styles */
.settings-toggle {
width: 36px;
height: 36px;
border-radius: 50%;
background: var(--card-bg);
border: 1px solid var(--border-color);
color: var(--text-color);
display: flex;
align-items: center;
justify-content: center;
cursor: pointer;
transition: all 0.2s ease;
}
.settings-toggle:hover {
background: var(--lora-accent);
color: white;
transform: translateY(-2px);
}
.settings-modal {
max-width: 500px;
}
.api-key-input {
position: relative;
display: flex;
align-items: center;
}
.api-key-input input {
padding-right: 40px;
}
.api-key-input .toggle-visibility {
position: absolute;
right: 8px;
background: none;
border: none;
color: var(--text-color);
opacity: 0.6;
cursor: pointer;
padding: 4px 8px;
}
.api-key-input .toggle-visibility:hover {
opacity: 1;
}
.input-help {
font-size: 0.85em;
color: var(--text-color);
opacity: 0.8;
margin-top: 4px;
}

View File

@@ -21,6 +21,7 @@ import { initializeInfiniteScroll } from './utils/infiniteScroll.js';
import { showDeleteModal, confirmDelete, closeDeleteModal } from './utils/modalUtils.js'; import { showDeleteModal, confirmDelete, closeDeleteModal } from './utils/modalUtils.js';
import { SearchManager } from './utils/search.js'; import { SearchManager } from './utils/search.js';
import { DownloadManager } from './managers/DownloadManager.js'; import { DownloadManager } from './managers/DownloadManager.js';
import { SettingsManager, toggleApiKeyVisibility } from './managers/SettingsManager.js';
// Export all functions that need global access // Export all functions that need global access
window.loadMoreLoras = loadMoreLoras; window.loadMoreLoras = loadMoreLoras;
@@ -39,6 +40,8 @@ window.refreshLoras = refreshLoras;
window.openCivitai = openCivitai; window.openCivitai = openCivitai;
window.showToast = showToast window.showToast = showToast
window.toggleFolderTags = toggleFolderTags; window.toggleFolderTags = toggleFolderTags;
window.settingsManager = new SettingsManager();
window.toggleApiKeyVisibility = toggleApiKeyVisibility;
// Initialize everything when DOM is ready // Initialize everything when DOM is ready
document.addEventListener('DOMContentLoaded', () => { document.addEventListener('DOMContentLoaded', () => {

View File

@@ -34,6 +34,15 @@ export class ModalManager {
} }
}); });
// Add settingsModal registration
this.registerModal('settingsModal', {
element: document.getElementById('settingsModal'),
onClose: () => {
this.getModal('settingsModal').element.style.display = 'none';
document.body.classList.remove('modal-open');
}
});
document.addEventListener('keydown', this.boundHandleEscape); document.addEventListener('keydown', this.boundHandleEscape);
this.initialized = true; this.initialized = true;
} }

View File

@@ -0,0 +1,52 @@
import { modalManager } from './ModalManager.js';
import { showToast } from '../utils/uiHelpers.js';
export class SettingsManager {
constructor() {
this.initialized = false;
}
showSettings() {
console.log('Opening settings modal...'); // Debug log
modalManager.showModal('settingsModal');
}
async saveSettings() {
const apiKey = document.getElementById('civitaiApiKey').value;
try {
const response = await fetch('/api/settings', {
method: 'POST',
headers: {
'Content-Type': 'application/json',
},
body: JSON.stringify({
civitai_api_key: apiKey
})
});
if (!response.ok) {
throw new Error('Failed to save settings');
}
showToast('Settings saved successfully', 'success');
modalManager.closeModal('settingsModal');
} catch (error) {
showToast('Failed to save settings: ' + error.message, 'error');
}
}
}
// Helper function for toggling API key visibility
export function toggleApiKeyVisibility(button) {
const input = button.parentElement.querySelector('input');
const icon = button.querySelector('i');
if (input.type === 'password') {
input.type = 'text';
icon.className = 'fas fa-eye-slash';
} else {
input.type = 'password';
icon.className = 'fas fa-eye';
}
}

View File

@@ -68,3 +68,31 @@
</div> </div>
</div> </div>
</div> </div>
<!-- Settings Modal -->
<div id="settingsModal" class="modal">
<div class="modal-content settings-modal">
<button class="close" onclick="modalManager.closeModal('settingsModal')">&times;</button>
<h2>Settings</h2>
<div class="settings-form">
<div class="input-group">
<label for="civitaiApiKey">Civitai API Key:</label>
<div class="api-key-input">
<input type="password"
id="civitaiApiKey"
placeholder="Enter your Civitai API key"
value="{{ settings.get('civitai_api_key', '') }}" />
<button class="toggle-visibility" onclick="toggleApiKeyVisibility(this)">
<i class="fas fa-eye"></i>
</button>
</div>
<div class="input-help">
Used for authentication when downloading models from Civitai
</div>
</div>
</div>
<div class="modal-actions">
<button class="primary-btn" onclick="settingsManager.saveSettings()">Save</button>
</div>
</div>
</div>

View File

@@ -39,6 +39,9 @@
<img src="/loras_static/images/theme-toggle-light.svg" alt="Theme" class="theme-icon light-icon"> <img src="/loras_static/images/theme-toggle-light.svg" alt="Theme" class="theme-icon light-icon">
<img src="/loras_static/images/theme-toggle-dark.svg" alt="Theme" class="theme-icon dark-icon"> <img src="/loras_static/images/theme-toggle-dark.svg" alt="Theme" class="theme-icon dark-icon">
</div> </div>
<div class="settings-toggle" onclick="settingsManager.showSettings()" title="Settings">
<i class="fas fa-cog"></i>
</div>
</div> </div>
{% include 'components/modals.html' %} {% include 'components/modals.html' %}