Add endpoints for finding duplicate loras and filename conflicts; implement tracking for duplicates in ModelHashIndex and update ModelScanner to handle new data structures.

This commit is contained in:
Will Miao
2025-05-31 20:50:51 +08:00
parent e06d15f508
commit 0bd62eef3a
4 changed files with 255 additions and 4 deletions

View File

@@ -80,6 +80,10 @@ class ApiRoutes:
# Add update check routes
UpdateRoutes.setup_routes(app)
# Add new endpoints for finding duplicates
app.router.add_get('/api/loras/find-duplicates', routes.find_duplicate_loras)
app.router.add_get('/api/loras/find-filename-conflicts', routes.find_filename_conflicts)
async def delete_model(self, request: web.Request) -> web.Response:
"""Handle model deletion request"""
if self.scanner is None:
@@ -1169,3 +1173,97 @@ class ApiRoutes:
'success': False,
'error': str(e)
}, status=500)
async def find_duplicate_loras(self, request: web.Request) -> web.Response:
"""Find loras with duplicate SHA256 hashes"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
# Get duplicate hashes from hash index
duplicates = self.scanner._hash_index.get_duplicate_hashes()
# Format the response
result = []
cache = await self.scanner.get_cached_data()
for sha256, paths in duplicates.items():
group = {
"hash": sha256,
"models": []
}
# Find matching models for each duplicate path
for path in paths:
model = next((m for m in cache.raw_data if m['file_path'] == path), None)
if model:
group["models"].append(self._format_lora_response(model))
# Add the primary model too
primary_path = self.scanner._hash_index.get_path(sha256)
if primary_path and primary_path not in paths:
primary_model = next((m for m in cache.raw_data if m['file_path'] == primary_path), None)
if primary_model:
group["models"].insert(0, self._format_lora_response(primary_model))
if group["models"]: # Only include if we found models
result.append(group)
return web.json_response({
"success": True,
"duplicates": result,
"count": len(result)
})
except Exception as e:
logger.error(f"Error finding duplicate loras: {e}", exc_info=True)
return web.json_response({
"success": False,
"error": str(e)
}, status=500)
async def find_filename_conflicts(self, request: web.Request) -> web.Response:
"""Find loras with conflicting filenames"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_lora_scanner()
# Get duplicate filenames from hash index
duplicates = self.scanner._hash_index.get_duplicate_filenames()
# Format the response
result = []
cache = await self.scanner.get_cached_data()
for filename, paths in duplicates.items():
group = {
"filename": filename,
"models": []
}
# Find matching models for each path
for path in paths:
model = next((m for m in cache.raw_data if m['file_path'] == path), None)
if model:
group["models"].append(self._format_lora_response(model))
# Find the model from the main index too
hash_val = self.scanner._hash_index.get_hash_by_filename(filename)
if hash_val:
main_path = self.scanner._hash_index.get_path(hash_val)
if main_path and main_path not in paths:
main_model = next((m for m in cache.raw_data if m['file_path'] == main_path), None)
if main_model:
group["models"].insert(0, self._format_lora_response(main_model))
if group["models"]: # Only include if we found models
result.append(group)
return web.json_response({
"success": True,
"conflicts": result,
"count": len(result)
})
except Exception as e:
logger.error(f"Error finding filename conflicts: {e}", exc_info=True)
return web.json_response({
"success": False,
"error": str(e)
}, status=500)

View File

@@ -58,6 +58,10 @@ class CheckpointsRoutes:
# Add new WebSocket endpoint for checkpoint progress
app.router.add_get('/ws/checkpoint-progress', ws_manager.handle_checkpoint_connection)
# Add new routes for finding duplicates and filename conflicts
app.router.add_get('/api/checkpoints/find-duplicates', self.find_duplicate_checkpoints)
app.router.add_get('/api/checkpoints/find-filename-conflicts', self.find_filename_conflicts)
async def get_checkpoints(self, request):
"""Get paginated checkpoint data"""
try:
@@ -695,3 +699,97 @@ class CheckpointsRoutes:
except Exception as e:
logger.error(f"Error fetching checkpoint model versions: {e}")
return web.Response(status=500, text=str(e))
async def find_duplicate_checkpoints(self, request: web.Request) -> web.Response:
"""Find checkpoints with duplicate SHA256 hashes"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
# Get duplicate hashes from hash index
duplicates = self.scanner._hash_index.get_duplicate_hashes()
# Format the response
result = []
cache = await self.scanner.get_cached_data()
for sha256, paths in duplicates.items():
group = {
"hash": sha256,
"models": []
}
# Find matching models for each path
for path in paths:
model = next((m for m in cache.raw_data if m['file_path'] == path), None)
if model:
group["models"].append(self._format_checkpoint_response(model))
# Add the primary model too
primary_path = self.scanner._hash_index.get_path(sha256)
if primary_path and primary_path not in paths:
primary_model = next((m for m in cache.raw_data if m['file_path'] == primary_path), None)
if primary_model:
group["models"].insert(0, self._format_checkpoint_response(primary_model))
if group["models"]:
result.append(group)
return web.json_response({
"success": True,
"duplicates": result,
"count": len(result)
})
except Exception as e:
logger.error(f"Error finding duplicate checkpoints: {e}", exc_info=True)
return web.json_response({
"success": False,
"error": str(e)
}, status=500)
async def find_filename_conflicts(self, request: web.Request) -> web.Response:
"""Find checkpoints with conflicting filenames"""
try:
if self.scanner is None:
self.scanner = await ServiceRegistry.get_checkpoint_scanner()
# Get duplicate filenames from hash index
duplicates = self.scanner._hash_index.get_duplicate_filenames()
# Format the response
result = []
cache = await self.scanner.get_cached_data()
for filename, paths in duplicates.items():
group = {
"filename": filename,
"models": []
}
# Find matching models for each path
for path in paths:
model = next((m for m in cache.raw_data if m['file_path'] == path), None)
if model:
group["models"].append(self._format_checkpoint_response(model))
# Find the model from the main index too
hash_val = self.scanner._hash_index.get_hash_by_filename(filename)
if hash_val:
main_path = self.scanner._hash_index.get_path(hash_val)
if main_path and main_path not in paths:
main_model = next((m for m in cache.raw_data if m['file_path'] == main_path), None)
if main_model:
group["models"].insert(0, self._format_checkpoint_response(main_model))
if group["models"]:
result.append(group)
return web.json_response({
"success": True,
"conflicts": result,
"count": len(result)
})
except Exception as e:
logger.error(f"Error finding filename conflicts: {e}", exc_info=True)
return web.json_response({
"success": False,
"error": str(e)
}, status=500)

View File

@@ -1,4 +1,4 @@
from typing import Dict, Optional, Set
from typing import Dict, Optional, Set, List
import os
class ModelHashIndex:
@@ -6,7 +6,10 @@ class ModelHashIndex:
def __init__(self):
self._hash_to_path: Dict[str, str] = {}
self._filename_to_hash: Dict[str, str] = {} # Changed from path_to_hash to filename_to_hash
self._filename_to_hash: Dict[str, str] = {}
# New data structures for tracking duplicates
self._duplicate_hashes: Dict[str, List[str]] = {} # sha256 -> list of paths
self._duplicate_filenames: Dict[str, List[str]] = {} # filename -> list of paths
def add_entry(self, sha256: str, file_path: str) -> None:
"""Add or update hash index entry"""
@@ -19,6 +22,26 @@ class ModelHashIndex:
# Extract filename without extension
filename = self._get_filename_from_path(file_path)
# Track duplicates by hash
if sha256 in self._hash_to_path:
old_path = self._hash_to_path[sha256]
if old_path != file_path: # Only record if it's actually a different path
if sha256 not in self._duplicate_hashes:
self._duplicate_hashes[sha256] = [old_path]
if file_path not in self._duplicate_hashes.get(sha256, []):
self._duplicate_hashes.setdefault(sha256, []).append(file_path)
# Track duplicates by filename
if filename in self._filename_to_hash:
old_hash = self._filename_to_hash[filename]
if old_hash != sha256: # Different models with the same name
old_path = self._hash_to_path.get(old_hash)
if old_path:
if filename not in self._duplicate_filenames:
self._duplicate_filenames[filename] = [old_path]
if file_path not in self._duplicate_filenames.get(filename, []):
self._duplicate_filenames.setdefault(filename, []).append(file_path)
# Remove old path mapping if hash exists
if sha256 in self._hash_to_path:
old_path = self._hash_to_path[sha256]
@@ -48,6 +71,17 @@ class ModelHashIndex:
if hash_val in self._hash_to_path:
del self._hash_to_path[hash_val]
del self._filename_to_hash[filename]
# Also clean up from duplicates tracking
if filename in self._duplicate_filenames:
self._duplicate_filenames[filename] = [p for p in self._duplicate_filenames[filename] if p != file_path]
if not self._duplicate_filenames[filename]:
del self._duplicate_filenames[filename]
if hash_val in self._duplicate_hashes:
self._duplicate_hashes[hash_val] = [p for p in self._duplicate_hashes[hash_val] if p != file_path]
if not self._duplicate_hashes[hash_val]:
del self._duplicate_hashes[hash_val]
def remove_by_hash(self, sha256: str) -> None:
"""Remove entry by hash"""
@@ -58,6 +92,10 @@ class ModelHashIndex:
if filename in self._filename_to_hash:
del self._filename_to_hash[filename]
del self._hash_to_path[sha256]
# Clean up from duplicates tracking
if sha256 in self._duplicate_hashes:
del self._duplicate_hashes[sha256]
def has_hash(self, sha256: str) -> bool:
"""Check if hash exists in index"""
@@ -82,6 +120,8 @@ class ModelHashIndex:
"""Clear all entries"""
self._hash_to_path.clear()
self._filename_to_hash.clear()
self._duplicate_hashes.clear()
self._duplicate_filenames.clear()
def get_all_hashes(self) -> Set[str]:
"""Get all hashes in the index"""
@@ -91,6 +131,14 @@ class ModelHashIndex:
"""Get all filenames in the index"""
return set(self._filename_to_hash.keys())
def get_duplicate_hashes(self) -> Dict[str, List[str]]:
"""Get dictionary of duplicate hashes and their paths"""
return self._duplicate_hashes
def get_duplicate_filenames(self) -> Dict[str, List[str]]:
"""Get dictionary of duplicate filenames and their paths"""
return self._duplicate_filenames
def __len__(self) -> int:
"""Get number of entries"""
return len(self._hash_to_path)

View File

@@ -19,7 +19,10 @@ from .websocket_manager import ws_manager
logger = logging.getLogger(__name__)
# Define cache version to handle future format changes
CACHE_VERSION = 1
# Version history:
# 1 - Initial version
# 2 - Added duplicate_filenames and duplicate_hashes tracking
CACHE_VERSION = 2
class ModelScanner:
"""Base service for scanning and managing model files"""
@@ -107,7 +110,9 @@ class ModelScanner:
"raw_data": self._cache.raw_data,
"hash_index": {
"hash_to_path": self._hash_index._hash_to_path,
"filename_to_hash": self._hash_index._filename_to_hash # Fix: changed from path_to_hash to filename_to_hash
"filename_to_hash": self._hash_index._filename_to_hash, # Fix: changed from path_to_hash to filename_to_hash
"duplicate_hashes": self._hash_index._duplicate_hashes,
"duplicate_filenames": self._hash_index._duplicate_filenames
},
"tags_count": self._tags_count,
"dirs_last_modified": self._get_dirs_last_modified()
@@ -205,6 +210,8 @@ class ModelScanner:
hash_index_data = cache_data.get("hash_index", {})
self._hash_index._hash_to_path = hash_index_data.get("hash_to_path", {})
self._hash_index._filename_to_hash = hash_index_data.get("filename_to_hash", {}) # Fix: changed from path_to_hash to filename_to_hash
self._hash_index._duplicate_hashes = hash_index_data.get("duplicate_hashes", {})
self._hash_index._duplicate_filenames = hash_index_data.get("duplicate_filenames", {})
# Load tags count
self._tags_count = cache_data.get("tags_count", {})