diff --git a/py/routes/api_routes.py b/py/routes/api_routes.py index 611f65c6..b7bba034 100644 --- a/py/routes/api_routes.py +++ b/py/routes/api_routes.py @@ -43,6 +43,7 @@ class ApiRoutes: app.on_startup.append(lambda _: routes.initialize_services()) app.router.add_post('/api/delete_model', routes.delete_model) + app.router.add_post('/api/exclude_model', routes.exclude_model) # Add new exclude endpoint app.router.add_post('/api/fetch-civitai', routes.fetch_civitai) app.router.add_post('/api/replace_preview', routes.replace_preview) app.router.add_get('/api/loras', routes.get_loras) @@ -81,6 +82,12 @@ class ApiRoutes: self.scanner = await ServiceRegistry.get_lora_scanner() return await ModelRouteUtils.handle_delete_model(request, self.scanner) + async def exclude_model(self, request: web.Request) -> web.Response: + """Handle model exclusion request""" + if self.scanner is None: + self.scanner = await ServiceRegistry.get_lora_scanner() + return await ModelRouteUtils.handle_exclude_model(request, self.scanner) + async def fetch_civitai(self, request: web.Request) -> web.Response: """Handle CivitAI metadata fetch request""" if self.scanner is None: diff --git a/py/routes/checkpoints_routes.py b/py/routes/checkpoints_routes.py index 46752f28..1e361f71 100644 --- a/py/routes/checkpoints_routes.py +++ b/py/routes/checkpoints_routes.py @@ -49,6 +49,7 @@ class CheckpointsRoutes: # Add new routes for model management similar to LoRA routes app.router.add_post('/api/checkpoints/delete', self.delete_model) + app.router.add_post('/api/checkpoints/exclude', self.exclude_model) # Add new exclude endpoint app.router.add_post('/api/checkpoints/fetch-civitai', self.fetch_civitai) app.router.add_post('/api/checkpoints/replace-preview', self.replace_preview) app.router.add_post('/api/checkpoints/download', self.download_checkpoint) @@ -499,6 +500,10 @@ class CheckpointsRoutes: async def delete_model(self, request: web.Request) -> web.Response: """Handle checkpoint model deletion request""" return await ModelRouteUtils.handle_delete_model(request, self.scanner) + + async def exclude_model(self, request: web.Request) -> web.Response: + """Handle checkpoint model exclusion request""" + return await ModelRouteUtils.handle_exclude_model(request, self.scanner) async def fetch_civitai(self, request: web.Request) -> web.Response: """Handle CivitAI metadata fetch request for checkpoints""" @@ -653,7 +658,7 @@ class CheckpointsRoutes: model_type = response.get('type', '') # Check model type - should be Checkpoint - if model_type.lower() != 'checkpoint': + if (model_type.lower() != 'checkpoint'): return web.json_response({ 'error': f"Model type mismatch. Expected Checkpoint, got {model_type}" }, status=400) diff --git a/py/services/model_scanner.py b/py/services/model_scanner.py index cff319c0..38c28de7 100644 --- a/py/services/model_scanner.py +++ b/py/services/model_scanner.py @@ -38,6 +38,7 @@ class ModelScanner: self._hash_index = hash_index or ModelHashIndex() self._tags_count = {} # Dictionary to store tag counts self._is_initializing = False # Flag to track initialization state + self._excluded_models = [] # List to track excluded models # Register this service asyncio.create_task(self._register_service()) @@ -394,6 +395,9 @@ class ModelScanner: if file_path in cached_paths: found_paths.add(file_path) continue + + if file_path in self._excluded_models: + continue # Try case-insensitive match on Windows if os.name == 'nt': @@ -406,7 +410,7 @@ class ModelScanner: break if matched: continue - + # This is a new file to process new_files.append(file_path) @@ -586,6 +590,11 @@ class ModelScanner: model_data = metadata.to_dict() + # Skip excluded models + if model_data.get('exclude', False): + self._excluded_models.append(model_data['file_path']) + return None + await self._fetch_missing_metadata(file_path, model_data) rel_path = os.path.relpath(file_path, root_path) folder = os.path.dirname(rel_path) @@ -905,6 +914,10 @@ class ModelScanner: logger.error(f"Error getting model info by name: {e}", exc_info=True) return None + def get_excluded_models(self) -> List[str]: + """Get list of excluded model file paths""" + return self._excluded_models.copy() + async def update_preview_in_cache(self, file_path: str, preview_url: str) -> bool: """Update preview URL in cache for a specific lora @@ -918,4 +931,4 @@ class ModelScanner: if self._cache is None: return False - return await self._cache.update_preview_url(file_path, preview_url) \ No newline at end of file + return await self._cache.update_preview_url(file_path, preview_url) diff --git a/py/utils/models.py b/py/utils/models.py index f0f7fc94..17ef6900 100644 --- a/py/utils/models.py +++ b/py/utils/models.py @@ -23,6 +23,7 @@ class BaseModelMetadata: modelDescription: str = "" # Full model description civitai_deleted: bool = False # Whether deleted from Civitai favorite: bool = False # Whether the model is a favorite + exclude: bool = False # Whether to exclude this model from the cache def __post_init__(self): # Initialize empty lists to avoid mutable default parameter issue diff --git a/py/utils/routes_common.py b/py/utils/routes_common.py index f8b5f649..3d219440 100644 --- a/py/utils/routes_common.py +++ b/py/utils/routes_common.py @@ -425,6 +425,65 @@ class ModelRouteUtils: logger.error(f"Error replacing preview: {e}", exc_info=True) return web.Response(text=str(e), status=500) + @staticmethod + async def handle_exclude_model(request: web.Request, scanner) -> web.Response: + """Handle model exclusion request + + Args: + request: The aiohttp request + scanner: The model scanner instance with cache management methods + + Returns: + web.Response: The HTTP response + """ + try: + data = await request.json() + file_path = data.get('file_path') + if not file_path: + return web.Response(text='Model path is required', status=400) + + # Update metadata to mark as excluded + metadata_path = os.path.splitext(file_path)[0] + '.metadata.json' + metadata = await ModelRouteUtils.load_local_metadata(metadata_path) + metadata['exclude'] = True + + # Save updated metadata + with open(metadata_path, 'w', encoding='utf-8') as f: + json.dump(metadata, f, indent=2, ensure_ascii=False) + + # Update cache + cache = await scanner.get_cached_data() + + # Find and remove model from cache + model_to_remove = next((item for item in cache.raw_data if item['file_path'] == file_path), None) + if model_to_remove: + # Update tags count + for tag in model_to_remove.get('tags', []): + if tag in scanner._tags_count: + scanner._tags_count[tag] = max(0, scanner._tags_count[tag] - 1) + if scanner._tags_count[tag] == 0: + del scanner._tags_count[tag] + + # Remove from hash index if available + if hasattr(scanner, '_hash_index') and scanner._hash_index: + scanner._hash_index.remove_by_path(file_path) + + # Remove from cache data + cache.raw_data = [item for item in cache.raw_data if item['file_path'] != file_path] + await cache.resort() + + # Add to excluded models list + scanner._excluded_models.append(file_path) + + return web.json_response({ + 'success': True, + 'message': f"Model {os.path.basename(file_path)} excluded" + }) + + except Exception as e: + logger.error(f"Error excluding model: {e}", exc_info=True) + return web.Response(text=str(e), status=500) + @staticmethod async def handle_download_model(request: web.Request, download_manager: DownloadManager, model_type="lora") -> web.Response: """Handle model download request @@ -501,4 +560,4 @@ class ModelRouteUtils: ) logger.error(f"Error downloading {model_type}: {error_message}") - return web.Response(status=500, text=error_message) \ No newline at end of file + return web.Response(status=500, text=error_message)