mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 15:15:44 -03:00
fix: update model_id and model_version_id handling across various services for improved flexibility
This commit is contained in:
@@ -654,13 +654,13 @@ class MiscRoutes:
|
|||||||
exists = False
|
exists = False
|
||||||
model_type = None
|
model_type = None
|
||||||
|
|
||||||
if await lora_scanner.check_model_version_exists(model_id, model_version_id):
|
if await lora_scanner.check_model_version_exists(model_version_id):
|
||||||
exists = True
|
exists = True
|
||||||
model_type = 'lora'
|
model_type = 'lora'
|
||||||
elif checkpoint_scanner and await checkpoint_scanner.check_model_version_exists(model_id, model_version_id):
|
elif checkpoint_scanner and await checkpoint_scanner.check_model_version_exists(model_version_id):
|
||||||
exists = True
|
exists = True
|
||||||
model_type = 'checkpoint'
|
model_type = 'checkpoint'
|
||||||
elif embedding_scanner and await embedding_scanner.check_model_version_exists(model_id, model_version_id):
|
elif embedding_scanner and await embedding_scanner.check_model_version_exists(model_version_id):
|
||||||
exists = True
|
exists = True
|
||||||
model_type = 'embedding'
|
model_type = 'embedding'
|
||||||
|
|
||||||
|
|||||||
@@ -223,11 +223,11 @@ class CivitaiClient:
|
|||||||
logger.error(f"Error fetching model versions: {e}")
|
logger.error(f"Error fetching model versions: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def get_model_version(self, model_id: int, version_id: int = None) -> Optional[Dict]:
|
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
|
||||||
"""Get specific model version with additional metadata
|
"""Get specific model version with additional metadata
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_id: The Civitai model ID
|
model_id: The Civitai model ID (optional if version_id is provided)
|
||||||
version_id: Optional specific version ID to retrieve
|
version_id: Optional specific version ID to retrieve
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -235,37 +235,72 @@ class CivitaiClient:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
session = await self._ensure_fresh_session()
|
session = await self._ensure_fresh_session()
|
||||||
|
|
||||||
# Step 1: Get model data to find version_id if not provided and get additional metadata
|
|
||||||
async with session.get(f"{self.base_url}/models/{model_id}") as response:
|
|
||||||
if response.status != 200:
|
|
||||||
return None
|
|
||||||
|
|
||||||
data = await response.json()
|
|
||||||
model_versions = data.get('modelVersions', [])
|
|
||||||
|
|
||||||
# Step 2: Determine the version_id to use
|
|
||||||
target_version_id = version_id
|
|
||||||
if target_version_id is None:
|
|
||||||
target_version_id = model_versions[0].get('id')
|
|
||||||
|
|
||||||
# Step 3: Get detailed version info using the version_id
|
|
||||||
headers = self._get_request_headers()
|
headers = self._get_request_headers()
|
||||||
async with session.get(f"{self.base_url}/model-versions/{target_version_id}", headers=headers) as response:
|
|
||||||
if response.status != 200:
|
|
||||||
return None
|
|
||||||
|
|
||||||
version = await response.json()
|
# Case 1: Only version_id is provided
|
||||||
|
if model_id is None and version_id is not None:
|
||||||
|
# First get the version info to extract model_id
|
||||||
|
async with session.get(f"{self.base_url}/model-versions/{version_id}", headers=headers) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
return None
|
||||||
|
|
||||||
# Step 4: Enrich version_info with model data
|
version = await response.json()
|
||||||
# Add description and tags from model data
|
model_id = version.get('modelId')
|
||||||
version['model']['description'] = data.get("description")
|
|
||||||
version['model']['tags'] = data.get("tags", [])
|
|
||||||
|
|
||||||
# Add creator from model data
|
if not model_id:
|
||||||
version['creator'] = data.get("creator")
|
logger.error(f"No modelId found in version {version_id}")
|
||||||
|
return None
|
||||||
|
|
||||||
return version
|
# Now get the model data for additional metadata
|
||||||
|
async with session.get(f"{self.base_url}/models/{model_id}") as response:
|
||||||
|
if response.status != 200:
|
||||||
|
return version # Return version without additional metadata
|
||||||
|
|
||||||
|
model_data = await response.json()
|
||||||
|
|
||||||
|
# Enrich version with model data
|
||||||
|
version['model']['description'] = model_data.get("description")
|
||||||
|
version['model']['tags'] = model_data.get("tags", [])
|
||||||
|
version['creator'] = model_data.get("creator")
|
||||||
|
|
||||||
|
return version
|
||||||
|
|
||||||
|
# Case 2: model_id is provided (with or without version_id)
|
||||||
|
elif model_id is not None:
|
||||||
|
# Step 1: Get model data to find version_id if not provided and get additional metadata
|
||||||
|
async with session.get(f"{self.base_url}/models/{model_id}") as response:
|
||||||
|
if response.status != 200:
|
||||||
|
return None
|
||||||
|
|
||||||
|
data = await response.json()
|
||||||
|
model_versions = data.get('modelVersions', [])
|
||||||
|
|
||||||
|
# Step 2: Determine the version_id to use
|
||||||
|
target_version_id = version_id
|
||||||
|
if target_version_id is None:
|
||||||
|
target_version_id = model_versions[0].get('id')
|
||||||
|
|
||||||
|
# Step 3: Get detailed version info using the version_id
|
||||||
|
async with session.get(f"{self.base_url}/model-versions/{target_version_id}", headers=headers) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
return None
|
||||||
|
|
||||||
|
version = await response.json()
|
||||||
|
|
||||||
|
# Step 4: Enrich version_info with model data
|
||||||
|
# Add description and tags from model data
|
||||||
|
version['model']['description'] = data.get("description")
|
||||||
|
version['model']['tags'] = data.get("tags", [])
|
||||||
|
|
||||||
|
# Add creator from model data
|
||||||
|
version['creator'] = data.get("creator")
|
||||||
|
|
||||||
|
return version
|
||||||
|
|
||||||
|
# Case 3: Neither model_id nor version_id provided
|
||||||
|
else:
|
||||||
|
logger.error("Either model_id or version_id must be provided")
|
||||||
|
return None
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error fetching model version: {e}")
|
logger.error(f"Error fetching model version: {e}")
|
||||||
|
|||||||
@@ -54,15 +54,15 @@ class DownloadManager:
|
|||||||
"""Get the checkpoint scanner from registry"""
|
"""Get the checkpoint scanner from registry"""
|
||||||
return await ServiceRegistry.get_checkpoint_scanner()
|
return await ServiceRegistry.get_checkpoint_scanner()
|
||||||
|
|
||||||
async def download_from_civitai(self, model_id: int, model_version_id: int,
|
async def download_from_civitai(self, model_id: int = None, model_version_id: int = None,
|
||||||
save_dir: str = None, relative_path: str = '',
|
save_dir: str = None, relative_path: str = '',
|
||||||
progress_callback=None, use_default_paths: bool = False,
|
progress_callback=None, use_default_paths: bool = False,
|
||||||
download_id: str = None) -> Dict:
|
download_id: str = None) -> Dict:
|
||||||
"""Download model from Civitai with task tracking and concurrency control
|
"""Download model from Civitai with task tracking and concurrency control
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_id: Civitai model ID
|
model_id: Civitai model ID (optional if model_version_id is provided)
|
||||||
model_version_id: Civitai model version ID
|
model_version_id: Civitai model version ID (optional if model_id is provided)
|
||||||
save_dir: Directory to save the model
|
save_dir: Directory to save the model
|
||||||
relative_path: Relative path within save_dir
|
relative_path: Relative path within save_dir
|
||||||
progress_callback: Callback function for progress updates
|
progress_callback: Callback function for progress updates
|
||||||
@@ -72,6 +72,10 @@ class DownloadManager:
|
|||||||
Returns:
|
Returns:
|
||||||
Dict with download result
|
Dict with download result
|
||||||
"""
|
"""
|
||||||
|
# Validate that at least one identifier is provided
|
||||||
|
if not model_id and not model_version_id:
|
||||||
|
return {'success': False, 'error': 'Either model_id or model_version_id must be provided'}
|
||||||
|
|
||||||
# Use provided download_id or generate new one
|
# Use provided download_id or generate new one
|
||||||
task_id = download_id or str(uuid.uuid4())
|
task_id = download_id or str(uuid.uuid4())
|
||||||
|
|
||||||
@@ -181,15 +185,20 @@ class DownloadManager:
|
|||||||
# Check both scanners
|
# Check both scanners
|
||||||
lora_scanner = await self._get_lora_scanner()
|
lora_scanner = await self._get_lora_scanner()
|
||||||
checkpoint_scanner = await self._get_checkpoint_scanner()
|
checkpoint_scanner = await self._get_checkpoint_scanner()
|
||||||
|
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||||
|
|
||||||
# Check lora scanner first
|
# Check lora scanner first
|
||||||
if await lora_scanner.check_model_version_exists(model_id, model_version_id):
|
if await lora_scanner.check_model_version_exists(model_version_id):
|
||||||
return {'success': False, 'error': 'Model version already exists in lora library'}
|
return {'success': False, 'error': 'Model version already exists in lora library'}
|
||||||
|
|
||||||
# Check checkpoint scanner
|
# Check checkpoint scanner
|
||||||
if await checkpoint_scanner.check_model_version_exists(model_id, model_version_id):
|
if await checkpoint_scanner.check_model_version_exists(model_version_id):
|
||||||
return {'success': False, 'error': 'Model version already exists in checkpoint library'}
|
return {'success': False, 'error': 'Model version already exists in checkpoint library'}
|
||||||
|
|
||||||
|
# Check embedding scanner
|
||||||
|
if await embedding_scanner.check_model_version_exists(model_version_id):
|
||||||
|
return {'success': False, 'error': 'Model version already exists in embedding library'}
|
||||||
|
|
||||||
# Get civitai client
|
# Get civitai client
|
||||||
civitai_client = await self._get_civitai_client()
|
civitai_client = await self._get_civitai_client()
|
||||||
|
|
||||||
@@ -211,23 +220,22 @@ class DownloadManager:
|
|||||||
|
|
||||||
# Case 2: model_version_id was None, check after getting version_info
|
# Case 2: model_version_id was None, check after getting version_info
|
||||||
if model_version_id is None:
|
if model_version_id is None:
|
||||||
version_model_id = version_info.get('modelId')
|
|
||||||
version_id = version_info.get('id')
|
version_id = version_info.get('id')
|
||||||
|
|
||||||
if model_type == 'lora':
|
if model_type == 'lora':
|
||||||
# Check lora scanner
|
# Check lora scanner
|
||||||
lora_scanner = await self._get_lora_scanner()
|
lora_scanner = await self._get_lora_scanner()
|
||||||
if await lora_scanner.check_model_version_exists(version_model_id, version_id):
|
if await lora_scanner.check_model_version_exists(version_id):
|
||||||
return {'success': False, 'error': 'Model version already exists in lora library'}
|
return {'success': False, 'error': 'Model version already exists in lora library'}
|
||||||
elif model_type == 'checkpoint':
|
elif model_type == 'checkpoint':
|
||||||
# Check checkpoint scanner
|
# Check checkpoint scanner
|
||||||
checkpoint_scanner = await self._get_checkpoint_scanner()
|
checkpoint_scanner = await self._get_checkpoint_scanner()
|
||||||
if await checkpoint_scanner.check_model_version_exists(version_model_id, version_id):
|
if await checkpoint_scanner.check_model_version_exists(version_id):
|
||||||
return {'success': False, 'error': 'Model version already exists in checkpoint library'}
|
return {'success': False, 'error': 'Model version already exists in checkpoint library'}
|
||||||
elif model_type == 'embedding':
|
elif model_type == 'embedding':
|
||||||
# Embeddings are not checked in scanners, but we can still check if it exists
|
# Embeddings are not checked in scanners, but we can still check if it exists
|
||||||
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
|
||||||
if await embedding_scanner.check_model_version_exists(version_model_id, version_id):
|
if await embedding_scanner.check_model_version_exists(version_id):
|
||||||
return {'success': False, 'error': 'Model version already exists in embedding library'}
|
return {'success': False, 'error': 'Model version already exists in embedding library'}
|
||||||
|
|
||||||
# Handle use_default_paths
|
# Handle use_default_paths
|
||||||
|
|||||||
@@ -1149,11 +1149,10 @@ class ModelScanner:
|
|||||||
if len(self._hash_index._duplicate_filenames[file_name]) <= 1:
|
if len(self._hash_index._duplicate_filenames[file_name]) <= 1:
|
||||||
del self._hash_index._duplicate_filenames[file_name]
|
del self._hash_index._duplicate_filenames[file_name]
|
||||||
|
|
||||||
async def check_model_version_exists(self, model_id: int, model_version_id: int) -> bool:
|
async def check_model_version_exists(self, model_version_id: int) -> bool:
|
||||||
"""Check if a specific model version exists in the cache
|
"""Check if a specific model version exists in the cache
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_id: Civitai model ID
|
|
||||||
model_version_id: Civitai model version ID
|
model_version_id: Civitai model version ID
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
@@ -1165,9 +1164,7 @@ class ModelScanner:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
for item in cache.raw_data:
|
for item in cache.raw_data:
|
||||||
if (item.get('civitai') and
|
if item.get('civitai') and item['civitai'].get('id') == model_version_id:
|
||||||
item['civitai'].get('modelId') == model_id and
|
|
||||||
item['civitai'].get('id') == model_version_id):
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -580,16 +580,19 @@ class ModelRouteUtils:
|
|||||||
})
|
})
|
||||||
|
|
||||||
# Check which identifier is provided and convert to int
|
# Check which identifier is provided and convert to int
|
||||||
try:
|
model_id = None
|
||||||
model_id = int(data.get('model_id'))
|
model_version_id = None
|
||||||
except (TypeError, ValueError):
|
|
||||||
return web.json_response({
|
if data.get('model_id'):
|
||||||
'success': False,
|
try:
|
||||||
'error': "Invalid model_id: Must be an integer"
|
model_id = int(data.get('model_id'))
|
||||||
}, status=400)
|
except (TypeError, ValueError):
|
||||||
|
return web.json_response({
|
||||||
|
'success': False,
|
||||||
|
'error': "Invalid model_id: Must be an integer"
|
||||||
|
}, status=400)
|
||||||
|
|
||||||
# Convert model_version_id to int if provided
|
# Convert model_version_id to int if provided
|
||||||
model_version_id = None
|
|
||||||
if data.get('model_version_id'):
|
if data.get('model_version_id'):
|
||||||
try:
|
try:
|
||||||
model_version_id = int(data.get('model_version_id'))
|
model_version_id = int(data.get('model_version_id'))
|
||||||
@@ -599,11 +602,11 @@ class ModelRouteUtils:
|
|||||||
'error': "Invalid model_version_id: Must be an integer"
|
'error': "Invalid model_version_id: Must be an integer"
|
||||||
}, status=400)
|
}, status=400)
|
||||||
|
|
||||||
# Only model_id is required, model_version_id is optional
|
# At least one identifier is required
|
||||||
if not model_id:
|
if not model_id and not model_version_id:
|
||||||
return web.json_response({
|
return web.json_response({
|
||||||
'success': False,
|
'success': False,
|
||||||
'error': "Missing required parameter: Please provide 'model_id'"
|
'error': "Missing required parameter: Please provide either 'model_id' or 'model_version_id'"
|
||||||
}, status=400)
|
}, status=400)
|
||||||
|
|
||||||
use_default_paths = data.get('use_default_paths', False)
|
use_default_paths = data.get('use_default_paths', False)
|
||||||
|
|||||||
@@ -17,16 +17,16 @@ export function createModelApiClient(modelType) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let _singletonClient = null;
|
let _singletonClients = new Map();
|
||||||
|
|
||||||
export function getModelApiClient() {
|
export function getModelApiClient(modelType = null) {
|
||||||
const currentType = state.currentPageType;
|
const targetType = modelType || state.currentPageType;
|
||||||
|
|
||||||
if (!_singletonClient || _singletonClient.modelType !== currentType) {
|
if (!_singletonClients.has(targetType)) {
|
||||||
_singletonClient = createModelApiClient(currentType);
|
_singletonClients.set(targetType, createModelApiClient(targetType));
|
||||||
}
|
}
|
||||||
|
|
||||||
return _singletonClient;
|
return _singletonClients.get(targetType);
|
||||||
}
|
}
|
||||||
|
|
||||||
export function resetAndReload(updateFolders = false) {
|
export function resetAndReload(updateFolders = false) {
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
import { showToast } from '../../utils/uiHelpers.js';
|
import { showToast } from '../../utils/uiHelpers.js';
|
||||||
|
import { getModelApiClient } from '../../api/modelApiFactory.js';
|
||||||
|
import { MODEL_TYPES } from '../../api/apiConfig.js';
|
||||||
|
|
||||||
export class DownloadManager {
|
export class DownloadManager {
|
||||||
constructor(importManager) {
|
constructor(importManager) {
|
||||||
@@ -199,30 +201,18 @@ export class DownloadManager {
|
|||||||
updateProgress(0, completedDownloads, lora.name);
|
updateProgress(0, completedDownloads, lora.name);
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
console.log(`lora:`, lora);
|
||||||
// Download the LoRA with download ID
|
// Download the LoRA with download ID
|
||||||
const response = await fetch('/api/download-model', {
|
const response = await getModelApiClient(MODEL_TYPES.LORA).downloadModel(
|
||||||
method: 'POST',
|
lora.modelId,
|
||||||
headers: { 'Content-Type': 'application/json' },
|
lora.modelVersionId,
|
||||||
body: JSON.stringify({
|
loraRoot,
|
||||||
model_id: lora.modelId,
|
targetPath.replace(loraRoot + '/', ''),
|
||||||
model_version_id: lora.id,
|
batchDownloadId
|
||||||
model_root: loraRoot,
|
);
|
||||||
relative_path: targetPath.replace(loraRoot + '/', ''),
|
|
||||||
download_id: batchDownloadId
|
|
||||||
})
|
|
||||||
});
|
|
||||||
|
|
||||||
if (!response.ok) {
|
if (!response.success) {
|
||||||
const errorText = await response.text();
|
console.error(`Failed to download LoRA ${lora.name}: ${response.error}`);
|
||||||
console.error(`Failed to download LoRA ${lora.name}: ${errorText}`);
|
|
||||||
|
|
||||||
// Check if this is an early access error (status 401 is the key indicator)
|
|
||||||
if (response.status === 401) {
|
|
||||||
accessFailures++;
|
|
||||||
this.importManager.loadingManager.setStatus(
|
|
||||||
`Failed to download ${lora.name}: Access restricted`
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
failedDownloads++;
|
failedDownloads++;
|
||||||
// Continue with next download
|
// Continue with next download
|
||||||
|
|||||||
Reference in New Issue
Block a user