diff --git a/py/routes/api_routes.py b/py/routes/api_routes.py index 876e67a1..a2969b60 100644 --- a/py/routes/api_routes.py +++ b/py/routes/api_routes.py @@ -358,10 +358,19 @@ class ApiRoutes: self.civitai_client = await ServiceRegistry.get_civitai_client() model_id = request.match_info['model_id'] - versions = await self.civitai_client.get_model_versions(model_id) - if not versions: + response = await self.civitai_client.get_model_versions(model_id) + if not response or not response.get('modelVersions'): return web.Response(status=404, text="Model not found") + versions = response.get('modelVersions', []) + model_type = response.get('type', '') + + # Check model type - should be LORA + if model_type.lower() != 'lora': + return web.json_response({ + 'error': f"Model type mismatch. Expected LORA, got {model_type}" + }, status=400) + # Check local availability for each version for version in versions: # Find the model file (type="Model") in the files list @@ -372,9 +381,9 @@ class ApiRoutes: sha256 = model_file.get('hashes', {}).get('SHA256') if sha256: # Set existsLocally and localPath at the version level - version['existsLocally'] = self.scanner.has_lora_hash(sha256) + version['existsLocally'] = self.scanner.has_hash(sha256) if version['existsLocally']: - version['localPath'] = self.scanner.get_lora_path_by_hash(sha256) + version['localPath'] = self.scanner.get_path_by_hash(sha256) # Also set the model file size at the version level for easier access version['modelSizeKB'] = model_file.get('sizeKB') diff --git a/py/routes/checkpoints_routes.py b/py/routes/checkpoints_routes.py index 838adba3..8cc35555 100644 --- a/py/routes/checkpoints_routes.py +++ b/py/routes/checkpoints_routes.py @@ -45,6 +45,7 @@ class CheckpointsRoutes: app.router.add_get('/api/checkpoints/scan', self.scan_checkpoints) app.router.add_get('/api/checkpoints/info/{name}', self.get_checkpoint_info) app.router.add_get('/api/checkpoints/roots', self.get_checkpoint_roots) + app.router.add_get('/api/checkpoints/civitai/versions/{model_id}', self.get_civitai_versions) # Add new route # Add new routes for model management similar to LoRA routes app.router.add_post('/api/checkpoints/delete', self.delete_model) @@ -565,3 +566,56 @@ class CheckpointsRoutes: except Exception as e: logger.error(f"Error saving checkpoint metadata: {e}", exc_info=True) return web.Response(text=str(e), status=500) + + async def get_civitai_versions(self, request: web.Request) -> web.Response: + """Get available versions for a Civitai checkpoint model with local availability info""" + try: + if self.scanner is None: + self.scanner = await ServiceRegistry.get_checkpoint_scanner() + + # Get the civitai client from service registry + civitai_client = await ServiceRegistry.get_civitai_client() + + model_id = request.match_info['model_id'] + response = await civitai_client.get_model_versions(model_id) + if not response or not response.get('modelVersions'): + return web.Response(status=404, text="Model not found") + + versions = response.get('modelVersions', []) + model_type = response.get('type', '') + + # Check model type - should be Checkpoint + if model_type.lower() != 'checkpoint': + return web.json_response({ + 'error': f"Model type mismatch. Expected Checkpoint, got {model_type}" + }, status=400) + + # Check local availability for each version + for version in versions: + # Find the primary model file (type="Model" and primary=true) in the files list + model_file = next((file for file in version.get('files', []) + if file.get('type') == 'Model' and file.get('primary') == True), None) + + # If no primary file found, try to find any model file + if not model_file: + model_file = next((file for file in version.get('files', []) + if file.get('type') == 'Model'), None) + + if model_file: + sha256 = model_file.get('hashes', {}).get('SHA256') + if sha256: + # Set existsLocally and localPath at the version level + version['existsLocally'] = self.scanner.has_hash(sha256) + if version['existsLocally']: + version['localPath'] = self.scanner.get_path_by_hash(sha256) + + # Also set the model file size at the version level for easier access + version['modelSizeKB'] = model_file.get('sizeKB') + else: + # No model file found in this version + version['existsLocally'] = False + + return web.json_response(versions) + except Exception as e: + logger.error(f"Error fetching checkpoint model versions: {e}") + return web.Response(status=500, text=str(e)) diff --git a/py/services/civitai_client.py b/py/services/civitai_client.py index 24227f8d..7f548db2 100644 --- a/py/services/civitai_client.py +++ b/py/services/civitai_client.py @@ -172,7 +172,11 @@ class CivitaiClient: if response.status != 200: return None data = await response.json() - return data.get('modelVersions', []) + # Also return model type along with versions + return { + 'modelVersions': data.get('modelVersions', []), + 'type': data.get('type', '') + } except Exception as e: logger.error(f"Error fetching model versions: {e}") return None diff --git a/py/services/lora_scanner.py b/py/services/lora_scanner.py index 29908ef9..94dee45d 100644 --- a/py/services/lora_scanner.py +++ b/py/services/lora_scanner.py @@ -293,15 +293,15 @@ class LoraScanner(ModelScanner): # Lora-specific hash index functionality def has_lora_hash(self, sha256: str) -> bool: """Check if a LoRA with given hash exists""" - return self._hash_index.has_hash(sha256.lower()) + return self.has_hash(sha256) def get_lora_path_by_hash(self, sha256: str) -> Optional[str]: """Get file path for a LoRA by its hash""" - return self._hash_index.get_path(sha256.lower()) + return self.get_path_by_hash(sha256) def get_lora_hash_by_path(self, file_path: str) -> Optional[str]: """Get hash for a LoRA by its file path""" - return self._hash_index.get_hash(file_path) + return self.get_hash_by_path(file_path) async def get_top_tags(self, limit: int = 20) -> List[Dict[str, any]]: """Get top tags sorted by count""" diff --git a/static/js/managers/CheckpointDownloadManager.js b/static/js/managers/CheckpointDownloadManager.js index 5dfb235c..5a13f116 100644 --- a/static/js/managers/CheckpointDownloadManager.js +++ b/static/js/managers/CheckpointDownloadManager.js @@ -76,8 +76,12 @@ export class CheckpointDownloadManager { throw new Error('Invalid Civitai URL format'); } - const response = await fetch(`/api/civitai/versions/${modelId}`); + const response = await fetch(`/api/checkpoints/civitai/versions/${modelId}`); if (!response.ok) { + const errorData = await response.json().catch(() => ({})); + if (errorData && errorData.error && errorData.error.includes('Model type mismatch')) { + throw new Error('This model is not a Checkpoint. Please switch to the LoRAs page to download LoRA models.'); + } throw new Error('Failed to fetch model versions'); } diff --git a/static/js/managers/DownloadManager.js b/static/js/managers/DownloadManager.js index c0891bdf..b489b69a 100644 --- a/static/js/managers/DownloadManager.js +++ b/static/js/managers/DownloadManager.js @@ -80,6 +80,10 @@ export class DownloadManager { const response = await fetch(`/api/civitai/versions/${modelId}`); if (!response.ok) { + const errorData = await response.json().catch(() => ({})); + if (errorData && errorData.error && errorData.error.includes('Model type mismatch')) { + throw new Error('This model is not a LoRA. Please switch to the Checkpoints page to download checkpoint models.'); + } throw new Error('Failed to fetch model versions'); }