refactor: Replace asynchronous service calls with synchronous counterparts in SaveImage and ServiceRegistry. Fixes #282

This commit is contained in:
Will Miao
2025-07-11 22:48:39 +08:00
parent 5de16a78c5
commit bd95e802ec
2 changed files with 18 additions and 24 deletions

View File

@@ -4,8 +4,7 @@ import asyncio
import re import re
import numpy as np import numpy as np
import folder_paths # type: ignore import folder_paths # type: ignore
from ..services.lora_scanner import LoraScanner from ..services.service_registry import ServiceRegistry
from ..services.checkpoint_scanner import CheckpointScanner
from ..metadata_collector.metadata_processor import MetadataProcessor from ..metadata_collector.metadata_processor import MetadataProcessor
from ..metadata_collector import get_metadata from ..metadata_collector import get_metadata
from PIL import Image, PngImagePlugin from PIL import Image, PngImagePlugin
@@ -71,25 +70,20 @@ class SaveImage:
FUNCTION = "process_image" FUNCTION = "process_image"
OUTPUT_NODE = True OUTPUT_NODE = True
async def get_lora_hash(self, lora_name): def get_lora_hash(self, lora_name):
"""Get the lora hash from cache""" """Get the lora hash from cache"""
scanner = await LoraScanner.get_instance() scanner = ServiceRegistry.get_service_sync("lora_scanner")
# Use the new direct filename lookup method # Use the new direct filename lookup method
hash_value = scanner.get_hash_by_filename(lora_name) hash_value = scanner.get_hash_by_filename(lora_name)
if hash_value: if hash_value:
return hash_value return hash_value
# Fallback to old method for compatibility
cache = await scanner.get_cached_data()
for item in cache.raw_data:
if item.get('file_name') == lora_name:
return item.get('sha256')
return None return None
async def get_checkpoint_hash(self, checkpoint_path): def get_checkpoint_hash(self, checkpoint_path):
"""Get the checkpoint hash from cache""" """Get the checkpoint hash from cache"""
scanner = await CheckpointScanner.get_instance() scanner = ServiceRegistry.get_service_sync("checkpoint_scanner")
if not checkpoint_path: if not checkpoint_path:
return None return None
@@ -103,17 +97,9 @@ class SaveImage:
if hash_value: if hash_value:
return hash_value return hash_value
# Fallback to old method for compatibility
cache = await scanner.get_cached_data()
normalized_path = checkpoint_path.replace('\\', '/')
for item in cache.raw_data:
if item.get('file_name') == checkpoint_name and item.get('file_path').endswith(normalized_path):
return item.get('sha256')
return None return None
async def format_metadata(self, metadata_dict): def format_metadata(self, metadata_dict):
"""Format metadata in the requested format similar to userComment example""" """Format metadata in the requested format similar to userComment example"""
if not metadata_dict: if not metadata_dict:
return "" return ""
@@ -140,7 +126,7 @@ class SaveImage:
# Get hash for each lora # Get hash for each lora
for lora_name, strength in lora_matches: for lora_name, strength in lora_matches:
hash_value = await self.get_lora_hash(lora_name) hash_value = self.get_lora_hash(lora_name)
if hash_value: if hash_value:
lora_hashes[lora_name] = hash_value lora_hashes[lora_name] = hash_value
else: else:
@@ -226,7 +212,7 @@ class SaveImage:
checkpoint = metadata_dict.get('checkpoint') checkpoint = metadata_dict.get('checkpoint')
if checkpoint is not None: if checkpoint is not None:
# Get model hash # Get model hash
model_hash = await self.get_checkpoint_hash(checkpoint) model_hash = self.get_checkpoint_hash(checkpoint)
# Extract basename without path # Extract basename without path
checkpoint_name = os.path.basename(checkpoint) checkpoint_name = os.path.basename(checkpoint)
@@ -329,8 +315,7 @@ class SaveImage:
raw_metadata = get_metadata() raw_metadata = get_metadata()
metadata_dict = MetadataProcessor.to_dict(raw_metadata, id) metadata_dict = MetadataProcessor.to_dict(raw_metadata, id)
# Get or create metadata asynchronously metadata = self.format_metadata(metadata_dict)
metadata = asyncio.run(self.format_metadata(metadata_dict))
# Process filename_prefix with pattern substitution # Process filename_prefix with pattern substitution
filename_prefix = self.format_filename(filename_prefix, metadata_dict) filename_prefix = self.format_filename(filename_prefix, metadata_dict)

View File

@@ -38,6 +38,15 @@ class ServiceRegistry:
return None return None
return registry._services[service_name] return registry._services[service_name]
@classmethod
def get_service_sync(cls, service_name: str) -> Any:
"""Get a service instance by name (synchronous version)"""
registry = cls.get_instance()
if service_name not in registry._services:
logger.debug(f"Service {service_name} not found in registry")
return None
return registry._services[service_name]
# Convenience methods for common services # Convenience methods for common services
@classmethod @classmethod
async def get_lora_scanner(cls): async def get_lora_scanner(cls):