feat(duplicates): add filter support for duplicate model finding, #783

This commit is contained in:
Will Miao
2026-02-04 20:46:16 +08:00
parent 36e3e62e70
commit b7e0821f66
3 changed files with 174 additions and 19 deletions

View File

@@ -6,6 +6,7 @@ import asyncio
import json import json
import logging import logging
import os import os
import re
import time import time
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Awaitable, Callable, Dict, Iterable, List, Mapping, Optional from typing import Any, Awaitable, Callable, Dict, Iterable, List, Mapping, Optional
@@ -755,19 +756,22 @@ class ModelQueryHandler:
async def find_duplicate_models(self, request: web.Request) -> web.Response: async def find_duplicate_models(self, request: web.Request) -> web.Response:
try: try:
filters = self._parse_duplicate_filters(request)
duplicates = self._service.find_duplicate_hashes() duplicates = self._service.find_duplicate_hashes()
result = [] result = []
cache = await self._service.scanner.get_cached_data() cache = await self._service.scanner.get_cached_data()
for sha256, paths in duplicates.items(): for sha256, paths in duplicates.items():
group = {"hash": sha256, "models": []} # Collect all models in this group
all_models = []
for path in paths: for path in paths:
model = next( model = next(
(m for m in cache.raw_data if m["file_path"] == path), None (m for m in cache.raw_data if m["file_path"] == path), None
) )
if model: if model:
group["models"].append( all_models.append(model)
await self._service.format_response(model)
) # Include primary if not already in paths
primary_path = self._service.get_path_by_hash(sha256) primary_path = self._service.get_path_by_hash(sha256)
if primary_path and primary_path not in paths: if primary_path and primary_path not in paths:
primary_model = next( primary_model = next(
@@ -775,11 +779,25 @@ class ModelQueryHandler:
None, None,
) )
if primary_model: if primary_model:
group["models"].insert( all_models.insert(0, primary_model)
0, await self._service.format_response(primary_model)
) # Apply filters
filtered = self._apply_duplicate_filters(all_models, filters)
# Sort: originals first, copies last
sorted_models = self._sort_duplicate_group(filtered)
# Format response
group = {"hash": sha256, "models": []}
for model in sorted_models:
group["models"].append(
await self._service.format_response(model)
)
# Only include groups with 2+ models after filtering
if len(group["models"]) > 1: if len(group["models"]) > 1:
result.append(group) result.append(group)
return web.json_response( return web.json_response(
{"success": True, "duplicates": result, "count": len(result)} {"success": True, "duplicates": result, "count": len(result)}
) )
@@ -792,6 +810,83 @@ class ModelQueryHandler:
) )
return web.json_response({"success": False, "error": str(exc)}, status=500) return web.json_response({"success": False, "error": str(exc)}, status=500)
def _parse_duplicate_filters(self, request: web.Request) -> Dict[str, Any]:
"""Parse filter parameters from the request for duplicate finding."""
return {
"base_models": request.query.getall("base_model", []),
"tag_include": request.query.getall("tag_include", []),
"tag_exclude": request.query.getall("tag_exclude", []),
"model_types": request.query.getall("model_type", []),
"folder": request.query.get("folder"),
"favorites_only": request.query.get("favorites_only", "").lower() == "true",
}
def _apply_duplicate_filters(self, models: List[Dict[str, Any]], filters: Dict[str, Any]) -> List[Dict[str, Any]]:
"""Apply filters to a list of models within a duplicate group."""
result = models
# Apply base model filter
if filters.get("base_models"):
base_set = set(filters["base_models"])
result = [m for m in result if m.get("base_model") in base_set]
# Apply tag filters (include)
for tag in filters.get("tag_include", []):
if tag == "__no_tags__":
result = [m for m in result if not m.get("tags")]
else:
result = [m for m in result if tag in (m.get("tags") or [])]
# Apply tag filters (exclude)
for tag in filters.get("tag_exclude", []):
if tag == "__no_tags__":
result = [m for m in result if m.get("tags")]
else:
result = [m for m in result if tag not in (m.get("tags") or [])]
# Apply model type filter
if filters.get("model_types"):
type_set = {t.lower() for t in filters["model_types"]}
result = [
m for m in result if (m.get("model_type") or "").lower() in type_set
]
# Apply folder filter
if filters.get("folder"):
folder = filters["folder"]
result = [m for m in result if m.get("folder", "").startswith(folder)]
# Apply favorites filter
if filters.get("favorites_only"):
result = [m for m in result if m.get("favorite", False)]
return result
def _sort_duplicate_group(self, models: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""Sort models: originals first (left), copies (with -????. pattern) last (right)."""
if len(models) <= 1:
return models
min_len = min(len(m.get("file_name", "")) for m in models)
def copy_score(m):
fn = m.get("file_name", "")
score = 0
# Match -0001.safetensors, -1234.safetensors etc.
if re.search(r"-\d{4}\.", fn):
score += 100
# Match (1), (2) etc.
if re.search(r"\(\d+\)", fn):
score += 50
# Match 'copy' in filename
if "copy" in fn.lower():
score += 50
# Longer filenames are more likely copies
score += len(fn) - min_len
return (score, fn.lower())
return sorted(models, key=copy_score)
async def find_filename_conflicts(self, request: web.Request) -> web.Response: async def find_filename_conflicts(self, request: web.Request) -> web.Response:
try: try:
duplicates = self._service.find_duplicate_filenames() duplicates = self._service.find_duplicate_filenames()

View File

@@ -48,15 +48,18 @@ export class ModelDuplicatesManager {
// Method to check for duplicates count using existing endpoint // Method to check for duplicates count using existing endpoint
async checkDuplicatesCount() { async checkDuplicatesCount() {
try { try {
const params = this._buildFilterQueryParams();
const endpoint = `/api/lm/${this.modelType}/find-duplicates`; const endpoint = `/api/lm/${this.modelType}/find-duplicates`;
const response = await fetch(endpoint); const url = params.toString() ? `${endpoint}?${params}` : endpoint;
const response = await fetch(url);
if (!response.ok) { if (!response.ok) {
throw new Error(`Failed to get duplicates count: ${response.statusText}`); throw new Error(`Failed to get duplicates count: ${response.statusText}`);
} }
const data = await response.json(); const data = await response.json();
if (data.success) { if (data.success) {
const duplicatesCount = (data.duplicates || []).length; const duplicatesCount = (data.duplicates || []).length;
this.updateDuplicatesBadge(duplicatesCount); this.updateDuplicatesBadge(duplicatesCount);
@@ -103,29 +106,30 @@ export class ModelDuplicatesManager {
async findDuplicates() { async findDuplicates() {
try { try {
// Determine API endpoint based on model type const params = this._buildFilterQueryParams();
const endpoint = `/api/lm/${this.modelType}/find-duplicates`; const endpoint = `/api/lm/${this.modelType}/find-duplicates`;
const url = params.toString() ? `${endpoint}?${params}` : endpoint;
const response = await fetch(endpoint);
const response = await fetch(url);
if (!response.ok) { if (!response.ok) {
throw new Error(`Failed to find duplicates: ${response.statusText}`); throw new Error(`Failed to find duplicates: ${response.statusText}`);
} }
const data = await response.json(); const data = await response.json();
if (!data.success) { if (!data.success) {
throw new Error(data.error || 'Unknown error finding duplicates'); throw new Error(data.error || 'Unknown error finding duplicates');
} }
this.duplicateGroups = data.duplicates || []; this.duplicateGroups = data.duplicates || [];
// Update the badge with the current count // Update the badge with the current count
this.updateDuplicatesBadge(this.duplicateGroups.length); this.updateDuplicatesBadge(this.duplicateGroups.length);
if (this.duplicateGroups.length === 0) { if (this.duplicateGroups.length === 0) {
showToast('toast.duplicates.noDuplicatesFound', { type: this.modelType }, 'info'); showToast('toast.duplicates.noDuplicatesFound', { type: this.modelType }, 'info');
return false; return false;
} }
this.enterDuplicateMode(); this.enterDuplicateMode();
return true; return true;
} catch (error) { } catch (error) {
@@ -134,6 +138,51 @@ export class ModelDuplicatesManager {
return false; return false;
} }
} }
/**
* Build query parameters from current filter state for duplicate finding.
* @returns {URLSearchParams} The query parameters to append to the API endpoint
*/
_buildFilterQueryParams() {
const params = new URLSearchParams();
const pageState = getCurrentPageState();
const filters = pageState?.filters;
if (!filters) return params;
// Base model filters
if (filters.baseModel && Array.isArray(filters.baseModel)) {
filters.baseModel.forEach(m => params.append('base_model', m));
}
// Tag filters (tri-state: include/exclude)
if (filters.tags && typeof filters.tags === 'object') {
Object.entries(filters.tags).forEach(([tag, state]) => {
if (state === 'include') {
params.append('tag_include', tag);
} else if (state === 'exclude') {
params.append('tag_exclude', tag);
}
});
}
// Model type filters
if (filters.modelTypes && Array.isArray(filters.modelTypes)) {
filters.modelTypes.forEach(t => params.append('model_type', t));
}
// Folder filter (from active folder state)
if (pageState.activeFolder) {
params.append('folder', pageState.activeFolder);
}
// Favorites filter
if (pageState.showFavoritesOnly) {
params.append('favorites_only', 'true');
}
return params;
}
enterDuplicateMode() { enterDuplicateMode() {
this.inDuplicateMode = true; this.inDuplicateMode = true;

View File

@@ -549,6 +549,17 @@ export class FilterManager {
showToast('toast.filters.cleared', {}, 'info'); showToast('toast.filters.cleared', {}, 'info');
} }
} }
// Refresh duplicates with new filters
if (window.modelDuplicatesManager) {
if (window.modelDuplicatesManager.inDuplicateMode) {
// In duplicate mode: refresh the duplicate list
await window.modelDuplicatesManager.findDuplicates();
} else {
// Not in duplicate mode: just update badge count
window.modelDuplicatesManager.checkDuplicatesCount();
}
}
} }
async clearFilters() { async clearFilters() {