mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-24 22:52:12 -03:00
Add endpoints for finding duplicate loras and filename conflicts; implement tracking for duplicates in ModelHashIndex and update ModelScanner to handle new data structures.
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from typing import Dict, Optional, Set
|
||||
from typing import Dict, Optional, Set, List
|
||||
import os
|
||||
|
||||
class ModelHashIndex:
|
||||
@@ -6,7 +6,10 @@ class ModelHashIndex:
|
||||
|
||||
def __init__(self):
|
||||
self._hash_to_path: Dict[str, str] = {}
|
||||
self._filename_to_hash: Dict[str, str] = {} # Changed from path_to_hash to filename_to_hash
|
||||
self._filename_to_hash: Dict[str, str] = {}
|
||||
# New data structures for tracking duplicates
|
||||
self._duplicate_hashes: Dict[str, List[str]] = {} # sha256 -> list of paths
|
||||
self._duplicate_filenames: Dict[str, List[str]] = {} # filename -> list of paths
|
||||
|
||||
def add_entry(self, sha256: str, file_path: str) -> None:
|
||||
"""Add or update hash index entry"""
|
||||
@@ -19,6 +22,26 @@ class ModelHashIndex:
|
||||
# Extract filename without extension
|
||||
filename = self._get_filename_from_path(file_path)
|
||||
|
||||
# Track duplicates by hash
|
||||
if sha256 in self._hash_to_path:
|
||||
old_path = self._hash_to_path[sha256]
|
||||
if old_path != file_path: # Only record if it's actually a different path
|
||||
if sha256 not in self._duplicate_hashes:
|
||||
self._duplicate_hashes[sha256] = [old_path]
|
||||
if file_path not in self._duplicate_hashes.get(sha256, []):
|
||||
self._duplicate_hashes.setdefault(sha256, []).append(file_path)
|
||||
|
||||
# Track duplicates by filename
|
||||
if filename in self._filename_to_hash:
|
||||
old_hash = self._filename_to_hash[filename]
|
||||
if old_hash != sha256: # Different models with the same name
|
||||
old_path = self._hash_to_path.get(old_hash)
|
||||
if old_path:
|
||||
if filename not in self._duplicate_filenames:
|
||||
self._duplicate_filenames[filename] = [old_path]
|
||||
if file_path not in self._duplicate_filenames.get(filename, []):
|
||||
self._duplicate_filenames.setdefault(filename, []).append(file_path)
|
||||
|
||||
# Remove old path mapping if hash exists
|
||||
if sha256 in self._hash_to_path:
|
||||
old_path = self._hash_to_path[sha256]
|
||||
@@ -48,6 +71,17 @@ class ModelHashIndex:
|
||||
if hash_val in self._hash_to_path:
|
||||
del self._hash_to_path[hash_val]
|
||||
del self._filename_to_hash[filename]
|
||||
|
||||
# Also clean up from duplicates tracking
|
||||
if filename in self._duplicate_filenames:
|
||||
self._duplicate_filenames[filename] = [p for p in self._duplicate_filenames[filename] if p != file_path]
|
||||
if not self._duplicate_filenames[filename]:
|
||||
del self._duplicate_filenames[filename]
|
||||
|
||||
if hash_val in self._duplicate_hashes:
|
||||
self._duplicate_hashes[hash_val] = [p for p in self._duplicate_hashes[hash_val] if p != file_path]
|
||||
if not self._duplicate_hashes[hash_val]:
|
||||
del self._duplicate_hashes[hash_val]
|
||||
|
||||
def remove_by_hash(self, sha256: str) -> None:
|
||||
"""Remove entry by hash"""
|
||||
@@ -58,6 +92,10 @@ class ModelHashIndex:
|
||||
if filename in self._filename_to_hash:
|
||||
del self._filename_to_hash[filename]
|
||||
del self._hash_to_path[sha256]
|
||||
|
||||
# Clean up from duplicates tracking
|
||||
if sha256 in self._duplicate_hashes:
|
||||
del self._duplicate_hashes[sha256]
|
||||
|
||||
def has_hash(self, sha256: str) -> bool:
|
||||
"""Check if hash exists in index"""
|
||||
@@ -82,6 +120,8 @@ class ModelHashIndex:
|
||||
"""Clear all entries"""
|
||||
self._hash_to_path.clear()
|
||||
self._filename_to_hash.clear()
|
||||
self._duplicate_hashes.clear()
|
||||
self._duplicate_filenames.clear()
|
||||
|
||||
def get_all_hashes(self) -> Set[str]:
|
||||
"""Get all hashes in the index"""
|
||||
@@ -91,6 +131,14 @@ class ModelHashIndex:
|
||||
"""Get all filenames in the index"""
|
||||
return set(self._filename_to_hash.keys())
|
||||
|
||||
def get_duplicate_hashes(self) -> Dict[str, List[str]]:
|
||||
"""Get dictionary of duplicate hashes and their paths"""
|
||||
return self._duplicate_hashes
|
||||
|
||||
def get_duplicate_filenames(self) -> Dict[str, List[str]]:
|
||||
"""Get dictionary of duplicate filenames and their paths"""
|
||||
return self._duplicate_filenames
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Get number of entries"""
|
||||
return len(self._hash_to_path)
|
||||
@@ -19,7 +19,10 @@ from .websocket_manager import ws_manager
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Define cache version to handle future format changes
|
||||
CACHE_VERSION = 1
|
||||
# Version history:
|
||||
# 1 - Initial version
|
||||
# 2 - Added duplicate_filenames and duplicate_hashes tracking
|
||||
CACHE_VERSION = 2
|
||||
|
||||
class ModelScanner:
|
||||
"""Base service for scanning and managing model files"""
|
||||
@@ -107,7 +110,9 @@ class ModelScanner:
|
||||
"raw_data": self._cache.raw_data,
|
||||
"hash_index": {
|
||||
"hash_to_path": self._hash_index._hash_to_path,
|
||||
"filename_to_hash": self._hash_index._filename_to_hash # Fix: changed from path_to_hash to filename_to_hash
|
||||
"filename_to_hash": self._hash_index._filename_to_hash, # Fix: changed from path_to_hash to filename_to_hash
|
||||
"duplicate_hashes": self._hash_index._duplicate_hashes,
|
||||
"duplicate_filenames": self._hash_index._duplicate_filenames
|
||||
},
|
||||
"tags_count": self._tags_count,
|
||||
"dirs_last_modified": self._get_dirs_last_modified()
|
||||
@@ -205,6 +210,8 @@ class ModelScanner:
|
||||
hash_index_data = cache_data.get("hash_index", {})
|
||||
self._hash_index._hash_to_path = hash_index_data.get("hash_to_path", {})
|
||||
self._hash_index._filename_to_hash = hash_index_data.get("filename_to_hash", {}) # Fix: changed from path_to_hash to filename_to_hash
|
||||
self._hash_index._duplicate_hashes = hash_index_data.get("duplicate_hashes", {})
|
||||
self._hash_index._duplicate_filenames = hash_index_data.get("duplicate_filenames", {})
|
||||
|
||||
# Load tags count
|
||||
self._tags_count = cache_data.get("tags_count", {})
|
||||
|
||||
Reference in New Issue
Block a user