fix: update model_id and model_version_id handling across various services for improved flexibility

This commit is contained in:
Will Miao
2025-08-11 15:31:49 +08:00
parent b03420faac
commit 96517cbdef
7 changed files with 126 additions and 93 deletions

View File

@@ -654,13 +654,13 @@ class MiscRoutes:
exists = False exists = False
model_type = None model_type = None
if await lora_scanner.check_model_version_exists(model_id, model_version_id): if await lora_scanner.check_model_version_exists(model_version_id):
exists = True exists = True
model_type = 'lora' model_type = 'lora'
elif checkpoint_scanner and await checkpoint_scanner.check_model_version_exists(model_id, model_version_id): elif checkpoint_scanner and await checkpoint_scanner.check_model_version_exists(model_version_id):
exists = True exists = True
model_type = 'checkpoint' model_type = 'checkpoint'
elif embedding_scanner and await embedding_scanner.check_model_version_exists(model_id, model_version_id): elif embedding_scanner and await embedding_scanner.check_model_version_exists(model_version_id):
exists = True exists = True
model_type = 'embedding' model_type = 'embedding'

View File

@@ -223,11 +223,11 @@ class CivitaiClient:
logger.error(f"Error fetching model versions: {e}") logger.error(f"Error fetching model versions: {e}")
return None return None
async def get_model_version(self, model_id: int, version_id: int = None) -> Optional[Dict]: async def get_model_version(self, model_id: int = None, version_id: int = None) -> Optional[Dict]:
"""Get specific model version with additional metadata """Get specific model version with additional metadata
Args: Args:
model_id: The Civitai model ID model_id: The Civitai model ID (optional if version_id is provided)
version_id: Optional specific version ID to retrieve version_id: Optional specific version ID to retrieve
Returns: Returns:
@@ -235,37 +235,72 @@ class CivitaiClient:
""" """
try: try:
session = await self._ensure_fresh_session() session = await self._ensure_fresh_session()
# Step 1: Get model data to find version_id if not provided and get additional metadata
async with session.get(f"{self.base_url}/models/{model_id}") as response:
if response.status != 200:
return None
data = await response.json()
model_versions = data.get('modelVersions', [])
# Step 2: Determine the version_id to use
target_version_id = version_id
if target_version_id is None:
target_version_id = model_versions[0].get('id')
# Step 3: Get detailed version info using the version_id
headers = self._get_request_headers() headers = self._get_request_headers()
async with session.get(f"{self.base_url}/model-versions/{target_version_id}", headers=headers) as response:
if response.status != 200: # Case 1: Only version_id is provided
return None if model_id is None and version_id is not None:
# First get the version info to extract model_id
async with session.get(f"{self.base_url}/model-versions/{version_id}", headers=headers) as response:
if response.status != 200:
return None
version = await response.json()
model_id = version.get('modelId')
if not model_id:
logger.error(f"No modelId found in version {version_id}")
return None
version = await response.json() # Now get the model data for additional metadata
async with session.get(f"{self.base_url}/models/{model_id}") as response:
if response.status != 200:
return version # Return version without additional metadata
model_data = await response.json()
# Enrich version with model data
version['model']['description'] = model_data.get("description")
version['model']['tags'] = model_data.get("tags", [])
version['creator'] = model_data.get("creator")
return version
# Case 2: model_id is provided (with or without version_id)
elif model_id is not None:
# Step 1: Get model data to find version_id if not provided and get additional metadata
async with session.get(f"{self.base_url}/models/{model_id}") as response:
if response.status != 200:
return None
data = await response.json()
model_versions = data.get('modelVersions', [])
# Step 2: Determine the version_id to use
target_version_id = version_id
if target_version_id is None:
target_version_id = model_versions[0].get('id')
# Step 4: Enrich version_info with model data # Step 3: Get detailed version info using the version_id
# Add description and tags from model data async with session.get(f"{self.base_url}/model-versions/{target_version_id}", headers=headers) as response:
version['model']['description'] = data.get("description") if response.status != 200:
version['model']['tags'] = data.get("tags", []) return None
# Add creator from model data version = await response.json()
version['creator'] = data.get("creator")
# Step 4: Enrich version_info with model data
return version # Add description and tags from model data
version['model']['description'] = data.get("description")
version['model']['tags'] = data.get("tags", [])
# Add creator from model data
version['creator'] = data.get("creator")
return version
# Case 3: Neither model_id nor version_id provided
else:
logger.error("Either model_id or version_id must be provided")
return None
except Exception as e: except Exception as e:
logger.error(f"Error fetching model version: {e}") logger.error(f"Error fetching model version: {e}")

View File

@@ -54,15 +54,15 @@ class DownloadManager:
"""Get the checkpoint scanner from registry""" """Get the checkpoint scanner from registry"""
return await ServiceRegistry.get_checkpoint_scanner() return await ServiceRegistry.get_checkpoint_scanner()
async def download_from_civitai(self, model_id: int, model_version_id: int, async def download_from_civitai(self, model_id: int = None, model_version_id: int = None,
save_dir: str = None, relative_path: str = '', save_dir: str = None, relative_path: str = '',
progress_callback=None, use_default_paths: bool = False, progress_callback=None, use_default_paths: bool = False,
download_id: str = None) -> Dict: download_id: str = None) -> Dict:
"""Download model from Civitai with task tracking and concurrency control """Download model from Civitai with task tracking and concurrency control
Args: Args:
model_id: Civitai model ID model_id: Civitai model ID (optional if model_version_id is provided)
model_version_id: Civitai model version ID model_version_id: Civitai model version ID (optional if model_id is provided)
save_dir: Directory to save the model save_dir: Directory to save the model
relative_path: Relative path within save_dir relative_path: Relative path within save_dir
progress_callback: Callback function for progress updates progress_callback: Callback function for progress updates
@@ -72,6 +72,10 @@ class DownloadManager:
Returns: Returns:
Dict with download result Dict with download result
""" """
# Validate that at least one identifier is provided
if not model_id and not model_version_id:
return {'success': False, 'error': 'Either model_id or model_version_id must be provided'}
# Use provided download_id or generate new one # Use provided download_id or generate new one
task_id = download_id or str(uuid.uuid4()) task_id = download_id or str(uuid.uuid4())
@@ -181,14 +185,19 @@ class DownloadManager:
# Check both scanners # Check both scanners
lora_scanner = await self._get_lora_scanner() lora_scanner = await self._get_lora_scanner()
checkpoint_scanner = await self._get_checkpoint_scanner() checkpoint_scanner = await self._get_checkpoint_scanner()
embedding_scanner = await ServiceRegistry.get_embedding_scanner()
# Check lora scanner first # Check lora scanner first
if await lora_scanner.check_model_version_exists(model_id, model_version_id): if await lora_scanner.check_model_version_exists(model_version_id):
return {'success': False, 'error': 'Model version already exists in lora library'} return {'success': False, 'error': 'Model version already exists in lora library'}
# Check checkpoint scanner # Check checkpoint scanner
if await checkpoint_scanner.check_model_version_exists(model_id, model_version_id): if await checkpoint_scanner.check_model_version_exists(model_version_id):
return {'success': False, 'error': 'Model version already exists in checkpoint library'} return {'success': False, 'error': 'Model version already exists in checkpoint library'}
# Check embedding scanner
if await embedding_scanner.check_model_version_exists(model_version_id):
return {'success': False, 'error': 'Model version already exists in embedding library'}
# Get civitai client # Get civitai client
civitai_client = await self._get_civitai_client() civitai_client = await self._get_civitai_client()
@@ -211,23 +220,22 @@ class DownloadManager:
# Case 2: model_version_id was None, check after getting version_info # Case 2: model_version_id was None, check after getting version_info
if model_version_id is None: if model_version_id is None:
version_model_id = version_info.get('modelId')
version_id = version_info.get('id') version_id = version_info.get('id')
if model_type == 'lora': if model_type == 'lora':
# Check lora scanner # Check lora scanner
lora_scanner = await self._get_lora_scanner() lora_scanner = await self._get_lora_scanner()
if await lora_scanner.check_model_version_exists(version_model_id, version_id): if await lora_scanner.check_model_version_exists(version_id):
return {'success': False, 'error': 'Model version already exists in lora library'} return {'success': False, 'error': 'Model version already exists in lora library'}
elif model_type == 'checkpoint': elif model_type == 'checkpoint':
# Check checkpoint scanner # Check checkpoint scanner
checkpoint_scanner = await self._get_checkpoint_scanner() checkpoint_scanner = await self._get_checkpoint_scanner()
if await checkpoint_scanner.check_model_version_exists(version_model_id, version_id): if await checkpoint_scanner.check_model_version_exists(version_id):
return {'success': False, 'error': 'Model version already exists in checkpoint library'} return {'success': False, 'error': 'Model version already exists in checkpoint library'}
elif model_type == 'embedding': elif model_type == 'embedding':
# Embeddings are not checked in scanners, but we can still check if it exists # Embeddings are not checked in scanners, but we can still check if it exists
embedding_scanner = await ServiceRegistry.get_embedding_scanner() embedding_scanner = await ServiceRegistry.get_embedding_scanner()
if await embedding_scanner.check_model_version_exists(version_model_id, version_id): if await embedding_scanner.check_model_version_exists(version_id):
return {'success': False, 'error': 'Model version already exists in embedding library'} return {'success': False, 'error': 'Model version already exists in embedding library'}
# Handle use_default_paths # Handle use_default_paths

View File

@@ -1149,13 +1149,12 @@ class ModelScanner:
if len(self._hash_index._duplicate_filenames[file_name]) <= 1: if len(self._hash_index._duplicate_filenames[file_name]) <= 1:
del self._hash_index._duplicate_filenames[file_name] del self._hash_index._duplicate_filenames[file_name]
async def check_model_version_exists(self, model_id: int, model_version_id: int) -> bool: async def check_model_version_exists(self, model_version_id: int) -> bool:
"""Check if a specific model version exists in the cache """Check if a specific model version exists in the cache
Args: Args:
model_id: Civitai model ID
model_version_id: Civitai model version ID model_version_id: Civitai model version ID
Returns: Returns:
bool: True if the model version exists, False otherwise bool: True if the model version exists, False otherwise
""" """
@@ -1163,13 +1162,11 @@ class ModelScanner:
cache = await self.get_cached_data() cache = await self.get_cached_data()
if not cache or not cache.raw_data: if not cache or not cache.raw_data:
return False return False
for item in cache.raw_data: for item in cache.raw_data:
if (item.get('civitai') and if item.get('civitai') and item['civitai'].get('id') == model_version_id:
item['civitai'].get('modelId') == model_id and
item['civitai'].get('id') == model_version_id):
return True return True
return False return False
except Exception as e: except Exception as e:
logger.error(f"Error checking model version existence: {e}") logger.error(f"Error checking model version existence: {e}")

View File

@@ -580,16 +580,19 @@ class ModelRouteUtils:
}) })
# Check which identifier is provided and convert to int # Check which identifier is provided and convert to int
try: model_id = None
model_id = int(data.get('model_id')) model_version_id = None
except (TypeError, ValueError):
return web.json_response({ if data.get('model_id'):
'success': False, try:
'error': "Invalid model_id: Must be an integer" model_id = int(data.get('model_id'))
}, status=400) except (TypeError, ValueError):
return web.json_response({
'success': False,
'error': "Invalid model_id: Must be an integer"
}, status=400)
# Convert model_version_id to int if provided # Convert model_version_id to int if provided
model_version_id = None
if data.get('model_version_id'): if data.get('model_version_id'):
try: try:
model_version_id = int(data.get('model_version_id')) model_version_id = int(data.get('model_version_id'))
@@ -599,11 +602,11 @@ class ModelRouteUtils:
'error': "Invalid model_version_id: Must be an integer" 'error': "Invalid model_version_id: Must be an integer"
}, status=400) }, status=400)
# Only model_id is required, model_version_id is optional # At least one identifier is required
if not model_id: if not model_id and not model_version_id:
return web.json_response({ return web.json_response({
'success': False, 'success': False,
'error': "Missing required parameter: Please provide 'model_id'" 'error': "Missing required parameter: Please provide either 'model_id' or 'model_version_id'"
}, status=400) }, status=400)
use_default_paths = data.get('use_default_paths', False) use_default_paths = data.get('use_default_paths', False)

View File

@@ -17,16 +17,16 @@ export function createModelApiClient(modelType) {
} }
} }
let _singletonClient = null; let _singletonClients = new Map();
export function getModelApiClient() { export function getModelApiClient(modelType = null) {
const currentType = state.currentPageType; const targetType = modelType || state.currentPageType;
if (!_singletonClient || _singletonClient.modelType !== currentType) { if (!_singletonClients.has(targetType)) {
_singletonClient = createModelApiClient(currentType); _singletonClients.set(targetType, createModelApiClient(targetType));
} }
return _singletonClient; return _singletonClients.get(targetType);
} }
export function resetAndReload(updateFolders = false) { export function resetAndReload(updateFolders = false) {

View File

@@ -1,4 +1,6 @@
import { showToast } from '../../utils/uiHelpers.js'; import { showToast } from '../../utils/uiHelpers.js';
import { getModelApiClient } from '../../api/modelApiFactory.js';
import { MODEL_TYPES } from '../../api/apiConfig.js';
export class DownloadManager { export class DownloadManager {
constructor(importManager) { constructor(importManager) {
@@ -199,39 +201,27 @@ export class DownloadManager {
updateProgress(0, completedDownloads, lora.name); updateProgress(0, completedDownloads, lora.name);
try { try {
console.log(`lora:`, lora);
// Download the LoRA with download ID // Download the LoRA with download ID
const response = await fetch('/api/download-model', { const response = await getModelApiClient(MODEL_TYPES.LORA).downloadModel(
method: 'POST', lora.modelId,
headers: { 'Content-Type': 'application/json' }, lora.modelVersionId,
body: JSON.stringify({ loraRoot,
model_id: lora.modelId, targetPath.replace(loraRoot + '/', ''),
model_version_id: lora.id, batchDownloadId
model_root: loraRoot, );
relative_path: targetPath.replace(loraRoot + '/', ''),
download_id: batchDownloadId
})
});
if (!response.ok) { if (!response.success) {
const errorText = await response.text(); console.error(`Failed to download LoRA ${lora.name}: ${response.error}`);
console.error(`Failed to download LoRA ${lora.name}: ${errorText}`);
// Check if this is an early access error (status 401 is the key indicator)
if (response.status === 401) {
accessFailures++;
this.importManager.loadingManager.setStatus(
`Failed to download ${lora.name}: Access restricted`
);
}
failedDownloads++; failedDownloads++;
// Continue with next download // Continue with next download
} else { } else {
completedDownloads++; completedDownloads++;
// Update progress to show completion of current LoRA // Update progress to show completion of current LoRA
updateProgress(100, completedDownloads, ''); updateProgress(100, completedDownloads, '');
if (completedDownloads + failedDownloads < this.importManager.downloadableLoRAs.length) { if (completedDownloads + failedDownloads < this.importManager.downloadableLoRAs.length) {
this.importManager.loadingManager.setStatus( this.importManager.loadingManager.setStatus(
`Completed ${completedDownloads}/${this.importManager.downloadableLoRAs.length} LoRAs. Starting next download...` `Completed ${completedDownloads}/${this.importManager.downloadableLoRAs.length} LoRAs. Starting next download...`