mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-21 21:22:11 -03:00
feat: Add Civitai model version retrieval for Checkpoints and update error handling in download managers
This commit is contained in:
@@ -358,10 +358,19 @@ class ApiRoutes:
|
||||
self.civitai_client = await ServiceRegistry.get_civitai_client()
|
||||
|
||||
model_id = request.match_info['model_id']
|
||||
versions = await self.civitai_client.get_model_versions(model_id)
|
||||
if not versions:
|
||||
response = await self.civitai_client.get_model_versions(model_id)
|
||||
if not response or not response.get('modelVersions'):
|
||||
return web.Response(status=404, text="Model not found")
|
||||
|
||||
versions = response.get('modelVersions', [])
|
||||
model_type = response.get('type', '')
|
||||
|
||||
# Check model type - should be LORA
|
||||
if model_type.lower() != 'lora':
|
||||
return web.json_response({
|
||||
'error': f"Model type mismatch. Expected LORA, got {model_type}"
|
||||
}, status=400)
|
||||
|
||||
# Check local availability for each version
|
||||
for version in versions:
|
||||
# Find the model file (type="Model") in the files list
|
||||
@@ -372,9 +381,9 @@ class ApiRoutes:
|
||||
sha256 = model_file.get('hashes', {}).get('SHA256')
|
||||
if sha256:
|
||||
# Set existsLocally and localPath at the version level
|
||||
version['existsLocally'] = self.scanner.has_lora_hash(sha256)
|
||||
version['existsLocally'] = self.scanner.has_hash(sha256)
|
||||
if version['existsLocally']:
|
||||
version['localPath'] = self.scanner.get_lora_path_by_hash(sha256)
|
||||
version['localPath'] = self.scanner.get_path_by_hash(sha256)
|
||||
|
||||
# Also set the model file size at the version level for easier access
|
||||
version['modelSizeKB'] = model_file.get('sizeKB')
|
||||
|
||||
@@ -45,6 +45,7 @@ class CheckpointsRoutes:
|
||||
app.router.add_get('/api/checkpoints/scan', self.scan_checkpoints)
|
||||
app.router.add_get('/api/checkpoints/info/{name}', self.get_checkpoint_info)
|
||||
app.router.add_get('/api/checkpoints/roots', self.get_checkpoint_roots)
|
||||
app.router.add_get('/api/checkpoints/civitai/versions/{model_id}', self.get_civitai_versions) # Add new route
|
||||
|
||||
# Add new routes for model management similar to LoRA routes
|
||||
app.router.add_post('/api/checkpoints/delete', self.delete_model)
|
||||
@@ -565,3 +566,56 @@ class CheckpointsRoutes:
|
||||
except Exception as e:
|
||||
logger.error(f"Error saving checkpoint metadata: {e}", exc_info=True)
|
||||
return web.Response(text=str(e), status=500)
|
||||
|
||||
async def get_civitai_versions(self, request: web.Request) -> web.Response:
|
||||
"""Get available versions for a Civitai checkpoint model with local availability info"""
|
||||
try:
|
||||
if self.scanner is None:
|
||||
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
|
||||
|
||||
# Get the civitai client from service registry
|
||||
civitai_client = await ServiceRegistry.get_civitai_client()
|
||||
|
||||
model_id = request.match_info['model_id']
|
||||
response = await civitai_client.get_model_versions(model_id)
|
||||
if not response or not response.get('modelVersions'):
|
||||
return web.Response(status=404, text="Model not found")
|
||||
|
||||
versions = response.get('modelVersions', [])
|
||||
model_type = response.get('type', '')
|
||||
|
||||
# Check model type - should be Checkpoint
|
||||
if model_type.lower() != 'checkpoint':
|
||||
return web.json_response({
|
||||
'error': f"Model type mismatch. Expected Checkpoint, got {model_type}"
|
||||
}, status=400)
|
||||
|
||||
# Check local availability for each version
|
||||
for version in versions:
|
||||
# Find the primary model file (type="Model" and primary=true) in the files list
|
||||
model_file = next((file for file in version.get('files', [])
|
||||
if file.get('type') == 'Model' and file.get('primary') == True), None)
|
||||
|
||||
# If no primary file found, try to find any model file
|
||||
if not model_file:
|
||||
model_file = next((file for file in version.get('files', [])
|
||||
if file.get('type') == 'Model'), None)
|
||||
|
||||
if model_file:
|
||||
sha256 = model_file.get('hashes', {}).get('SHA256')
|
||||
if sha256:
|
||||
# Set existsLocally and localPath at the version level
|
||||
version['existsLocally'] = self.scanner.has_hash(sha256)
|
||||
if version['existsLocally']:
|
||||
version['localPath'] = self.scanner.get_path_by_hash(sha256)
|
||||
|
||||
# Also set the model file size at the version level for easier access
|
||||
version['modelSizeKB'] = model_file.get('sizeKB')
|
||||
else:
|
||||
# No model file found in this version
|
||||
version['existsLocally'] = False
|
||||
|
||||
return web.json_response(versions)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching checkpoint model versions: {e}")
|
||||
return web.Response(status=500, text=str(e))
|
||||
|
||||
@@ -172,7 +172,11 @@ class CivitaiClient:
|
||||
if response.status != 200:
|
||||
return None
|
||||
data = await response.json()
|
||||
return data.get('modelVersions', [])
|
||||
# Also return model type along with versions
|
||||
return {
|
||||
'modelVersions': data.get('modelVersions', []),
|
||||
'type': data.get('type', '')
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching model versions: {e}")
|
||||
return None
|
||||
|
||||
@@ -293,15 +293,15 @@ class LoraScanner(ModelScanner):
|
||||
# Lora-specific hash index functionality
|
||||
def has_lora_hash(self, sha256: str) -> bool:
|
||||
"""Check if a LoRA with given hash exists"""
|
||||
return self._hash_index.has_hash(sha256.lower())
|
||||
return self.has_hash(sha256)
|
||||
|
||||
def get_lora_path_by_hash(self, sha256: str) -> Optional[str]:
|
||||
"""Get file path for a LoRA by its hash"""
|
||||
return self._hash_index.get_path(sha256.lower())
|
||||
return self.get_path_by_hash(sha256)
|
||||
|
||||
def get_lora_hash_by_path(self, file_path: str) -> Optional[str]:
|
||||
"""Get hash for a LoRA by its file path"""
|
||||
return self._hash_index.get_hash(file_path)
|
||||
return self.get_hash_by_path(file_path)
|
||||
|
||||
async def get_top_tags(self, limit: int = 20) -> List[Dict[str, any]]:
|
||||
"""Get top tags sorted by count"""
|
||||
|
||||
@@ -76,8 +76,12 @@ export class CheckpointDownloadManager {
|
||||
throw new Error('Invalid Civitai URL format');
|
||||
}
|
||||
|
||||
const response = await fetch(`/api/civitai/versions/${modelId}`);
|
||||
const response = await fetch(`/api/checkpoints/civitai/versions/${modelId}`);
|
||||
if (!response.ok) {
|
||||
const errorData = await response.json().catch(() => ({}));
|
||||
if (errorData && errorData.error && errorData.error.includes('Model type mismatch')) {
|
||||
throw new Error('This model is not a Checkpoint. Please switch to the LoRAs page to download LoRA models.');
|
||||
}
|
||||
throw new Error('Failed to fetch model versions');
|
||||
}
|
||||
|
||||
|
||||
@@ -80,6 +80,10 @@ export class DownloadManager {
|
||||
|
||||
const response = await fetch(`/api/civitai/versions/${modelId}`);
|
||||
if (!response.ok) {
|
||||
const errorData = await response.json().catch(() => ({}));
|
||||
if (errorData && errorData.error && errorData.error.includes('Model type mismatch')) {
|
||||
throw new Error('This model is not a LoRA. Please switch to the Checkpoints page to download checkpoint models.');
|
||||
}
|
||||
throw new Error('Failed to fetch model versions');
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user