fix: update model_id and model_version_id handling across various services for improved flexibility

This commit is contained in:
Will Miao
2025-08-11 15:31:49 +08:00
parent b03420faac
commit 96517cbdef
7 changed files with 126 additions and 93 deletions

View File

@@ -654,13 +654,13 @@ class MiscRoutes:
exists = False
model_type = None
if await lora_scanner.check_model_version_exists(model_id, model_version_id):
if await lora_scanner.check_model_version_exists(model_version_id):
exists = True
model_type = 'lora'
elif checkpoint_scanner and await checkpoint_scanner.check_model_version_exists(model_id, model_version_id):
elif checkpoint_scanner and await checkpoint_scanner.check_model_version_exists(model_version_id):
exists = True
model_type = 'checkpoint'
elif embedding_scanner and await embedding_scanner.check_model_version_exists(model_id, model_version_id):
elif embedding_scanner and await embedding_scanner.check_model_version_exists(model_version_id):
exists = True
model_type = 'embedding'

View File

@@ -223,11 +223,11 @@ class CivitaiClient:
logger.error(f"Error fetching model versions: {e}")
return None
async def get_model_version(self, model_id: int, version_id: int = None) -> Optional[Dict]:
async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
"""Get specific model version with additional metadata
Args:
model_id: The Civitai model ID
model_id: The Civitai model ID (optional if version_id is provided)
version_id: Optional specific version ID to retrieve
Returns:
@@ -235,37 +235,72 @@ class CivitaiClient:
"""
try:
session = await self._ensure_fresh_session()
# Step 1: Get model data to find version_id if not provided and get additional metadata
async with session.get(f"{self.base_url}/models/{model_id}") as response:
if response.status != 200:
return None
data = await response.json()
model_versions = data.get('modelVersions', [])
# Step 2: Determine the version_id to use
target_version_id = version_id
if target_version_id is None:
target_version_id = model_versions[0].get('id')
# Step 3: Get detailed version info using the version_id
headers = self._get_request_headers()
async with session.get(f"{self.base_url}/model-versions/{target_version_id}", headers=headers) as response:
if response.status != 200:
return None
# Case 1: Only version_id is provided
if model_id is None and version_id is not None:
# First get the version info to extract model_id
async with session.get(f"{self.base_url}/model-versions/{version_id}", headers=headers) as response:
if response.status != 200:
return None
version = await response.json()
model_id = version.get('modelId')
if not model_id:
logger.error(f"No modelId found in version {version_id}")
return None
version = await response.json()
# Now get the model data for additional metadata
async with session.get(f"{self.base_url}/models/{model_id}") as response:
if response.status != 200:
return version # Return version without additional metadata
model_data = await response.json()
# Enrich version with model data
version['model']['description'] = model_data.get("description")
version['model']['tags'] = model_data.get("tags", [])
version['creator'] = model_data.get("creator")
return version
# Case 2: model_id is provided (with or without version_id)
elif model_id is not None:
# Step 1: Get model data to find version_id if not provided and get additional metadata
async with session.get(f"{self.base_url}/models/{model_id}") as response:
if response.status != 200:
return None
data = await response.json()
model_versions = data.get('modelVersions', [])
# Step 2: Determine the version_id to use
target_version_id = version_id
if target_version_id is None:
target_version_id = model_versions[0].get('id')
# Step 4: Enrich version_info with model data
# Add description and tags from model data
version['model']['description'] = data.get("description")
version['model']['tags'] = data.get("tags", [])
# Add creator from model data
version['creator'] = data.get("creator")
return version
# Step 3: Get detailed version info using the version_id
async with session.get(f"{self.base_url}/model-versions/{target_version_id}", headers=headers) as response:
if response.status != 200:
return None
version = await response.json()
# Step 4: Enrich version_info with model data
# Add description and tags from model data
version['model']['description'] = data.get("description")
version['model']['tags'] = data.get("tags", [])
# Add creator from model data
version['creator'] = data.get("creator")
return version
# Case 3: Neither model_id nor version_id provided
else:
logger.error("Either model_id or version_id must be provided")
return None
except Exception as e:
logger.error(f"Error fetching model version: {e}")

View File

@@ -54,15 +54,15 @@ class DownloadManager:
"""Get the checkpoint scanner from registry"""
return await ServiceRegistry.get_checkpoint_scanner()
async def download_from_civitai(self, model_id: int, model_version_id: int,
async def download_from_civitai(self, model_id: int = None, model_version_id: int = None,
save_dir: str = None, relative_path: str = '',
progress_callback=None, use_default_paths: bool = False,
download_id: str = None) -> Dict:
"""Download model from Civitai with task tracking and concurrency control
Args:
model_id: Civitai model ID
model_version_id: Civitai model version ID
model_id: Civitai model ID (optional if model_version_id is provided)
model_version_id: Civitai model version ID (optional if model_id is provided)
save_dir: Directory to save the model
relative_path: Relative path within save_dir
progress_callback: Callback function for progress updates
@@ -72,6 +72,10 @@ class DownloadManager:
Returns:
Dict with download result
"""
# Validate that at least one identifier is provided
if not model_id and not model_version_id:
return {'success': False, 'error': 'Either model_id or model_version_id must be provided'}
# Use provided download_id or generate new one
task_id = download_id or str(uuid.uuid4())
@@ -181,14 +185,19 @@ class DownloadManager:
# Check both scanners
lora_scanner = await self._get_lora_scanner()
checkpoint_scanner = await self._get_checkpoint_scanner()
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
# Check lora scanner first
if await lora_scanner.check_model_version_exists(model_id, model_version_id):
if await lora_scanner.check_model_version_exists(model_version_id):
return {'success': False, 'error': 'Model version already exists in lora library'}
# Check checkpoint scanner
if await checkpoint_scanner.check_model_version_exists(model_id, model_version_id):
if await checkpoint_scanner.check_model_version_exists(model_version_id):
return {'success': False, 'error': 'Model version already exists in checkpoint library'}
# Check embedding scanner
if await embedding_scanner.check_model_version_exists(model_version_id):
return {'success': False, 'error': 'Model version already exists in embedding library'}
# Get civitai client
civitai_client = await self._get_civitai_client()
@@ -211,23 +220,22 @@ class DownloadManager:
# Case 2: model_version_id was None, check after getting version_info
if model_version_id is None:
version_model_id = version_info.get('modelId')
version_id = version_info.get('id')
if model_type == 'lora':
# Check lora scanner
lora_scanner = await self._get_lora_scanner()
if await lora_scanner.check_model_version_exists(version_model_id, version_id):
if await lora_scanner.check_model_version_exists(version_id):
return {'success': False, 'error': 'Model version already exists in lora library'}
elif model_type == 'checkpoint':
# Check checkpoint scanner
checkpoint_scanner = await self._get_checkpoint_scanner()
if await checkpoint_scanner.check_model_version_exists(version_model_id, version_id):
if await checkpoint_scanner.check_model_version_exists(version_id):
return {'success': False, 'error': 'Model version already exists in checkpoint library'}
elif model_type == 'embedding':
# Embeddings are not checked in scanners, but we can still check if it exists
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
if await embedding_scanner.check_model_version_exists(version_model_id, version_id):
if await embedding_scanner.check_model_version_exists(version_id):
return {'success': False, 'error': 'Model version already exists in embedding library'}
# Handle use_default_paths

View File

@@ -1149,13 +1149,12 @@ class ModelScanner:
if len(self._hash_index._duplicate_filenames[file_name]) <= 1:
del self._hash_index._duplicate_filenames[file_name]
async def check_model_version_exists(self, model_id: int, model_version_id: int) -> bool:
async def check_model_version_exists(self, model_version_id: int) -> bool:
"""Check if a specific model version exists in the cache
Args:
model_id: Civitai model ID
model_version_id: Civitai model version ID
Returns:
bool: True if the model version exists, False otherwise
"""
@@ -1163,13 +1162,11 @@ class ModelScanner:
cache = await self.get_cached_data()
if not cache or not cache.raw_data:
return False
for item in cache.raw_data:
if (item.get('civitai') and
item['civitai'].get('modelId') == model_id and
item['civitai'].get('id') == model_version_id):
if item.get('civitai') and item['civitai'].get('id') == model_version_id:
return True
return False
except Exception as e:
logger.error(f"Error checking model version existence: {e}")

View File

@@ -580,16 +580,19 @@ class ModelRouteUtils:
})
# Check which identifier is provided and convert to int
try:
model_id = int(data.get('model_id'))
except (TypeError, ValueError):
return web.json_response({
'success': False,
'error': "Invalid model_id: Must be an integer"
}, status=400)
model_id = None
model_version_id = None
if data.get('model_id'):
try:
model_id = int(data.get('model_id'))
except (TypeError, ValueError):
return web.json_response({
'success': False,
'error': "Invalid model_id: Must be an integer"
}, status=400)
# Convert model_version_id to int if provided
model_version_id = None
if data.get('model_version_id'):
try:
model_version_id = int(data.get('model_version_id'))
@@ -599,11 +602,11 @@ class ModelRouteUtils:
'error': "Invalid model_version_id: Must be an integer"
}, status=400)
# Only model_id is required, model_version_id is optional
if not model_id:
# At least one identifier is required
if not model_id and not model_version_id:
return web.json_response({
'success': False,
'error': "Missing required parameter: Please provide 'model_id'"
'error': "Missing required parameter: Please provide either 'model_id' or 'model_version_id'"
}, status=400)
use_default_paths = data.get('use_default_paths', False)

View File

@@ -17,16 +17,16 @@ export function createModelApiClient(modelType) {
}
}
let _singletonClient = null;
let _singletonClients = new Map();
export function getModelApiClient() {
const currentType = state.currentPageType;
export function getModelApiClient(modelType = null) {
const targetType = modelType || state.currentPageType;
if (!_singletonClient || _singletonClient.modelType !== currentType) {
_singletonClient = createModelApiClient(currentType);
if (!_singletonClients.has(targetType)) {
_singletonClients.set(targetType, createModelApiClient(targetType));
}
return _singletonClient;
return _singletonClients.get(targetType);
}
export function resetAndReload(updateFolders = false) {

View File

@@ -1,4 +1,6 @@
import { showToast } from '../../utils/uiHelpers.js';
import { getModelApiClient } from '../../api/modelApiFactory.js';
import { MODEL_TYPES } from '../../api/apiConfig.js';
export class DownloadManager {
constructor(importManager) {
@@ -199,39 +201,27 @@ export class DownloadManager {
updateProgress(0, completedDownloads, lora.name);
try {
console.log(`lora:`, lora);
// Download the LoRA with download ID
const response = await fetch('/api/download-model', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
model_id: lora.modelId,
model_version_id: lora.id,
model_root: loraRoot,
relative_path: targetPath.replace(loraRoot + '/', ''),
download_id: batchDownloadId
})
});
const response = await getModelApiClient(MODEL_TYPES.LORA).downloadModel(
lora.modelId,
lora.modelVersionId,
loraRoot,
targetPath.replace(loraRoot + '/', ''),
batchDownloadId
);
if (!response.ok) {
const errorText = await response.text();
console.error(`Failed to download LoRA ${lora.name}: ${errorText}`);
// Check if this is an early access error (status 401 is the key indicator)
if (response.status === 401) {
accessFailures++;
this.importManager.loadingManager.setStatus(
`Failed to download ${lora.name}: Access restricted`
);
}
if (!response.success) {
console.error(`Failed to download LoRA ${lora.name}: ${response.error}`);
failedDownloads++;
// Continue with next download
} else {
completedDownloads++;
// Update progress to show completion of current LoRA
updateProgress(100, completedDownloads, '');
if (completedDownloads + failedDownloads < this.importManager.downloadableLoRAs.length) {
this.importManager.loadingManager.setStatus(
`Completed ${completedDownloads}/${this.importManager.downloadableLoRAs.length} LoRAs. Starting next download...`