mirror of
https://github.com/willmiao/ComfyUI-Lora-Manager.git
synced 2026-03-25 07:05:43 -03:00
refactor: Replace asynchronous service calls with synchronous counterparts in SaveImage and ServiceRegistry. Fixes #282
This commit is contained in:
@@ -4,8 +4,7 @@ import asyncio
|
|||||||
import re
|
import re
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import folder_paths # type: ignore
|
import folder_paths # type: ignore
|
||||||
from ..services.lora_scanner import LoraScanner
|
from ..services.service_registry import ServiceRegistry
|
||||||
from ..services.checkpoint_scanner import CheckpointScanner
|
|
||||||
from ..metadata_collector.metadata_processor import MetadataProcessor
|
from ..metadata_collector.metadata_processor import MetadataProcessor
|
||||||
from ..metadata_collector import get_metadata
|
from ..metadata_collector import get_metadata
|
||||||
from PIL import Image, PngImagePlugin
|
from PIL import Image, PngImagePlugin
|
||||||
@@ -71,25 +70,20 @@ class SaveImage:
|
|||||||
FUNCTION = "process_image"
|
FUNCTION = "process_image"
|
||||||
OUTPUT_NODE = True
|
OUTPUT_NODE = True
|
||||||
|
|
||||||
async def get_lora_hash(self, lora_name):
|
def get_lora_hash(self, lora_name):
|
||||||
"""Get the lora hash from cache"""
|
"""Get the lora hash from cache"""
|
||||||
scanner = await LoraScanner.get_instance()
|
scanner = ServiceRegistry.get_service_sync("lora_scanner")
|
||||||
|
|
||||||
# Use the new direct filename lookup method
|
# Use the new direct filename lookup method
|
||||||
hash_value = scanner.get_hash_by_filename(lora_name)
|
hash_value = scanner.get_hash_by_filename(lora_name)
|
||||||
if hash_value:
|
if hash_value:
|
||||||
return hash_value
|
return hash_value
|
||||||
|
|
||||||
# Fallback to old method for compatibility
|
|
||||||
cache = await scanner.get_cached_data()
|
|
||||||
for item in cache.raw_data:
|
|
||||||
if item.get('file_name') == lora_name:
|
|
||||||
return item.get('sha256')
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def get_checkpoint_hash(self, checkpoint_path):
|
def get_checkpoint_hash(self, checkpoint_path):
|
||||||
"""Get the checkpoint hash from cache"""
|
"""Get the checkpoint hash from cache"""
|
||||||
scanner = await CheckpointScanner.get_instance()
|
scanner = ServiceRegistry.get_service_sync("checkpoint_scanner")
|
||||||
|
|
||||||
if not checkpoint_path:
|
if not checkpoint_path:
|
||||||
return None
|
return None
|
||||||
@@ -103,17 +97,9 @@ class SaveImage:
|
|||||||
if hash_value:
|
if hash_value:
|
||||||
return hash_value
|
return hash_value
|
||||||
|
|
||||||
# Fallback to old method for compatibility
|
|
||||||
cache = await scanner.get_cached_data()
|
|
||||||
normalized_path = checkpoint_path.replace('\\', '/')
|
|
||||||
|
|
||||||
for item in cache.raw_data:
|
|
||||||
if item.get('file_name') == checkpoint_name and item.get('file_path').endswith(normalized_path):
|
|
||||||
return item.get('sha256')
|
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def format_metadata(self, metadata_dict):
|
def format_metadata(self, metadata_dict):
|
||||||
"""Format metadata in the requested format similar to userComment example"""
|
"""Format metadata in the requested format similar to userComment example"""
|
||||||
if not metadata_dict:
|
if not metadata_dict:
|
||||||
return ""
|
return ""
|
||||||
@@ -140,7 +126,7 @@ class SaveImage:
|
|||||||
|
|
||||||
# Get hash for each lora
|
# Get hash for each lora
|
||||||
for lora_name, strength in lora_matches:
|
for lora_name, strength in lora_matches:
|
||||||
hash_value = await self.get_lora_hash(lora_name)
|
hash_value = self.get_lora_hash(lora_name)
|
||||||
if hash_value:
|
if hash_value:
|
||||||
lora_hashes[lora_name] = hash_value
|
lora_hashes[lora_name] = hash_value
|
||||||
else:
|
else:
|
||||||
@@ -226,7 +212,7 @@ class SaveImage:
|
|||||||
checkpoint = metadata_dict.get('checkpoint')
|
checkpoint = metadata_dict.get('checkpoint')
|
||||||
if checkpoint is not None:
|
if checkpoint is not None:
|
||||||
# Get model hash
|
# Get model hash
|
||||||
model_hash = await self.get_checkpoint_hash(checkpoint)
|
model_hash = self.get_checkpoint_hash(checkpoint)
|
||||||
|
|
||||||
# Extract basename without path
|
# Extract basename without path
|
||||||
checkpoint_name = os.path.basename(checkpoint)
|
checkpoint_name = os.path.basename(checkpoint)
|
||||||
@@ -329,8 +315,7 @@ class SaveImage:
|
|||||||
raw_metadata = get_metadata()
|
raw_metadata = get_metadata()
|
||||||
metadata_dict = MetadataProcessor.to_dict(raw_metadata, id)
|
metadata_dict = MetadataProcessor.to_dict(raw_metadata, id)
|
||||||
|
|
||||||
# Get or create metadata asynchronously
|
metadata = self.format_metadata(metadata_dict)
|
||||||
metadata = asyncio.run(self.format_metadata(metadata_dict))
|
|
||||||
|
|
||||||
# Process filename_prefix with pattern substitution
|
# Process filename_prefix with pattern substitution
|
||||||
filename_prefix = self.format_filename(filename_prefix, metadata_dict)
|
filename_prefix = self.format_filename(filename_prefix, metadata_dict)
|
||||||
|
|||||||
@@ -38,6 +38,15 @@ class ServiceRegistry:
|
|||||||
return None
|
return None
|
||||||
return registry._services[service_name]
|
return registry._services[service_name]
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_service_sync(cls, service_name: str) -> Any:
|
||||||
|
"""Get a service instance by name (synchronous version)"""
|
||||||
|
registry = cls.get_instance()
|
||||||
|
if service_name not in registry._services:
|
||||||
|
logger.debug(f"Service {service_name} not found in registry")
|
||||||
|
return None
|
||||||
|
return registry._services[service_name]
|
||||||
|
|
||||||
# Convenience methods for common services
|
# Convenience methods for common services
|
||||||
@classmethod
|
@classmethod
|
||||||
async def get_lora_scanner(cls):
|
async def get_lora_scanner(cls):
|
||||||
|
|||||||
Reference in New Issue
Block a user