diff --git a/py/routes/handlers/model_handlers.py b/py/routes/handlers/model_handlers.py index e6fc99b5..62aae575 100644 --- a/py/routes/handlers/model_handlers.py +++ b/py/routes/handlers/model_handlers.py @@ -6,6 +6,7 @@ import asyncio import json import logging import os +import re import time from dataclasses import dataclass from typing import Any, Awaitable, Callable, Dict, Iterable, List, Mapping, Optional @@ -755,19 +756,22 @@ class ModelQueryHandler: async def find_duplicate_models(self, request: web.Request) -> web.Response: try: + filters = self._parse_duplicate_filters(request) duplicates = self._service.find_duplicate_hashes() result = [] cache = await self._service.scanner.get_cached_data() + for sha256, paths in duplicates.items(): - group = {"hash": sha256, "models": []} + # Collect all models in this group + all_models = [] for path in paths: model = next( (m for m in cache.raw_data if m["file_path"] == path), None ) if model: - group["models"].append( - await self._service.format_response(model) - ) + all_models.append(model) + + # Include primary if not already in paths primary_path = self._service.get_path_by_hash(sha256) if primary_path and primary_path not in paths: primary_model = next( @@ -775,11 +779,25 @@ class ModelQueryHandler: None, ) if primary_model: - group["models"].insert( - 0, await self._service.format_response(primary_model) - ) + all_models.insert(0, primary_model) + + # Apply filters + filtered = self._apply_duplicate_filters(all_models, filters) + + # Sort: originals first, copies last + sorted_models = self._sort_duplicate_group(filtered) + + # Format response + group = {"hash": sha256, "models": []} + for model in sorted_models: + group["models"].append( + await self._service.format_response(model) + ) + + # Only include groups with 2+ models after filtering if len(group["models"]) > 1: result.append(group) + return web.json_response( {"success": True, "duplicates": result, "count": len(result)} ) @@ -792,6 +810,83 @@ class ModelQueryHandler: ) return web.json_response({"success": False, "error": str(exc)}, status=500) + def _parse_duplicate_filters(self, request: web.Request) -> Dict[str, Any]: + """Parse filter parameters from the request for duplicate finding.""" + return { + "base_models": request.query.getall("base_model", []), + "tag_include": request.query.getall("tag_include", []), + "tag_exclude": request.query.getall("tag_exclude", []), + "model_types": request.query.getall("model_type", []), + "folder": request.query.get("folder"), + "favorites_only": request.query.get("favorites_only", "").lower() == "true", + } + + def _apply_duplicate_filters(self, models: List[Dict[str, Any]], filters: Dict[str, Any]) -> List[Dict[str, Any]]: + """Apply filters to a list of models within a duplicate group.""" + result = models + + # Apply base model filter + if filters.get("base_models"): + base_set = set(filters["base_models"]) + result = [m for m in result if m.get("base_model") in base_set] + + # Apply tag filters (include) + for tag in filters.get("tag_include", []): + if tag == "__no_tags__": + result = [m for m in result if not m.get("tags")] + else: + result = [m for m in result if tag in (m.get("tags") or [])] + + # Apply tag filters (exclude) + for tag in filters.get("tag_exclude", []): + if tag == "__no_tags__": + result = [m for m in result if m.get("tags")] + else: + result = [m for m in result if tag not in (m.get("tags") or [])] + + # Apply model type filter + if filters.get("model_types"): + type_set = {t.lower() for t in filters["model_types"]} + result = [ + m for m in result if (m.get("model_type") or "").lower() in type_set + ] + + # Apply folder filter + if filters.get("folder"): + folder = filters["folder"] + result = [m for m in result if m.get("folder", "").startswith(folder)] + + # Apply favorites filter + if filters.get("favorites_only"): + result = [m for m in result if m.get("favorite", False)] + + return result + + def _sort_duplicate_group(self, models: List[Dict[str, Any]]) -> List[Dict[str, Any]]: + """Sort models: originals first (left), copies (with -????. pattern) last (right).""" + if len(models) <= 1: + return models + + min_len = min(len(m.get("file_name", "")) for m in models) + + def copy_score(m): + fn = m.get("file_name", "") + score = 0 + # Match -0001.safetensors, -1234.safetensors etc. + if re.search(r"-\d{4}\.", fn): + score += 100 + # Match (1), (2) etc. + if re.search(r"\(\d+\)", fn): + score += 50 + # Match 'copy' in filename + if "copy" in fn.lower(): + score += 50 + # Longer filenames are more likely copies + score += len(fn) - min_len + return (score, fn.lower()) + + return sorted(models, key=copy_score) + async def find_filename_conflicts(self, request: web.Request) -> web.Response: try: duplicates = self._service.find_duplicate_filenames() diff --git a/static/js/components/ModelDuplicatesManager.js b/static/js/components/ModelDuplicatesManager.js index 33df3779..4a818729 100644 --- a/static/js/components/ModelDuplicatesManager.js +++ b/static/js/components/ModelDuplicatesManager.js @@ -48,15 +48,18 @@ export class ModelDuplicatesManager { // Method to check for duplicates count using existing endpoint async checkDuplicatesCount() { try { + const params = this._buildFilterQueryParams(); const endpoint = `/api/lm/${this.modelType}/find-duplicates`; - const response = await fetch(endpoint); - + const url = params.toString() ? `${endpoint}?${params}` : endpoint; + + const response = await fetch(url); + if (!response.ok) { throw new Error(`Failed to get duplicates count: ${response.statusText}`); } - + const data = await response.json(); - + if (data.success) { const duplicatesCount = (data.duplicates || []).length; this.updateDuplicatesBadge(duplicatesCount); @@ -103,29 +106,30 @@ export class ModelDuplicatesManager { async findDuplicates() { try { - // Determine API endpoint based on model type + const params = this._buildFilterQueryParams(); const endpoint = `/api/lm/${this.modelType}/find-duplicates`; - - const response = await fetch(endpoint); + const url = params.toString() ? `${endpoint}?${params}` : endpoint; + + const response = await fetch(url); if (!response.ok) { throw new Error(`Failed to find duplicates: ${response.statusText}`); } - + const data = await response.json(); if (!data.success) { throw new Error(data.error || 'Unknown error finding duplicates'); } - + this.duplicateGroups = data.duplicates || []; - + // Update the badge with the current count this.updateDuplicatesBadge(this.duplicateGroups.length); - + if (this.duplicateGroups.length === 0) { showToast('toast.duplicates.noDuplicatesFound', { type: this.modelType }, 'info'); return false; } - + this.enterDuplicateMode(); return true; } catch (error) { @@ -134,6 +138,51 @@ export class ModelDuplicatesManager { return false; } } + + /** + * Build query parameters from current filter state for duplicate finding. + * @returns {URLSearchParams} The query parameters to append to the API endpoint + */ + _buildFilterQueryParams() { + const params = new URLSearchParams(); + const pageState = getCurrentPageState(); + const filters = pageState?.filters; + + if (!filters) return params; + + // Base model filters + if (filters.baseModel && Array.isArray(filters.baseModel)) { + filters.baseModel.forEach(m => params.append('base_model', m)); + } + + // Tag filters (tri-state: include/exclude) + if (filters.tags && typeof filters.tags === 'object') { + Object.entries(filters.tags).forEach(([tag, state]) => { + if (state === 'include') { + params.append('tag_include', tag); + } else if (state === 'exclude') { + params.append('tag_exclude', tag); + } + }); + } + + // Model type filters + if (filters.modelTypes && Array.isArray(filters.modelTypes)) { + filters.modelTypes.forEach(t => params.append('model_type', t)); + } + + // Folder filter (from active folder state) + if (pageState.activeFolder) { + params.append('folder', pageState.activeFolder); + } + + // Favorites filter + if (pageState.showFavoritesOnly) { + params.append('favorites_only', 'true'); + } + + return params; + } enterDuplicateMode() { this.inDuplicateMode = true; diff --git a/static/js/managers/FilterManager.js b/static/js/managers/FilterManager.js index b8aef214..3cf61233 100644 --- a/static/js/managers/FilterManager.js +++ b/static/js/managers/FilterManager.js @@ -549,6 +549,17 @@ export class FilterManager { showToast('toast.filters.cleared', {}, 'info'); } } + + // Refresh duplicates with new filters + if (window.modelDuplicatesManager) { + if (window.modelDuplicatesManager.inDuplicateMode) { + // In duplicate mode: refresh the duplicate list + await window.modelDuplicatesManager.findDuplicates(); + } else { + // Not in duplicate mode: just update badge count + window.modelDuplicatesManager.checkDuplicatesCount(); + } + } } async clearFilters() {