diff --git a/py/nodes/save_image.py b/py/nodes/save_image.py index 2d9dc084..f2d7a119 100644 --- a/py/nodes/save_image.py +++ b/py/nodes/save_image.py @@ -54,8 +54,14 @@ class SaveImage: async def get_lora_hash(self, lora_name): """Get the lora hash from cache""" scanner = await LoraScanner.get_instance() - cache = await scanner.get_cached_data() + # Use the new direct filename lookup method + hash_value = scanner.get_hash_by_filename(lora_name) + if hash_value: + return hash_value + + # Fallback to old method for compatibility + cache = await scanner.get_cached_data() for item in cache.raw_data: if item.get('file_name') == lora_name: return item.get('sha256') @@ -64,7 +70,6 @@ class SaveImage: async def get_checkpoint_hash(self, checkpoint_path): """Get the checkpoint hash from cache""" scanner = await CheckpointScanner.get_instance() - cache = await scanner.get_cached_data() if not checkpoint_path: return None @@ -73,7 +78,13 @@ class SaveImage: checkpoint_name = os.path.basename(checkpoint_path) checkpoint_name = os.path.splitext(checkpoint_name)[0] - # Normalize path separators for comparison + # Try direct filename lookup first + hash_value = scanner.get_hash_by_filename(checkpoint_name) + if hash_value: + return hash_value + + # Fallback to old method for compatibility + cache = await scanner.get_cached_data() normalized_path = checkpoint_path.replace('\\', '/') for item in cache.raw_data: diff --git a/py/services/lora_scanner.py b/py/services/lora_scanner.py index 94dee45d..9756e8f1 100644 --- a/py/services/lora_scanner.py +++ b/py/services/lora_scanner.py @@ -9,7 +9,7 @@ from typing import List, Dict, Optional, Set from ..utils.models import LoraMetadata from ..config import config from .model_scanner import ModelScanner -from .lora_hash_index import LoraHashIndex +from .model_hash_index import ModelHashIndex # Changed from LoraHashIndex to ModelHashIndex from .settings_manager import settings from ..utils.constants import NSFW_LEVELS from ..utils.utils import fuzzy_match @@ -35,12 +35,12 @@ class LoraScanner(ModelScanner): # Define supported file extensions file_extensions = {'.safetensors'} - # Initialize parent class + # Initialize parent class with ModelHashIndex super().__init__( model_type="lora", model_class=LoraMetadata, file_extensions=file_extensions, - hash_index=LoraHashIndex() + hash_index=ModelHashIndex() # Changed from LoraHashIndex to ModelHashIndex ) self._initialized = True diff --git a/py/services/model_hash_index.py b/py/services/model_hash_index.py index 2f8ef0eb..caab788a 100644 --- a/py/services/model_hash_index.py +++ b/py/services/model_hash_index.py @@ -1,11 +1,12 @@ from typing import Dict, Optional, Set +import os class ModelHashIndex: """Index for looking up models by hash or path""" def __init__(self): self._hash_to_path: Dict[str, str] = {} - self._path_to_hash: Dict[str, str] = {} + self._filename_to_hash: Dict[str, str] = {} # Changed from path_to_hash to filename_to_hash def add_entry(self, sha256: str, file_path: str) -> None: """Add or update hash index entry""" @@ -15,37 +16,47 @@ class ModelHashIndex: # Ensure hash is lowercase for consistency sha256 = sha256.lower() + # Extract filename without extension + filename = self._get_filename_from_path(file_path) + # Remove old path mapping if hash exists if sha256 in self._hash_to_path: old_path = self._hash_to_path[sha256] - if old_path in self._path_to_hash: - del self._path_to_hash[old_path] + old_filename = self._get_filename_from_path(old_path) + if old_filename in self._filename_to_hash: + del self._filename_to_hash[old_filename] - # Remove old hash mapping if path exists - if file_path in self._path_to_hash: - old_hash = self._path_to_hash[file_path] + # Remove old hash mapping if filename exists + if filename in self._filename_to_hash: + old_hash = self._filename_to_hash[filename] if old_hash in self._hash_to_path: del self._hash_to_path[old_hash] # Add new mappings self._hash_to_path[sha256] = file_path - self._path_to_hash[file_path] = sha256 + self._filename_to_hash[filename] = sha256 + + def _get_filename_from_path(self, file_path: str) -> str: + """Extract filename without extension from path""" + return os.path.splitext(os.path.basename(file_path))[0] def remove_by_path(self, file_path: str) -> None: """Remove entry by file path""" - if file_path in self._path_to_hash: - hash_val = self._path_to_hash[file_path] + filename = self._get_filename_from_path(file_path) + if filename in self._filename_to_hash: + hash_val = self._filename_to_hash[filename] if hash_val in self._hash_to_path: del self._hash_to_path[hash_val] - del self._path_to_hash[file_path] + del self._filename_to_hash[filename] def remove_by_hash(self, sha256: str) -> None: """Remove entry by hash""" sha256 = sha256.lower() if sha256 in self._hash_to_path: path = self._hash_to_path[sha256] - if path in self._path_to_hash: - del self._path_to_hash[path] + filename = self._get_filename_from_path(path) + if filename in self._filename_to_hash: + del self._filename_to_hash[filename] del self._hash_to_path[sha256] def has_hash(self, sha256: str) -> bool: @@ -58,20 +69,27 @@ class ModelHashIndex: def get_hash(self, file_path: str) -> Optional[str]: """Get hash for a file path""" - return self._path_to_hash.get(file_path) + filename = self._get_filename_from_path(file_path) + return self._filename_to_hash.get(filename) + + def get_hash_by_filename(self, filename: str) -> Optional[str]: + """Get hash for a filename without extension""" + # Strip extension if present to make the function more flexible + filename = os.path.splitext(filename)[0] + return self._filename_to_hash.get(filename) def clear(self) -> None: """Clear all entries""" self._hash_to_path.clear() - self._path_to_hash.clear() + self._filename_to_hash.clear() def get_all_hashes(self) -> Set[str]: """Get all hashes in the index""" return set(self._hash_to_path.keys()) - def get_all_paths(self) -> Set[str]: - """Get all file paths in the index""" - return set(self._path_to_hash.keys()) + def get_all_filenames(self) -> Set[str]: + """Get all filenames in the index""" + return set(self._filename_to_hash.keys()) def __len__(self) -> int: """Get number of entries""" diff --git a/py/services/model_scanner.py b/py/services/model_scanner.py index 6febd33a..93f622b9 100644 --- a/py/services/model_scanner.py +++ b/py/services/model_scanner.py @@ -832,6 +832,10 @@ class ModelScanner: def get_hash_by_path(self, file_path: str) -> Optional[str]: """Get hash for a model by its file path""" return self._hash_index.get_hash(file_path) + + def get_hash_by_filename(self, filename: str) -> Optional[str]: + """Get hash for a model by its filename without path""" + return self._hash_index.get_hash_by_filename(filename) # TODO: Adjust this method to use metadata instead of finding the file def get_preview_url_by_hash(self, sha256: str) -> Optional[str]: