checkpoint

This commit is contained in:
Will Miao
2025-02-10 23:40:38 +08:00
parent 343da36199
commit 2222731f36
10 changed files with 546 additions and 910 deletions

View File

@@ -16,6 +16,7 @@ class ApiRoutes:
def __init__(self):
self.scanner = LoraScanner()
self.civitai_client = CivitaiClient()
@classmethod
def setup_routes(cls, app: web.Application):
@@ -27,6 +28,9 @@ class ApiRoutes:
app.router.add_get('/api/loras', routes.get_loras)
app.router.add_post('/api/fetch-all-civitai', routes.fetch_all_civitai)
app.router.add_get('/ws/fetch-progress', ws_manager.handle_connection)
app.router.add_get('/api/lora-roots', routes.get_lora_roots)
app.router.add_get('/api/civitai/versions/{model_id}', routes.get_civitai_versions)
app.router.add_post('/api/download-lora', routes.download_lora)
async def delete_model(self, request: web.Request) -> web.Response:
"""Handle model deletion request"""
@@ -52,7 +56,6 @@ class ApiRoutes:
async def fetch_civitai(self, request: web.Request) -> web.Response:
"""Handle CivitAI metadata fetch request"""
client = CivitaiClient()
try:
data = await request.json()
metadata_path = os.path.splitext(data['file_path'])[0] + '.metadata.json'
@@ -63,19 +66,17 @@ class ApiRoutes:
return web.json_response({"success": True, "notice": "Not from CivitAI"})
# Fetch and update metadata
civitai_metadata = await client.get_model_by_hash(data["sha256"])
civitai_metadata = await self.civitai_client.get_model_by_hash(data["sha256"])
if not civitai_metadata:
return await self._handle_not_found_on_civitai(metadata_path, local_metadata)
await self._update_model_metadata(metadata_path, local_metadata, civitai_metadata, client)
await self._update_model_metadata(metadata_path, local_metadata, civitai_metadata, self.civitai_client)
return web.json_response({"success": True})
except Exception as e:
logger.error(f"Error fetching from CivitAI: {e}", exc_info=True)
return web.json_response({"success": False, "error": str(e)}, status=500)
finally:
await client.close()
async def replace_preview(self, request: web.Request) -> web.Response:
"""Handle preview image replacement request"""
@@ -444,3 +445,64 @@ class ApiRoutes:
return False
finally:
await client.close()
async def get_lora_roots(self, request: web.Request) -> web.Response:
"""Get all configured LoRA root directories"""
return web.json_response({
'roots': config.loras_roots
})
async def get_civitai_versions(self, request: web.Request) -> web.Response:
"""Get available versions for a Civitai model"""
try:
model_id = request.match_info['model_id']
versions = await self.civitai_client.get_model_versions(model_id)
if not versions:
return web.Response(status=404, text="Model not found")
return web.json_response(versions)
except Exception as e:
logger.error(f"Error fetching model versions: {e}")
return web.Response(status=500, text=str(e))
async def download_lora(self, request: web.Request) -> web.Response:
"""Handle LoRA download request"""
try:
data = await request.json()
download_url = data.get('download_url')
version_info = data.get('version_info')
lora_root = data.get('lora_root')
new_folder = data.get('new_folder', '').strip()
if not download_url or not version_info or not lora_root:
return web.Response(status=400, text="Missing required parameters")
if not os.path.isdir(lora_root):
return web.Response(status=400, text="Invalid LoRA root directory")
# 构建保存路径
save_dir = os.path.join(lora_root, new_folder) if new_folder else lora_root
os.makedirs(save_dir, exist_ok=True)
# 使用提供的下载 URL 和版本信息
result = await self.civitai_client.download_model_with_info(
download_url=download_url,
version_info=version_info,
save_dir=save_dir
)
if result.get('success'):
# 更新缓存
await self.scanner.rescan_directory(save_dir)
return web.json_response(result)
else:
return web.Response(status=500, text=result.get('error', 'Download failed'))
except Exception as e:
logger.error(f"Error downloading LoRA: {e}")
return web.Response(status=500, text=str(e))
@classmethod
async def cleanup(cls):
"""Add cleanup method for application shutdown"""
if hasattr(cls, '_instance'):
await cls._instance.civitai_client.close()