mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 07:05:43 -03:00
feat: add model exclusion functionality with new API endpoints and metadata handling
This commit is contained in:
@@ -43,6 +43,7 @@ class ApiRoutes:
|
|||||||
app.on_startup.append(lambda _: routes.initialize_services())
|
app.on_startup.append(lambda _: routes.initialize_services())
|
||||||
|
|
||||||
app.router.add_post('/api/delete_model', routes.delete_model)
|
app.router.add_post('/api/delete_model', routes.delete_model)
|
||||||
|
app.router.add_post('/api/exclude_model', routes.exclude_model) # Add new exclude endpoint
|
||||||
app.router.add_post('/api/fetch-civitai', routes.fetch_civitai)
|
app.router.add_post('/api/fetch-civitai', routes.fetch_civitai)
|
||||||
app.router.add_post('/api/replace_preview', routes.replace_preview)
|
app.router.add_post('/api/replace_preview', routes.replace_preview)
|
||||||
app.router.add_get('/api/loras', routes.get_loras)
|
app.router.add_get('/api/loras', routes.get_loras)
|
||||||
@@ -81,6 +82,12 @@ class ApiRoutes:
|
|||||||
self.scanner = await ServiceRegistry.get_lora_scanner()
|
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||||
return await ModelRouteUtils.handle_delete_model(request, self.scanner)
|
return await ModelRouteUtils.handle_delete_model(request, self.scanner)
|
||||||
|
|
||||||
|
async def exclude_model(self, request: web.Request) -> web.Response:
|
||||||
|
"""Handle model exclusion request"""
|
||||||
|
if self.scanner is None:
|
||||||
|
self.scanner = await ServiceRegistry.get_lora_scanner()
|
||||||
|
return await ModelRouteUtils.handle_exclude_model(request, self.scanner)
|
||||||
|
|
||||||
async def fetch_civitai(self, request: web.Request) -> web.Response:
|
async def fetch_civitai(self, request: web.Request) -> web.Response:
|
||||||
"""Handle CivitAI metadata fetch request"""
|
"""Handle CivitAI metadata fetch request"""
|
||||||
if self.scanner is None:
|
if self.scanner is None:
|
||||||
|
|||||||
@@ -49,6 +49,7 @@ class CheckpointsRoutes:
|
|||||||
|
|
||||||
# Add new routes for model management similar to LoRA routes
|
# Add new routes for model management similar to LoRA routes
|
||||||
app.router.add_post('/api/checkpoints/delete', self.delete_model)
|
app.router.add_post('/api/checkpoints/delete', self.delete_model)
|
||||||
|
app.router.add_post('/api/checkpoints/exclude', self.exclude_model) # Add new exclude endpoint
|
||||||
app.router.add_post('/api/checkpoints/fetch-civitai', self.fetch_civitai)
|
app.router.add_post('/api/checkpoints/fetch-civitai', self.fetch_civitai)
|
||||||
app.router.add_post('/api/checkpoints/replace-preview', self.replace_preview)
|
app.router.add_post('/api/checkpoints/replace-preview', self.replace_preview)
|
||||||
app.router.add_post('/api/checkpoints/download', self.download_checkpoint)
|
app.router.add_post('/api/checkpoints/download', self.download_checkpoint)
|
||||||
@@ -499,6 +500,10 @@ class CheckpointsRoutes:
|
|||||||
async def delete_model(self, request: web.Request) -> web.Response:
|
async def delete_model(self, request: web.Request) -> web.Response:
|
||||||
"""Handle checkpoint model deletion request"""
|
"""Handle checkpoint model deletion request"""
|
||||||
return await ModelRouteUtils.handle_delete_model(request, self.scanner)
|
return await ModelRouteUtils.handle_delete_model(request, self.scanner)
|
||||||
|
|
||||||
|
async def exclude_model(self, request: web.Request) -> web.Response:
|
||||||
|
"""Handle checkpoint model exclusion request"""
|
||||||
|
return await ModelRouteUtils.handle_exclude_model(request, self.scanner)
|
||||||
|
|
||||||
async def fetch_civitai(self, request: web.Request) -> web.Response:
|
async def fetch_civitai(self, request: web.Request) -> web.Response:
|
||||||
"""Handle CivitAI metadata fetch request for checkpoints"""
|
"""Handle CivitAI metadata fetch request for checkpoints"""
|
||||||
@@ -653,7 +658,7 @@ class CheckpointsRoutes:
|
|||||||
model_type = response.get('type', '')
|
model_type = response.get('type', '')
|
||||||
|
|
||||||
# Check model type - should be Checkpoint
|
# Check model type - should be Checkpoint
|
||||||
if model_type.lower() != 'checkpoint':
|
if (model_type.lower() != 'checkpoint'):
|
||||||
return web.json_response({
|
return web.json_response({
|
||||||
'error': f"Model type mismatch. Expected Checkpoint, got {model_type}"
|
'error': f"Model type mismatch. Expected Checkpoint, got {model_type}"
|
||||||
}, status=400)
|
}, status=400)
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ class ModelScanner:
|
|||||||
self._hash_index = hash_index or ModelHashIndex()
|
self._hash_index = hash_index or ModelHashIndex()
|
||||||
self._tags_count = {} # Dictionary to store tag counts
|
self._tags_count = {} # Dictionary to store tag counts
|
||||||
self._is_initializing = False # Flag to track initialization state
|
self._is_initializing = False # Flag to track initialization state
|
||||||
|
self._excluded_models = [] # List to track excluded models
|
||||||
|
|
||||||
# Register this service
|
# Register this service
|
||||||
asyncio.create_task(self._register_service())
|
asyncio.create_task(self._register_service())
|
||||||
@@ -394,6 +395,9 @@ class ModelScanner:
|
|||||||
if file_path in cached_paths:
|
if file_path in cached_paths:
|
||||||
found_paths.add(file_path)
|
found_paths.add(file_path)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if file_path in self._excluded_models:
|
||||||
|
continue
|
||||||
|
|
||||||
# Try case-insensitive match on Windows
|
# Try case-insensitive match on Windows
|
||||||
if os.name == 'nt':
|
if os.name == 'nt':
|
||||||
@@ -406,7 +410,7 @@ class ModelScanner:
|
|||||||
break
|
break
|
||||||
if matched:
|
if matched:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# This is a new file to process
|
# This is a new file to process
|
||||||
new_files.append(file_path)
|
new_files.append(file_path)
|
||||||
|
|
||||||
@@ -586,6 +590,11 @@ class ModelScanner:
|
|||||||
|
|
||||||
model_data = metadata.to_dict()
|
model_data = metadata.to_dict()
|
||||||
|
|
||||||
|
# Skip excluded models
|
||||||
|
if model_data.get('exclude', False):
|
||||||
|
self._excluded_models.append(model_data['file_path'])
|
||||||
|
return None
|
||||||
|
|
||||||
await self._fetch_missing_metadata(file_path, model_data)
|
await self._fetch_missing_metadata(file_path, model_data)
|
||||||
rel_path = os.path.relpath(file_path, root_path)
|
rel_path = os.path.relpath(file_path, root_path)
|
||||||
folder = os.path.dirname(rel_path)
|
folder = os.path.dirname(rel_path)
|
||||||
@@ -905,6 +914,10 @@ class ModelScanner:
|
|||||||
logger.error(f"Error getting model info by name: {e}", exc_info=True)
|
logger.error(f"Error getting model info by name: {e}", exc_info=True)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_excluded_models(self) -> List[str]:
|
||||||
|
"""Get list of excluded model file paths"""
|
||||||
|
return self._excluded_models.copy()
|
||||||
|
|
||||||
async def update_preview_in_cache(self, file_path: str, preview_url: str) -> bool:
|
async def update_preview_in_cache(self, file_path: str, preview_url: str) -> bool:
|
||||||
"""Update preview URL in cache for a specific lora
|
"""Update preview URL in cache for a specific lora
|
||||||
|
|
||||||
@@ -918,4 +931,4 @@ class ModelScanner:
|
|||||||
if self._cache is None:
|
if self._cache is None:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return await self._cache.update_preview_url(file_path, preview_url)
|
return await self._cache.update_preview_url(file_path, preview_url)
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ class BaseModelMetadata:
|
|||||||
modelDescription: str = "" # Full model description
|
modelDescription: str = "" # Full model description
|
||||||
civitai_deleted: bool = False # Whether deleted from Civitai
|
civitai_deleted: bool = False # Whether deleted from Civitai
|
||||||
favorite: bool = False # Whether the model is a favorite
|
favorite: bool = False # Whether the model is a favorite
|
||||||
|
exclude: bool = False # Whether to exclude this model from the cache
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# Initialize empty lists to avoid mutable default parameter issue
|
# Initialize empty lists to avoid mutable default parameter issue
|
||||||
|
|||||||
@@ -425,6 +425,65 @@ class ModelRouteUtils:
|
|||||||
logger.error(f"Error replacing preview: {e}", exc_info=True)
|
logger.error(f"Error replacing preview: {e}", exc_info=True)
|
||||||
return web.Response(text=str(e), status=500)
|
return web.Response(text=str(e), status=500)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def handle_exclude_model(request: web.Request, scanner) -> web.Response:
|
||||||
|
"""Handle model exclusion request
|
||||||
|
|
||||||
|
Args:
|
||||||
|
request: The aiohttp request
|
||||||
|
scanner: The model scanner instance with cache management methods
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
web.Response: The HTTP response
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
data = await request.json()
|
||||||
|
file_path = data.get('file_path')
|
||||||
|
if not file_path:
|
||||||
|
return web.Response(text='Model path is required', status=400)
|
||||||
|
|
||||||
|
# Update metadata to mark as excluded
|
||||||
|
metadata_path = os.path.splitext(file_path)[0] + '.metadata.json'
|
||||||
|
metadata = await ModelRouteUtils.load_local_metadata(metadata_path)
|
||||||
|
metadata['exclude'] = True
|
||||||
|
|
||||||
|
# Save updated metadata
|
||||||
|
with open(metadata_path, 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(metadata, f, indent=2, ensure_ascii=False)
|
||||||
|
|
||||||
|
# Update cache
|
||||||
|
cache = await scanner.get_cached_data()
|
||||||
|
|
||||||
|
# Find and remove model from cache
|
||||||
|
model_to_remove = next((item for item in cache.raw_data if item['file_path'] == file_path), None)
|
||||||
|
if model_to_remove:
|
||||||
|
# Update tags count
|
||||||
|
for tag in model_to_remove.get('tags', []):
|
||||||
|
if tag in scanner._tags_count:
|
||||||
|
scanner._tags_count[tag] = max(0, scanner._tags_count[tag] - 1)
|
||||||
|
if scanner._tags_count[tag] == 0:
|
||||||
|
del scanner._tags_count[tag]
|
||||||
|
|
||||||
|
# Remove from hash index if available
|
||||||
|
if hasattr(scanner, '_hash_index') and scanner._hash_index:
|
||||||
|
scanner._hash_index.remove_by_path(file_path)
|
||||||
|
|
||||||
|
# Remove from cache data
|
||||||
|
cache.raw_data = [item for item in cache.raw_data if item['file_path'] != file_path]
|
||||||
|
await cache.resort()
|
||||||
|
|
||||||
|
# Add to excluded models list
|
||||||
|
scanner._excluded_models.append(file_path)
|
||||||
|
|
||||||
|
return web.json_response({
|
||||||
|
'success': True,
|
||||||
|
'message': f"Model {os.path.basename(file_path)} excluded"
|
||||||
|
})
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error excluding model: {e}", exc_info=True)
|
||||||
|
return web.Response(text=str(e), status=500)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def handle_download_model(request: web.Request, download_manager: DownloadManager, model_type="lora") -> web.Response:
|
async def handle_download_model(request: web.Request, download_manager: DownloadManager, model_type="lora") -> web.Response:
|
||||||
"""Handle model download request
|
"""Handle model download request
|
||||||
@@ -501,4 +560,4 @@ class ModelRouteUtils:
|
|||||||
)
|
)
|
||||||
|
|
||||||
logger.error(f"Error downloading {model_type}: {error_message}")
|
logger.error(f"Error downloading {model_type}: {error_message}")
|
||||||
return web.Response(status=500, text=error_message)
|
return web.Response(status=500, text=error_message)
|
||||||
|
|||||||
Reference in New Issue
Block a user