mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 07:05:43 -03:00
feat: Implement filename-based hash retrieval in LoraScanner and ModelScanner for improved compatibility
This commit is contained in:
@@ -54,8 +54,14 @@ class SaveImage:
|
|||||||
async def get_lora_hash(self, lora_name):
|
async def get_lora_hash(self, lora_name):
|
||||||
"""Get the lora hash from cache"""
|
"""Get the lora hash from cache"""
|
||||||
scanner = await LoraScanner.get_instance()
|
scanner = await LoraScanner.get_instance()
|
||||||
cache = await scanner.get_cached_data()
|
|
||||||
|
|
||||||
|
# Use the new direct filename lookup method
|
||||||
|
hash_value = scanner.get_hash_by_filename(lora_name)
|
||||||
|
if hash_value:
|
||||||
|
return hash_value
|
||||||
|
|
||||||
|
# Fallback to old method for compatibility
|
||||||
|
cache = await scanner.get_cached_data()
|
||||||
for item in cache.raw_data:
|
for item in cache.raw_data:
|
||||||
if item.get('file_name') == lora_name:
|
if item.get('file_name') == lora_name:
|
||||||
return item.get('sha256')
|
return item.get('sha256')
|
||||||
@@ -64,7 +70,6 @@ class SaveImage:
|
|||||||
async def get_checkpoint_hash(self, checkpoint_path):
|
async def get_checkpoint_hash(self, checkpoint_path):
|
||||||
"""Get the checkpoint hash from cache"""
|
"""Get the checkpoint hash from cache"""
|
||||||
scanner = await CheckpointScanner.get_instance()
|
scanner = await CheckpointScanner.get_instance()
|
||||||
cache = await scanner.get_cached_data()
|
|
||||||
|
|
||||||
if not checkpoint_path:
|
if not checkpoint_path:
|
||||||
return None
|
return None
|
||||||
@@ -73,7 +78,13 @@ class SaveImage:
|
|||||||
checkpoint_name = os.path.basename(checkpoint_path)
|
checkpoint_name = os.path.basename(checkpoint_path)
|
||||||
checkpoint_name = os.path.splitext(checkpoint_name)[0]
|
checkpoint_name = os.path.splitext(checkpoint_name)[0]
|
||||||
|
|
||||||
# Normalize path separators for comparison
|
# Try direct filename lookup first
|
||||||
|
hash_value = scanner.get_hash_by_filename(checkpoint_name)
|
||||||
|
if hash_value:
|
||||||
|
return hash_value
|
||||||
|
|
||||||
|
# Fallback to old method for compatibility
|
||||||
|
cache = await scanner.get_cached_data()
|
||||||
normalized_path = checkpoint_path.replace('\\', '/')
|
normalized_path = checkpoint_path.replace('\\', '/')
|
||||||
|
|
||||||
for item in cache.raw_data:
|
for item in cache.raw_data:
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ from typing import List, Dict, Optional, Set
|
|||||||
from ..utils.models import LoraMetadata
|
from ..utils.models import LoraMetadata
|
||||||
from ..config import config
|
from ..config import config
|
||||||
from .model_scanner import ModelScanner
|
from .model_scanner import ModelScanner
|
||||||
from .lora_hash_index import LoraHashIndex
|
from .model_hash_index import ModelHashIndex # Changed from LoraHashIndex to ModelHashIndex
|
||||||
from .settings_manager import settings
|
from .settings_manager import settings
|
||||||
from ..utils.constants import NSFW_LEVELS
|
from ..utils.constants import NSFW_LEVELS
|
||||||
from ..utils.utils import fuzzy_match
|
from ..utils.utils import fuzzy_match
|
||||||
@@ -35,12 +35,12 @@ class LoraScanner(ModelScanner):
|
|||||||
# Define supported file extensions
|
# Define supported file extensions
|
||||||
file_extensions = {'.safetensors'}
|
file_extensions = {'.safetensors'}
|
||||||
|
|
||||||
# Initialize parent class
|
# Initialize parent class with ModelHashIndex
|
||||||
super().__init__(
|
super().__init__(
|
||||||
model_type="lora",
|
model_type="lora",
|
||||||
model_class=LoraMetadata,
|
model_class=LoraMetadata,
|
||||||
file_extensions=file_extensions,
|
file_extensions=file_extensions,
|
||||||
hash_index=LoraHashIndex()
|
hash_index=ModelHashIndex() # Changed from LoraHashIndex to ModelHashIndex
|
||||||
)
|
)
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
|
|
||||||
|
|||||||
@@ -1,11 +1,12 @@
|
|||||||
from typing import Dict, Optional, Set
|
from typing import Dict, Optional, Set
|
||||||
|
import os
|
||||||
|
|
||||||
class ModelHashIndex:
|
class ModelHashIndex:
|
||||||
"""Index for looking up models by hash or path"""
|
"""Index for looking up models by hash or path"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._hash_to_path: Dict[str, str] = {}
|
self._hash_to_path: Dict[str, str] = {}
|
||||||
self._path_to_hash: Dict[str, str] = {}
|
self._filename_to_hash: Dict[str, str] = {} # Changed from path_to_hash to filename_to_hash
|
||||||
|
|
||||||
def add_entry(self, sha256: str, file_path: str) -> None:
|
def add_entry(self, sha256: str, file_path: str) -> None:
|
||||||
"""Add or update hash index entry"""
|
"""Add or update hash index entry"""
|
||||||
@@ -15,37 +16,47 @@ class ModelHashIndex:
|
|||||||
# Ensure hash is lowercase for consistency
|
# Ensure hash is lowercase for consistency
|
||||||
sha256 = sha256.lower()
|
sha256 = sha256.lower()
|
||||||
|
|
||||||
|
# Extract filename without extension
|
||||||
|
filename = self._get_filename_from_path(file_path)
|
||||||
|
|
||||||
# Remove old path mapping if hash exists
|
# Remove old path mapping if hash exists
|
||||||
if sha256 in self._hash_to_path:
|
if sha256 in self._hash_to_path:
|
||||||
old_path = self._hash_to_path[sha256]
|
old_path = self._hash_to_path[sha256]
|
||||||
if old_path in self._path_to_hash:
|
old_filename = self._get_filename_from_path(old_path)
|
||||||
del self._path_to_hash[old_path]
|
if old_filename in self._filename_to_hash:
|
||||||
|
del self._filename_to_hash[old_filename]
|
||||||
|
|
||||||
# Remove old hash mapping if path exists
|
# Remove old hash mapping if filename exists
|
||||||
if file_path in self._path_to_hash:
|
if filename in self._filename_to_hash:
|
||||||
old_hash = self._path_to_hash[file_path]
|
old_hash = self._filename_to_hash[filename]
|
||||||
if old_hash in self._hash_to_path:
|
if old_hash in self._hash_to_path:
|
||||||
del self._hash_to_path[old_hash]
|
del self._hash_to_path[old_hash]
|
||||||
|
|
||||||
# Add new mappings
|
# Add new mappings
|
||||||
self._hash_to_path[sha256] = file_path
|
self._hash_to_path[sha256] = file_path
|
||||||
self._path_to_hash[file_path] = sha256
|
self._filename_to_hash[filename] = sha256
|
||||||
|
|
||||||
|
def _get_filename_from_path(self, file_path: str) -> str:
|
||||||
|
"""Extract filename without extension from path"""
|
||||||
|
return os.path.splitext(os.path.basename(file_path))[0]
|
||||||
|
|
||||||
def remove_by_path(self, file_path: str) -> None:
|
def remove_by_path(self, file_path: str) -> None:
|
||||||
"""Remove entry by file path"""
|
"""Remove entry by file path"""
|
||||||
if file_path in self._path_to_hash:
|
filename = self._get_filename_from_path(file_path)
|
||||||
hash_val = self._path_to_hash[file_path]
|
if filename in self._filename_to_hash:
|
||||||
|
hash_val = self._filename_to_hash[filename]
|
||||||
if hash_val in self._hash_to_path:
|
if hash_val in self._hash_to_path:
|
||||||
del self._hash_to_path[hash_val]
|
del self._hash_to_path[hash_val]
|
||||||
del self._path_to_hash[file_path]
|
del self._filename_to_hash[filename]
|
||||||
|
|
||||||
def remove_by_hash(self, sha256: str) -> None:
|
def remove_by_hash(self, sha256: str) -> None:
|
||||||
"""Remove entry by hash"""
|
"""Remove entry by hash"""
|
||||||
sha256 = sha256.lower()
|
sha256 = sha256.lower()
|
||||||
if sha256 in self._hash_to_path:
|
if sha256 in self._hash_to_path:
|
||||||
path = self._hash_to_path[sha256]
|
path = self._hash_to_path[sha256]
|
||||||
if path in self._path_to_hash:
|
filename = self._get_filename_from_path(path)
|
||||||
del self._path_to_hash[path]
|
if filename in self._filename_to_hash:
|
||||||
|
del self._filename_to_hash[filename]
|
||||||
del self._hash_to_path[sha256]
|
del self._hash_to_path[sha256]
|
||||||
|
|
||||||
def has_hash(self, sha256: str) -> bool:
|
def has_hash(self, sha256: str) -> bool:
|
||||||
@@ -58,20 +69,27 @@ class ModelHashIndex:
|
|||||||
|
|
||||||
def get_hash(self, file_path: str) -> Optional[str]:
|
def get_hash(self, file_path: str) -> Optional[str]:
|
||||||
"""Get hash for a file path"""
|
"""Get hash for a file path"""
|
||||||
return self._path_to_hash.get(file_path)
|
filename = self._get_filename_from_path(file_path)
|
||||||
|
return self._filename_to_hash.get(filename)
|
||||||
|
|
||||||
|
def get_hash_by_filename(self, filename: str) -> Optional[str]:
|
||||||
|
"""Get hash for a filename without extension"""
|
||||||
|
# Strip extension if present to make the function more flexible
|
||||||
|
filename = os.path.splitext(filename)[0]
|
||||||
|
return self._filename_to_hash.get(filename)
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear(self) -> None:
|
||||||
"""Clear all entries"""
|
"""Clear all entries"""
|
||||||
self._hash_to_path.clear()
|
self._hash_to_path.clear()
|
||||||
self._path_to_hash.clear()
|
self._filename_to_hash.clear()
|
||||||
|
|
||||||
def get_all_hashes(self) -> Set[str]:
|
def get_all_hashes(self) -> Set[str]:
|
||||||
"""Get all hashes in the index"""
|
"""Get all hashes in the index"""
|
||||||
return set(self._hash_to_path.keys())
|
return set(self._hash_to_path.keys())
|
||||||
|
|
||||||
def get_all_paths(self) -> Set[str]:
|
def get_all_filenames(self) -> Set[str]:
|
||||||
"""Get all file paths in the index"""
|
"""Get all filenames in the index"""
|
||||||
return set(self._path_to_hash.keys())
|
return set(self._filename_to_hash.keys())
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
"""Get number of entries"""
|
"""Get number of entries"""
|
||||||
|
|||||||
@@ -832,6 +832,10 @@ class ModelScanner:
|
|||||||
def get_hash_by_path(self, file_path: str) -> Optional[str]:
|
def get_hash_by_path(self, file_path: str) -> Optional[str]:
|
||||||
"""Get hash for a model by its file path"""
|
"""Get hash for a model by its file path"""
|
||||||
return self._hash_index.get_hash(file_path)
|
return self._hash_index.get_hash(file_path)
|
||||||
|
|
||||||
|
def get_hash_by_filename(self, filename: str) -> Optional[str]:
|
||||||
|
"""Get hash for a model by its filename without path"""
|
||||||
|
return self._hash_index.get_hash_by_filename(filename)
|
||||||
|
|
||||||
# TODO: Adjust this method to use metadata instead of finding the file
|
# TODO: Adjust this method to use metadata instead of finding the file
|
||||||
def get_preview_url_by_hash(self, sha256: str) -> Optional[str]:
|
def get_preview_url_by_hash(self, sha256: str) -> Optional[str]:
|
||||||
|
|||||||
Reference in New Issue
Block a user