Files
Bjornulf_custom_nodes/API_falAI.py
justumen 50ee5779c0 0.62
2024-12-17 14:38:57 +01:00

269 lines
9.6 KiB
Python

import os
import time
import requests
from PIL import Image
import numpy as np
import torch
import fal_client
from io import BytesIO
import json
import threading
import asyncio
class APIGenerateFalAI:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"api_token": ("STRING", {
"multiline": False,
"default": "",
"display": "Fal.ai API Token"
}),
"model": (["fal-ai/flux-pro/v1.1-ultra", "fal-ai/recraft-v3", "fal-ai/flux-general/image-to-image"], {
"default": "fal-ai/flux-pro/v1.1-ultra"
}),
"prompt": ("STRING", {
"multiline": True,
"default": "A blackhole in space"
}),
"number_of_images": ("INT", {
"default": 1,
"min": 1,
"max": 10,
"step": 1
}),
"seed": ("INT", {
"default": -1,
"min": -1,
"max": 2147483647
}),
"timeout": ("INT", {
"default": 300,
"min": 60,
"max": 1800,
"step": 60,
"display": "Timeout (seconds)"
}),
}
}
RETURN_TYPES = ("IMAGE", "STRING",)
RETURN_NAMES = ("image", "generation_info",)
FUNCTION = "generate"
CATEGORY = "FalAI"
def __init__(self):
self.output_dir = "output/API/FalAI"
self.metadata_dir = "output/API/FalAI/metadata"
os.makedirs(self.output_dir, exist_ok=True)
os.makedirs(self.metadata_dir, exist_ok=True)
self._interrupt_event = threading.Event()
def get_next_number(self):
save_dir = "output/API/FalAI"
os.makedirs(save_dir, exist_ok=True)
files = [f for f in os.listdir(save_dir) if f.endswith('.png')]
if not files:
return 1
numbers = [int(f.split('.')[0]) for f in files]
return max(numbers) + 1
def create_filename(self, number):
# Simply format the number with leading zeros
return f"{number:03d}.png"
def save_image_and_metadata(self, img, generation_info, number):
# Create simple filename
filename = self.create_filename(number)
filepath = os.path.join(self.output_dir, filename)
# Save image
img.save(filepath)
# Create metadata filename based on the image filename
metadata_filename = f"{number:03d}_metadata.json"
metadata_filepath = os.path.join(self.metadata_dir, metadata_filename)
# Save metadata
with open(metadata_filepath, 'w', encoding='utf-8') as f:
json.dump(generation_info, f, indent=4, ensure_ascii=False)
return filepath, metadata_filepath
async def generate_single_image_async(self, input_data, api_token, model):
try:
# Set the environment variable for the API token
os.environ['FAL_KEY'] = api_token
# Submit request and get request ID
handler = await fal_client.submit_async(
model,
arguments=input_data
)
request_id = handler.request_id
print(f"Request ID: {request_id}")
# Wait for the result
result = await fal_client.result_async(model, request_id)
if not result or 'images' not in result or not result['images']:
raise ValueError(f"No valid result received. Result: {result}")
# Get image URL and download image
image_url = result['images'][0]['url']
image_response = requests.get(image_url)
if image_response.status_code != 200:
raise ConnectionError(f"Failed to download image: Status code {image_response.status_code}")
# Process image
img = Image.open(BytesIO(image_response.content))
if img.mode != 'RGB':
img = img.convert('RGB')
# Save metadata and image
number = self.get_next_number()
generation_info = {
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"parameters": input_data,
"result": result,
"request_id": request_id
}
image_path, metadata_path = self.save_image_and_metadata(img, generation_info, number)
print(f"Saved image to: {image_path}")
print(f"Saved metadata to: {metadata_path}")
img_tensor = torch.from_numpy(np.array(img).astype(np.float32) / 255.0)
img_tensor = img_tensor.unsqueeze(0)
return img_tensor, generation_info
except Exception as e:
print(f"Generation error: {str(e)}")
raise Exception(f"Error generating image: {str(e)}")
def generate(self, api_token, model, prompt, number_of_images=1, seed=-1, timeout=300):
if not api_token:
raise ValueError("API token is required")
self._interrupt_event.clear()
empty_image = torch.zeros((1, 1024, 1024, 3)) # Default size
try:
images = []
infos = []
failed_jobs = []
# Create new event loop
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
async def process_all_images():
tasks = []
for i in range(number_of_images):
if self._interrupt_event.is_set():
break
# Create input data for each image
input_data = {"prompt": prompt}
# If seed is provided, increment it for each image
# If seed is -1, generate a random seed for each image
if seed != -1:
current_seed = seed + i
else:
current_seed = np.random.randint(0, 2147483647)
input_data["seed"] = current_seed
tasks.append(self.generate_single_image_async(input_data, api_token, model))
return await asyncio.gather(*tasks, return_exceptions=True)
try:
results = loop.run_until_complete(process_all_images())
finally:
loop.close()
for result in results:
if isinstance(result, Exception):
failed_jobs.append({
'error': str(result)
})
else:
img_tensor, generation_info = result
images.append(img_tensor)
infos.append(generation_info)
if not images:
generation_info = {
"error": "All generation jobs failed",
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"failed_jobs": failed_jobs
}
return (empty_image, json.dumps(generation_info, indent=2))
combined_tensor = torch.cat(images, dim=0)
combined_info = {
"successful_generations": len(images),
"total_requested": number_of_images,
"generation_parameters": {
"prompt": prompt,
"initial_seed": seed
},
"individual_results": infos,
"failed_jobs": failed_jobs if failed_jobs else None
}
return (combined_tensor, json.dumps(combined_info, indent=2))
except Exception as e:
generation_info = {
"error": f"Fal.ai generation failed: {str(e)}",
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
}
return (empty_image, json.dumps(generation_info, indent=2))
def recover_image_by_request_id(self, request_id, api_token, model):
try:
# Set the environment variable for the API token
os.environ['FAL_KEY'] = api_token
result = fal_client.result(model, request_id)
if not result or 'images' not in result or not result['images']:
raise ValueError(f"No valid result for request ID {request_id}")
image_url = result['images'][0]['url']
image_response = requests.get(image_url)
if image_response.status_code != 200:
raise ConnectionError(f"Failed to download image: Status code {image_response.status_code}")
img = Image.open(BytesIO(image_response.content))
if img.mode != 'RGB':
img = img.convert('RGB')
number = self.get_next_number()
generation_info = {
"timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
"result": result,
"request_id": request_id
}
image_path, metadata_path = self.save_image_and_metadata(img, generation_info, number)
img_tensor = torch.from_numpy(np.array(img).astype(np.float32) / 255.0)
img_tensor = img_tensor.unsqueeze(0)
return img_tensor, generation_info
except Exception as e:
raise Exception(f"Error recovering image: {str(e)}")
@classmethod
def IS_CHANGED(cls, **kwargs):
return float("NaN")
def interrupt(self):
print("Interrupting Fal.ai generation...")
self._interrupt_event.set()