mirror of
https://github.com/jags111/efficiency-nodes-comfyui.git
synced 2026-03-21 21:22:13 -03:00
Efficiency Nodes V2.0
This commit is contained in:
1
py/__init__.py
Normal file
1
py/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
#
|
||||
317
py/bnk_adv_encode.py
Normal file
317
py/bnk_adv_encode.py
Normal file
@@ -0,0 +1,317 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import itertools
|
||||
#from math import gcd
|
||||
|
||||
from comfy import model_management
|
||||
from comfy.sdxl_clip import SDXLClipModel
|
||||
|
||||
def _grouper(n, iterable):
|
||||
it = iter(iterable)
|
||||
while True:
|
||||
chunk = list(itertools.islice(it, n))
|
||||
if not chunk:
|
||||
return
|
||||
yield chunk
|
||||
|
||||
def _norm_mag(w, n):
|
||||
d = w - 1
|
||||
return 1 + np.sign(d) * np.sqrt(np.abs(d)**2 / n)
|
||||
#return np.sign(w) * np.sqrt(np.abs(w)**2 / n)
|
||||
|
||||
def divide_length(word_ids, weights):
|
||||
sums = dict(zip(*np.unique(word_ids, return_counts=True)))
|
||||
sums[0] = 1
|
||||
weights = [[_norm_mag(w, sums[id]) if id != 0 else 1.0
|
||||
for w, id in zip(x, y)] for x, y in zip(weights, word_ids)]
|
||||
return weights
|
||||
|
||||
def shift_mean_weight(word_ids, weights):
|
||||
delta = 1 - np.mean([w for x, y in zip(weights, word_ids) for w, id in zip(x,y) if id != 0])
|
||||
weights = [[w if id == 0 else w+delta
|
||||
for w, id in zip(x, y)] for x, y in zip(weights, word_ids)]
|
||||
return weights
|
||||
|
||||
def scale_to_norm(weights, word_ids, w_max):
|
||||
top = np.max(weights)
|
||||
w_max = min(top, w_max)
|
||||
weights = [[w_max if id == 0 else (w/top) * w_max
|
||||
for w, id in zip(x, y)] for x, y in zip(weights, word_ids)]
|
||||
return weights
|
||||
|
||||
def from_zero(weights, base_emb):
|
||||
weight_tensor = torch.tensor(weights, dtype=base_emb.dtype, device=base_emb.device)
|
||||
weight_tensor = weight_tensor.reshape(1,-1,1).expand(base_emb.shape)
|
||||
return base_emb * weight_tensor
|
||||
|
||||
def mask_word_id(tokens, word_ids, target_id, mask_token):
|
||||
new_tokens = [[mask_token if wid == target_id else t
|
||||
for t, wid in zip(x,y)] for x,y in zip(tokens, word_ids)]
|
||||
mask = np.array(word_ids) == target_id
|
||||
return (new_tokens, mask)
|
||||
|
||||
def batched_clip_encode(tokens, length, encode_func, num_chunks):
|
||||
embs = []
|
||||
for e in _grouper(32, tokens):
|
||||
enc, pooled = encode_func(e)
|
||||
enc = enc.reshape((len(e), length, -1))
|
||||
embs.append(enc)
|
||||
embs = torch.cat(embs)
|
||||
embs = embs.reshape((len(tokens) // num_chunks, length * num_chunks, -1))
|
||||
return embs
|
||||
|
||||
def from_masked(tokens, weights, word_ids, base_emb, length, encode_func, m_token=266):
|
||||
pooled_base = base_emb[0,length-1:length,:]
|
||||
wids, inds = np.unique(np.array(word_ids).reshape(-1), return_index=True)
|
||||
weight_dict = dict((id,w)
|
||||
for id,w in zip(wids ,np.array(weights).reshape(-1)[inds])
|
||||
if w != 1.0)
|
||||
|
||||
if len(weight_dict) == 0:
|
||||
return torch.zeros_like(base_emb), base_emb[0,length-1:length,:]
|
||||
|
||||
weight_tensor = torch.tensor(weights, dtype=base_emb.dtype, device=base_emb.device)
|
||||
weight_tensor = weight_tensor.reshape(1,-1,1).expand(base_emb.shape)
|
||||
|
||||
#m_token = (clip.tokenizer.end_token, 1.0) if clip.tokenizer.pad_with_end else (0,1.0)
|
||||
#TODO: find most suitable masking token here
|
||||
m_token = (m_token, 1.0)
|
||||
|
||||
ws = []
|
||||
masked_tokens = []
|
||||
masks = []
|
||||
|
||||
#create prompts
|
||||
for id, w in weight_dict.items():
|
||||
masked, m = mask_word_id(tokens, word_ids, id, m_token)
|
||||
masked_tokens.extend(masked)
|
||||
|
||||
m = torch.tensor(m, dtype=base_emb.dtype, device=base_emb.device)
|
||||
m = m.reshape(1,-1,1).expand(base_emb.shape)
|
||||
masks.append(m)
|
||||
|
||||
ws.append(w)
|
||||
|
||||
#batch process prompts
|
||||
embs = batched_clip_encode(masked_tokens, length, encode_func, len(tokens))
|
||||
masks = torch.cat(masks)
|
||||
|
||||
embs = (base_emb.expand(embs.shape) - embs)
|
||||
pooled = embs[0,length-1:length,:]
|
||||
|
||||
embs *= masks
|
||||
embs = embs.sum(axis=0, keepdim=True)
|
||||
|
||||
pooled_start = pooled_base.expand(len(ws), -1)
|
||||
ws = torch.tensor(ws).reshape(-1,1).expand(pooled_start.shape)
|
||||
pooled = (pooled - pooled_start) * (ws - 1)
|
||||
pooled = pooled.mean(axis=0, keepdim=True)
|
||||
|
||||
return ((weight_tensor - 1) * embs), pooled_base + pooled
|
||||
|
||||
def mask_inds(tokens, inds, mask_token):
|
||||
clip_len = len(tokens[0])
|
||||
inds_set = set(inds)
|
||||
new_tokens = [[mask_token if i*clip_len + j in inds_set else t
|
||||
for j, t in enumerate(x)] for i, x in enumerate(tokens)]
|
||||
return new_tokens
|
||||
|
||||
def down_weight(tokens, weights, word_ids, base_emb, length, encode_func, m_token=266):
|
||||
w, w_inv = np.unique(weights,return_inverse=True)
|
||||
|
||||
if np.sum(w < 1) == 0:
|
||||
return base_emb, tokens, base_emb[0,length-1:length,:]
|
||||
#m_token = (clip.tokenizer.end_token, 1.0) if clip.tokenizer.pad_with_end else (0,1.0)
|
||||
#using the comma token as a masking token seems to work better than aos tokens for SD 1.x
|
||||
m_token = (m_token, 1.0)
|
||||
|
||||
masked_tokens = []
|
||||
|
||||
masked_current = tokens
|
||||
for i in range(len(w)):
|
||||
if w[i] >= 1:
|
||||
continue
|
||||
masked_current = mask_inds(masked_current, np.where(w_inv == i)[0], m_token)
|
||||
masked_tokens.extend(masked_current)
|
||||
|
||||
embs = batched_clip_encode(masked_tokens, length, encode_func, len(tokens))
|
||||
embs = torch.cat([base_emb, embs])
|
||||
w = w[w<=1.0]
|
||||
w_mix = np.diff([0] + w.tolist())
|
||||
w_mix = torch.tensor(w_mix, dtype=embs.dtype, device=embs.device).reshape((-1,1,1))
|
||||
|
||||
weighted_emb = (w_mix * embs).sum(axis=0, keepdim=True)
|
||||
return weighted_emb, masked_current, weighted_emb[0,length-1:length,:]
|
||||
|
||||
def scale_emb_to_mag(base_emb, weighted_emb):
|
||||
norm_base = torch.linalg.norm(base_emb)
|
||||
norm_weighted = torch.linalg.norm(weighted_emb)
|
||||
embeddings_final = (norm_base / norm_weighted) * weighted_emb
|
||||
return embeddings_final
|
||||
|
||||
def recover_dist(base_emb, weighted_emb):
|
||||
fixed_std = (base_emb.std() / weighted_emb.std()) * (weighted_emb - weighted_emb.mean())
|
||||
embeddings_final = fixed_std + (base_emb.mean() - fixed_std.mean())
|
||||
return embeddings_final
|
||||
|
||||
def A1111_renorm(base_emb, weighted_emb):
|
||||
embeddings_final = (base_emb.mean() / weighted_emb.mean()) * weighted_emb
|
||||
return embeddings_final
|
||||
|
||||
def advanced_encode_from_tokens(tokenized, token_normalization, weight_interpretation, encode_func, m_token=266, length=77, w_max=1.0, return_pooled=False, apply_to_pooled=False):
|
||||
tokens = [[t for t,_,_ in x] for x in tokenized]
|
||||
weights = [[w for _,w,_ in x] for x in tokenized]
|
||||
word_ids = [[wid for _,_,wid in x] for x in tokenized]
|
||||
|
||||
#weight normalization
|
||||
#====================
|
||||
|
||||
#distribute down/up weights over word lengths
|
||||
if token_normalization.startswith("length"):
|
||||
weights = divide_length(word_ids, weights)
|
||||
|
||||
#make mean of word tokens 1
|
||||
if token_normalization.endswith("mean"):
|
||||
weights = shift_mean_weight(word_ids, weights)
|
||||
|
||||
#weight interpretation
|
||||
#=====================
|
||||
pooled = None
|
||||
|
||||
if weight_interpretation == "comfy":
|
||||
weighted_tokens = [[(t,w) for t, w in zip(x, y)] for x, y in zip(tokens, weights)]
|
||||
weighted_emb, pooled_base = encode_func(weighted_tokens)
|
||||
pooled = pooled_base
|
||||
else:
|
||||
unweighted_tokens = [[(t,1.0) for t, _,_ in x] for x in tokenized]
|
||||
base_emb, pooled_base = encode_func(unweighted_tokens)
|
||||
|
||||
if weight_interpretation == "A1111":
|
||||
weighted_emb = from_zero(weights, base_emb)
|
||||
weighted_emb = A1111_renorm(base_emb, weighted_emb)
|
||||
pooled = pooled_base
|
||||
|
||||
if weight_interpretation == "compel":
|
||||
pos_tokens = [[(t,w) if w >= 1.0 else (t,1.0) for t, w in zip(x, y)] for x, y in zip(tokens, weights)]
|
||||
weighted_emb, _ = encode_func(pos_tokens)
|
||||
weighted_emb, _, pooled = down_weight(pos_tokens, weights, word_ids, weighted_emb, length, encode_func)
|
||||
|
||||
if weight_interpretation == "comfy++":
|
||||
weighted_emb, tokens_down, _ = down_weight(unweighted_tokens, weights, word_ids, base_emb, length, encode_func)
|
||||
weights = [[w if w > 1.0 else 1.0 for w in x] for x in weights]
|
||||
#unweighted_tokens = [[(t,1.0) for t, _,_ in x] for x in tokens_down]
|
||||
embs, pooled = from_masked(unweighted_tokens, weights, word_ids, base_emb, length, encode_func)
|
||||
weighted_emb += embs
|
||||
|
||||
if weight_interpretation == "down_weight":
|
||||
weights = scale_to_norm(weights, word_ids, w_max)
|
||||
weighted_emb, _, pooled = down_weight(unweighted_tokens, weights, word_ids, base_emb, length, encode_func)
|
||||
|
||||
if return_pooled:
|
||||
if apply_to_pooled:
|
||||
return weighted_emb, pooled
|
||||
else:
|
||||
return weighted_emb, pooled_base
|
||||
return weighted_emb, None
|
||||
|
||||
def encode_token_weights_g(model, token_weight_pairs):
|
||||
return model.clip_g.encode_token_weights(token_weight_pairs)
|
||||
|
||||
def encode_token_weights_l(model, token_weight_pairs):
|
||||
l_out, _ = model.clip_l.encode_token_weights(token_weight_pairs)
|
||||
return l_out, None
|
||||
|
||||
def encode_token_weights(model, token_weight_pairs, encode_func):
|
||||
if model.layer_idx is not None:
|
||||
model.cond_stage_model.clip_layer(model.layer_idx)
|
||||
|
||||
model_management.load_model_gpu(model.patcher)
|
||||
return encode_func(model.cond_stage_model, token_weight_pairs)
|
||||
|
||||
def prepareXL(embs_l, embs_g, pooled, clip_balance):
|
||||
l_w = 1 - max(0, clip_balance - .5) * 2
|
||||
g_w = 1 - max(0, .5 - clip_balance) * 2
|
||||
if embs_l is not None:
|
||||
return torch.cat([embs_l * l_w, embs_g * g_w], dim=-1), pooled
|
||||
else:
|
||||
return embs_g, pooled
|
||||
|
||||
def advanced_encode(clip, text, token_normalization, weight_interpretation, w_max=1.0, clip_balance=.5, apply_to_pooled=True):
|
||||
tokenized = clip.tokenize(text, return_word_ids=True)
|
||||
if isinstance(tokenized, dict):
|
||||
embs_l = None
|
||||
embs_g = None
|
||||
pooled = None
|
||||
if 'l' in tokenized and isinstance(clip.cond_stage_model, SDXLClipModel):
|
||||
embs_l, _ = advanced_encode_from_tokens(tokenized['l'],
|
||||
token_normalization,
|
||||
weight_interpretation,
|
||||
lambda x: encode_token_weights(clip, x, encode_token_weights_l),
|
||||
w_max=w_max,
|
||||
return_pooled=False)
|
||||
if 'g' in tokenized:
|
||||
embs_g, pooled = advanced_encode_from_tokens(tokenized['g'],
|
||||
token_normalization,
|
||||
weight_interpretation,
|
||||
lambda x: encode_token_weights(clip, x, encode_token_weights_g),
|
||||
w_max=w_max,
|
||||
return_pooled=True,
|
||||
apply_to_pooled=apply_to_pooled)
|
||||
return prepareXL(embs_l, embs_g, pooled, clip_balance)
|
||||
else:
|
||||
return advanced_encode_from_tokens(tokenized,
|
||||
token_normalization,
|
||||
weight_interpretation,
|
||||
lambda x: (clip.encode_from_tokens(x), None),
|
||||
w_max=w_max)
|
||||
|
||||
########################################################################################################################
|
||||
from nodes import MAX_RESOLUTION
|
||||
|
||||
class AdvancedCLIPTextEncode:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"text": ("STRING", {"multiline": True}),
|
||||
"clip": ("CLIP",),
|
||||
"token_normalization": (["none", "mean", "length", "length+mean"],),
|
||||
"weight_interpretation": (["comfy", "A1111", "compel", "comfy++", "down_weight"],),
|
||||
# "affect_pooled": (["disable", "enable"],),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/advanced"
|
||||
|
||||
def encode(self, clip, text, token_normalization, weight_interpretation, affect_pooled='disable'):
|
||||
embeddings_final, pooled = advanced_encode(clip, text, token_normalization, weight_interpretation, w_max=1.0,
|
||||
apply_to_pooled=affect_pooled == 'enable')
|
||||
return ([[embeddings_final, {"pooled_output": pooled}]],)
|
||||
|
||||
|
||||
class AddCLIPSDXLRParams:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"conditioning": ("CONDITIONING",),
|
||||
"width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
|
||||
"height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
|
||||
"ascore": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/advanced"
|
||||
|
||||
def encode(self, conditioning, width, height, ascore):
|
||||
c = []
|
||||
for t in conditioning:
|
||||
n = [t[0], t[1].copy()]
|
||||
n[1]['width'] = width
|
||||
n[1]['height'] = height
|
||||
n[1]['aesthetic_score'] = ascore
|
||||
c.append(n)
|
||||
return (c,)
|
||||
|
||||
523
py/bnk_tiled_samplers.py
Normal file
523
py/bnk_tiled_samplers.py
Normal file
@@ -0,0 +1,523 @@
|
||||
# https://github.com/BlenderNeko/ComfyUI_TiledKSampler
|
||||
import sys
|
||||
import os
|
||||
import itertools
|
||||
import numpy as np
|
||||
from tqdm.auto import tqdm
|
||||
import torch
|
||||
|
||||
#sys.path.insert(0, os.path.join(os.path.dirname(os.path.realpath(__file__)), "comfy"))
|
||||
import comfy.sd
|
||||
import comfy.controlnet
|
||||
import comfy.model_management
|
||||
import comfy.sample
|
||||
#from . import tiling
|
||||
import latent_preview
|
||||
#import torch
|
||||
#import itertools
|
||||
#import numpy as np
|
||||
MAX_RESOLUTION=8192
|
||||
|
||||
def grouper(n, iterable):
|
||||
it = iter(iterable)
|
||||
while True:
|
||||
chunk = list(itertools.islice(it, n))
|
||||
if not chunk:
|
||||
return
|
||||
yield chunk
|
||||
|
||||
|
||||
def create_batches(n, iterable):
|
||||
groups = itertools.groupby(iterable, key=lambda x: (x[1], x[3]))
|
||||
for _, x in groups:
|
||||
for y in grouper(n, x):
|
||||
yield y
|
||||
|
||||
|
||||
def get_slice(tensor, h, h_len, w, w_len):
|
||||
t = tensor.narrow(-2, h, h_len)
|
||||
t = t.narrow(-1, w, w_len)
|
||||
return t
|
||||
|
||||
|
||||
def set_slice(tensor1, tensor2, h, h_len, w, w_len, mask=None):
|
||||
if mask is not None:
|
||||
tensor1[:, :, h:h + h_len, w:w + w_len] = tensor1[:, :, h:h + h_len, w:w + w_len] * (1 - mask) + tensor2 * mask
|
||||
else:
|
||||
tensor1[:, :, h:h + h_len, w:w + w_len] = tensor2
|
||||
|
||||
|
||||
def get_tiles_and_masks_simple(steps, latent_shape, tile_height, tile_width):
|
||||
latent_size_h = latent_shape[-2]
|
||||
latent_size_w = latent_shape[-1]
|
||||
tile_size_h = int(tile_height // 8)
|
||||
tile_size_w = int(tile_width // 8)
|
||||
|
||||
h = np.arange(0, latent_size_h, tile_size_h)
|
||||
w = np.arange(0, latent_size_w, tile_size_w)
|
||||
|
||||
def create_tile(hs, ws, i, j):
|
||||
h = int(hs[i])
|
||||
w = int(ws[j])
|
||||
h_len = min(tile_size_h, latent_size_h - h)
|
||||
w_len = min(tile_size_w, latent_size_w - w)
|
||||
return (h, h_len, w, w_len, steps, None)
|
||||
|
||||
passes = [
|
||||
[[create_tile(h, w, i, j) for i in range(len(h)) for j in range(len(w))]],
|
||||
]
|
||||
return passes
|
||||
|
||||
|
||||
def get_tiles_and_masks_padded(steps, latent_shape, tile_height, tile_width):
|
||||
batch_size = latent_shape[0]
|
||||
latent_size_h = latent_shape[-2]
|
||||
latent_size_w = latent_shape[-1]
|
||||
|
||||
tile_size_h = int(tile_height // 8)
|
||||
tile_size_h = int((tile_size_h // 4) * 4)
|
||||
tile_size_w = int(tile_width // 8)
|
||||
tile_size_w = int((tile_size_w // 4) * 4)
|
||||
|
||||
# masks
|
||||
mask_h = [0, tile_size_h // 4, tile_size_h - tile_size_h // 4, tile_size_h]
|
||||
mask_w = [0, tile_size_w // 4, tile_size_w - tile_size_w // 4, tile_size_w]
|
||||
masks = [[] for _ in range(3)]
|
||||
for i in range(3):
|
||||
for j in range(3):
|
||||
mask = torch.zeros((batch_size, 1, tile_size_h, tile_size_w), dtype=torch.float32, device='cpu')
|
||||
mask[:, :, mask_h[i]:mask_h[i + 1], mask_w[j]:mask_w[j + 1]] = 1.0
|
||||
masks[i].append(mask)
|
||||
|
||||
def create_mask(h_ind, w_ind, h_ind_max, w_ind_max, mask_h, mask_w, h_len, w_len):
|
||||
mask = masks[1][1]
|
||||
if not (h_ind == 0 or h_ind == h_ind_max or w_ind == 0 or w_ind == w_ind_max):
|
||||
return get_slice(mask, 0, h_len, 0, w_len)
|
||||
mask = mask.clone()
|
||||
if h_ind == 0 and mask_h:
|
||||
mask += masks[0][1]
|
||||
if h_ind == h_ind_max and mask_h:
|
||||
mask += masks[2][1]
|
||||
if w_ind == 0 and mask_w:
|
||||
mask += masks[1][0]
|
||||
if w_ind == w_ind_max and mask_w:
|
||||
mask += masks[1][2]
|
||||
if h_ind == 0 and w_ind == 0 and mask_h and mask_w:
|
||||
mask += masks[0][0]
|
||||
if h_ind == 0 and w_ind == w_ind_max and mask_h and mask_w:
|
||||
mask += masks[0][2]
|
||||
if h_ind == h_ind_max and w_ind == 0 and mask_h and mask_w:
|
||||
mask += masks[2][0]
|
||||
if h_ind == h_ind_max and w_ind == w_ind_max and mask_h and mask_w:
|
||||
mask += masks[2][2]
|
||||
return get_slice(mask, 0, h_len, 0, w_len)
|
||||
|
||||
h = np.arange(0, latent_size_h, tile_size_h)
|
||||
h_shift = np.arange(tile_size_h // 2, latent_size_h - tile_size_h // 2, tile_size_h)
|
||||
w = np.arange(0, latent_size_w, tile_size_w)
|
||||
w_shift = np.arange(tile_size_w // 2, latent_size_w - tile_size_h // 2, tile_size_w)
|
||||
|
||||
def create_tile(hs, ws, mask_h, mask_w, i, j):
|
||||
h = int(hs[i])
|
||||
w = int(ws[j])
|
||||
h_len = min(tile_size_h, latent_size_h - h)
|
||||
w_len = min(tile_size_w, latent_size_w - w)
|
||||
mask = create_mask(i, j, len(hs) - 1, len(ws) - 1, mask_h, mask_w, h_len, w_len)
|
||||
return (h, h_len, w, w_len, steps, mask)
|
||||
|
||||
passes = [
|
||||
[[create_tile(h, w, True, True, i, j) for i in range(len(h)) for j in range(len(w))]],
|
||||
[[create_tile(h_shift, w, False, True, i, j) for i in range(len(h_shift)) for j in range(len(w))]],
|
||||
[[create_tile(h, w_shift, True, False, i, j) for i in range(len(h)) for j in range(len(w_shift))]],
|
||||
[[create_tile(h_shift, w_shift, False, False, i, j) for i in range(len(h_shift)) for j in range(len(w_shift))]],
|
||||
]
|
||||
|
||||
return passes
|
||||
|
||||
|
||||
def mask_at_boundary(h, h_len, w, w_len, tile_size_h, tile_size_w, latent_size_h, latent_size_w, mask, device='cpu'):
|
||||
tile_size_h = int(tile_size_h // 8)
|
||||
tile_size_w = int(tile_size_w // 8)
|
||||
|
||||
if (h_len == tile_size_h or h_len == latent_size_h) and (w_len == tile_size_w or w_len == latent_size_w):
|
||||
return h, h_len, w, w_len, mask
|
||||
h_offset = min(0, latent_size_h - (h + tile_size_h))
|
||||
w_offset = min(0, latent_size_w - (w + tile_size_w))
|
||||
new_mask = torch.zeros((1, 1, tile_size_h, tile_size_w), dtype=torch.float32, device=device)
|
||||
new_mask[:, :, -h_offset:h_len if h_offset == 0 else tile_size_h,
|
||||
-w_offset:w_len if w_offset == 0 else tile_size_w] = 1.0 if mask is None else mask
|
||||
return h + h_offset, tile_size_h, w + w_offset, tile_size_w, new_mask
|
||||
|
||||
|
||||
def get_tiles_and_masks_rgrid(steps, latent_shape, tile_height, tile_width, generator):
|
||||
def calc_coords(latent_size, tile_size, jitter):
|
||||
tile_coords = int((latent_size + jitter - 1) // tile_size + 1)
|
||||
tile_coords = [np.clip(tile_size * c - jitter, 0, latent_size) for c in range(tile_coords + 1)]
|
||||
tile_coords = [(c1, c2 - c1) for c1, c2 in zip(tile_coords, tile_coords[1:])]
|
||||
return tile_coords
|
||||
|
||||
# calc stuff
|
||||
batch_size = latent_shape[0]
|
||||
latent_size_h = latent_shape[-2]
|
||||
latent_size_w = latent_shape[-1]
|
||||
tile_size_h = int(tile_height // 8)
|
||||
tile_size_w = int(tile_width // 8)
|
||||
|
||||
tiles_all = []
|
||||
|
||||
for s in range(steps):
|
||||
rands = torch.rand((2,), dtype=torch.float32, generator=generator, device='cpu').numpy()
|
||||
|
||||
jitter_w1 = int(rands[0] * tile_size_w)
|
||||
jitter_w2 = int(((rands[0] + .5) % 1.0) * tile_size_w)
|
||||
jitter_h1 = int(rands[1] * tile_size_h)
|
||||
jitter_h2 = int(((rands[1] + .5) % 1.0) * tile_size_h)
|
||||
|
||||
# calc number of tiles
|
||||
tiles_h = [
|
||||
calc_coords(latent_size_h, tile_size_h, jitter_h1),
|
||||
calc_coords(latent_size_h, tile_size_h, jitter_h2)
|
||||
]
|
||||
tiles_w = [
|
||||
calc_coords(latent_size_w, tile_size_w, jitter_w1),
|
||||
calc_coords(latent_size_w, tile_size_w, jitter_w2)
|
||||
]
|
||||
|
||||
tiles = []
|
||||
if s % 2 == 0:
|
||||
for i, h in enumerate(tiles_h[0]):
|
||||
for w in tiles_w[i % 2]:
|
||||
tiles.append((int(h[0]), int(h[1]), int(w[0]), int(w[1]), 1, None))
|
||||
else:
|
||||
for i, w in enumerate(tiles_w[0]):
|
||||
for h in tiles_h[i % 2]:
|
||||
tiles.append((int(h[0]), int(h[1]), int(w[0]), int(w[1]), 1, None))
|
||||
tiles_all.append(tiles)
|
||||
return [tiles_all]
|
||||
|
||||
#######################
|
||||
|
||||
def recursion_to_list(obj, attr):
|
||||
current = obj
|
||||
yield current
|
||||
while True:
|
||||
current = getattr(current, attr, None)
|
||||
if current is not None:
|
||||
yield current
|
||||
else:
|
||||
return
|
||||
|
||||
def copy_cond(cond):
|
||||
return [[c1,c2.copy()] for c1,c2 in cond]
|
||||
|
||||
def slice_cond(tile_h, tile_h_len, tile_w, tile_w_len, cond, area):
|
||||
tile_h_end = tile_h + tile_h_len
|
||||
tile_w_end = tile_w + tile_w_len
|
||||
coords = area[0] #h_len, w_len, h, w,
|
||||
mask = area[1]
|
||||
if coords is not None:
|
||||
h_len, w_len, h, w = coords
|
||||
h_end = h + h_len
|
||||
w_end = w + w_len
|
||||
if h < tile_h_end and h_end > tile_h and w < tile_w_end and w_end > tile_w:
|
||||
new_h = max(0, h - tile_h)
|
||||
new_w = max(0, w - tile_w)
|
||||
new_h_end = min(tile_h_end, h_end - tile_h)
|
||||
new_w_end = min(tile_w_end, w_end - tile_w)
|
||||
cond[1]['area'] = (new_h_end - new_h, new_w_end - new_w, new_h, new_w)
|
||||
else:
|
||||
return (cond, True)
|
||||
if mask is not None:
|
||||
new_mask = get_slice(mask, tile_h,tile_h_len,tile_w,tile_w_len)
|
||||
if new_mask.sum().cpu() == 0.0 and 'mask' in cond[1]:
|
||||
return (cond, True)
|
||||
else:
|
||||
cond[1]['mask'] = new_mask
|
||||
return (cond, False)
|
||||
|
||||
def slice_gligen(tile_h, tile_h_len, tile_w, tile_w_len, cond, gligen):
|
||||
tile_h_end = tile_h + tile_h_len
|
||||
tile_w_end = tile_w + tile_w_len
|
||||
if gligen is None:
|
||||
return
|
||||
gligen_type = gligen[0]
|
||||
gligen_model = gligen[1]
|
||||
gligen_areas = gligen[2]
|
||||
|
||||
gligen_areas_new = []
|
||||
for emb, h_len, w_len, h, w in gligen_areas:
|
||||
h_end = h + h_len
|
||||
w_end = w + w_len
|
||||
if h < tile_h_end and h_end > tile_h and w < tile_w_end and w_end > tile_w:
|
||||
new_h = max(0, h - tile_h)
|
||||
new_w = max(0, w - tile_w)
|
||||
new_h_end = min(tile_h_end, h_end - tile_h)
|
||||
new_w_end = min(tile_w_end, w_end - tile_w)
|
||||
gligen_areas_new.append((emb, new_h_end - new_h, new_w_end - new_w, new_h, new_w))
|
||||
|
||||
if len(gligen_areas_new) == 0:
|
||||
del cond['gligen']
|
||||
else:
|
||||
cond['gligen'] = (gligen_type, gligen_model, gligen_areas_new)
|
||||
|
||||
def slice_cnet(h, h_len, w, w_len, model:comfy.controlnet.ControlBase, img):
|
||||
if img is None:
|
||||
img = model.cond_hint_original
|
||||
model.cond_hint = get_slice(img, h*8, h_len*8, w*8, w_len*8).to(model.control_model.dtype).to(model.device)
|
||||
|
||||
def slices_T2I(h, h_len, w, w_len, model:comfy.controlnet.ControlBase, img):
|
||||
model.control_input = None
|
||||
if img is None:
|
||||
img = model.cond_hint_original
|
||||
model.cond_hint = get_slice(img, h*8, h_len*8, w*8, w_len*8).float().to(model.device)
|
||||
|
||||
# TODO: refactor some of the mess
|
||||
|
||||
from PIL import Image
|
||||
|
||||
def sample_common(model, add_noise, noise_seed, tile_width, tile_height, tiling_strategy, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, return_with_leftover_noise, denoise=1.0, preview=False):
|
||||
end_at_step = min(end_at_step, steps)
|
||||
device = comfy.model_management.get_torch_device()
|
||||
samples = latent_image["samples"]
|
||||
noise_mask = latent_image["noise_mask"] if "noise_mask" in latent_image else None
|
||||
force_full_denoise = return_with_leftover_noise == "enable"
|
||||
if add_noise == "disable":
|
||||
noise = torch.zeros(samples.size(), dtype=samples.dtype, layout=samples.layout, device="cpu")
|
||||
else:
|
||||
skip = latent_image["batch_index"] if "batch_index" in latent_image else None
|
||||
noise = comfy.sample.prepare_noise(samples, noise_seed, skip)
|
||||
|
||||
if noise_mask is not None:
|
||||
noise_mask = comfy.sample.prepare_mask(noise_mask, noise.shape, device='cpu')
|
||||
|
||||
shape = samples.shape
|
||||
samples = samples.clone()
|
||||
|
||||
tile_width = min(shape[-1] * 8, tile_width)
|
||||
tile_height = min(shape[2] * 8, tile_height)
|
||||
|
||||
real_model = None
|
||||
modelPatches, inference_memory = comfy.sample.get_additional_models(positive, negative, model.model_dtype())
|
||||
comfy.model_management.load_models_gpu([model] + modelPatches, comfy.model_management.batch_area_memory(noise.shape[0] * noise.shape[2] * noise.shape[3]) + inference_memory)
|
||||
real_model = model.model
|
||||
|
||||
sampler = comfy.samplers.KSampler(real_model, steps=steps, device=device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
|
||||
|
||||
if tiling_strategy != 'padded':
|
||||
if noise_mask is not None:
|
||||
samples += sampler.sigmas[start_at_step].cpu() * noise_mask * model.model.process_latent_out(noise).cpu()
|
||||
else:
|
||||
samples += sampler.sigmas[start_at_step].cpu() * model.model.process_latent_out(noise).cpu()
|
||||
|
||||
#cnets
|
||||
cnets = comfy.sample.get_models_from_cond(positive, 'control') + comfy.sample.get_models_from_cond(negative, 'control')
|
||||
cnets = [m for m in cnets if isinstance(m, comfy.controlnet.ControlNet)]
|
||||
cnets = list(set([x for m in cnets for x in recursion_to_list(m, "previous_controlnet")]))
|
||||
cnet_imgs = [
|
||||
torch.nn.functional.interpolate(m.cond_hint_original, (shape[-2] * 8, shape[-1] * 8), mode='nearest-exact').to('cpu')
|
||||
if m.cond_hint_original.shape[-2] != shape[-2] * 8 or m.cond_hint_original.shape[-1] != shape[-1] * 8 else None
|
||||
for m in cnets]
|
||||
|
||||
#T2I
|
||||
T2Is = comfy.sample.get_models_from_cond(positive, 'control') + comfy.sample.get_models_from_cond(negative, 'control')
|
||||
T2Is = [m for m in T2Is if isinstance(m, comfy.controlnet.T2IAdapter)]
|
||||
T2Is = [x for m in T2Is for x in recursion_to_list(m, "previous_controlnet")]
|
||||
T2I_imgs = [
|
||||
torch.nn.functional.interpolate(m.cond_hint_original, (shape[-2] * 8, shape[-1] * 8), mode='nearest-exact').to('cpu')
|
||||
if m.cond_hint_original.shape[-2] != shape[-2] * 8 or m.cond_hint_original.shape[-1] != shape[-1] * 8 or (m.channels_in == 1 and m.cond_hint_original.shape[1] != 1) else None
|
||||
for m in T2Is
|
||||
]
|
||||
T2I_imgs = [
|
||||
torch.mean(img, 1, keepdim=True) if img is not None and m.channels_in == 1 and m.cond_hint_original.shape[1] else img
|
||||
for m, img in zip(T2Is, T2I_imgs)
|
||||
]
|
||||
|
||||
#cond area and mask
|
||||
spatial_conds_pos = [
|
||||
(c[1]['area'] if 'area' in c[1] else None,
|
||||
comfy.sample.prepare_mask(c[1]['mask'], shape, device) if 'mask' in c[1] else None)
|
||||
for c in positive
|
||||
]
|
||||
spatial_conds_neg = [
|
||||
(c[1]['area'] if 'area' in c[1] else None,
|
||||
comfy.sample.prepare_mask(c[1]['mask'], shape, device) if 'mask' in c[1] else None)
|
||||
for c in negative
|
||||
]
|
||||
|
||||
#gligen
|
||||
gligen_pos = [
|
||||
c[1]['gligen'] if 'gligen' in c[1] else None
|
||||
for c in positive
|
||||
]
|
||||
gligen_neg = [
|
||||
c[1]['gligen'] if 'gligen' in c[1] else None
|
||||
for c in negative
|
||||
]
|
||||
|
||||
positive_copy = comfy.sample.broadcast_cond(positive, shape[0], device)
|
||||
negative_copy = comfy.sample.broadcast_cond(negative, shape[0], device)
|
||||
|
||||
gen = torch.manual_seed(noise_seed)
|
||||
if tiling_strategy == 'random' or tiling_strategy == 'random strict':
|
||||
tiles = get_tiles_and_masks_rgrid(end_at_step - start_at_step, samples.shape, tile_height, tile_width, gen)
|
||||
elif tiling_strategy == 'padded':
|
||||
tiles = get_tiles_and_masks_padded(end_at_step - start_at_step, samples.shape, tile_height, tile_width)
|
||||
else:
|
||||
tiles = get_tiles_and_masks_simple(end_at_step - start_at_step, samples.shape, tile_height, tile_width)
|
||||
|
||||
total_steps = sum([num_steps for img_pass in tiles for steps_list in img_pass for _,_,_,_,num_steps,_ in steps_list])
|
||||
current_step = [0]
|
||||
|
||||
preview_format = "JPEG"
|
||||
if preview_format not in ["JPEG", "PNG"]:
|
||||
preview_format = "JPEG"
|
||||
previewer = None
|
||||
if preview:
|
||||
previewer = latent_preview.get_previewer(device, model.model.latent_format)
|
||||
|
||||
|
||||
with tqdm(total=total_steps) as pbar_tqdm:
|
||||
pbar = comfy.utils.ProgressBar(total_steps)
|
||||
|
||||
def callback(step, x0, x, total_steps):
|
||||
current_step[0] += 1
|
||||
preview_bytes = None
|
||||
if previewer:
|
||||
preview_bytes = previewer.decode_latent_to_preview_image(preview_format, x0)
|
||||
pbar.update_absolute(current_step[0], preview=preview_bytes)
|
||||
pbar_tqdm.update(1)
|
||||
|
||||
if tiling_strategy == "random strict":
|
||||
samples_next = samples.clone()
|
||||
for img_pass in tiles:
|
||||
for i in range(len(img_pass)):
|
||||
for tile_h, tile_h_len, tile_w, tile_w_len, tile_steps, tile_mask in img_pass[i]:
|
||||
tiled_mask = None
|
||||
if noise_mask is not None:
|
||||
tiled_mask = get_slice(noise_mask, tile_h, tile_h_len, tile_w, tile_w_len).to(device)
|
||||
if tile_mask is not None:
|
||||
if tiled_mask is not None:
|
||||
tiled_mask *= tile_mask.to(device)
|
||||
else:
|
||||
tiled_mask = tile_mask.to(device)
|
||||
|
||||
if tiling_strategy == 'padded' or tiling_strategy == 'random strict':
|
||||
tile_h, tile_h_len, tile_w, tile_w_len, tiled_mask = mask_at_boundary( tile_h, tile_h_len, tile_w, tile_w_len,
|
||||
tile_height, tile_width, samples.shape[-2], samples.shape[-1],
|
||||
tiled_mask, device)
|
||||
|
||||
|
||||
if tiled_mask is not None and tiled_mask.sum().cpu() == 0.0:
|
||||
continue
|
||||
|
||||
tiled_latent = get_slice(samples, tile_h, tile_h_len, tile_w, tile_w_len).to(device)
|
||||
|
||||
if tiling_strategy == 'padded':
|
||||
tiled_noise = get_slice(noise, tile_h, tile_h_len, tile_w, tile_w_len).to(device)
|
||||
else:
|
||||
if tiled_mask is None or noise_mask is None:
|
||||
tiled_noise = torch.zeros_like(tiled_latent)
|
||||
else:
|
||||
tiled_noise = get_slice(noise, tile_h, tile_h_len, tile_w, tile_w_len).to(device) * (1 - tiled_mask)
|
||||
|
||||
#TODO: all other condition based stuff like area sets and GLIGEN should also happen here
|
||||
|
||||
#cnets
|
||||
for m, img in zip(cnets, cnet_imgs):
|
||||
slice_cnet(tile_h, tile_h_len, tile_w, tile_w_len, m, img)
|
||||
|
||||
#T2I
|
||||
for m, img in zip(T2Is, T2I_imgs):
|
||||
slices_T2I(tile_h, tile_h_len, tile_w, tile_w_len, m, img)
|
||||
|
||||
pos = copy_cond(positive_copy)
|
||||
neg = copy_cond(negative_copy)
|
||||
|
||||
#cond areas
|
||||
pos = [slice_cond(tile_h, tile_h_len, tile_w, tile_w_len, c, area) for c, area in zip(pos, spatial_conds_pos)]
|
||||
pos = [c for c, ignore in pos if not ignore]
|
||||
neg = [slice_cond(tile_h, tile_h_len, tile_w, tile_w_len, c, area) for c, area in zip(neg, spatial_conds_neg)]
|
||||
neg = [c for c, ignore in neg if not ignore]
|
||||
|
||||
#gligen
|
||||
for (_, cond), gligen in zip(pos, gligen_pos):
|
||||
slice_gligen(tile_h, tile_h_len, tile_w, tile_w_len, cond, gligen)
|
||||
for (_, cond), gligen in zip(neg, gligen_neg):
|
||||
slice_gligen(tile_h, tile_h_len, tile_w, tile_w_len, cond, gligen)
|
||||
|
||||
tile_result = sampler.sample(tiled_noise, pos, neg, cfg=cfg, latent_image=tiled_latent, start_step=start_at_step + i * tile_steps, last_step=start_at_step + i*tile_steps + tile_steps, force_full_denoise=force_full_denoise and i+1 == end_at_step - start_at_step, denoise_mask=tiled_mask, callback=callback, disable_pbar=True, seed=noise_seed)
|
||||
tile_result = tile_result.cpu()
|
||||
if tiled_mask is not None:
|
||||
tiled_mask = tiled_mask.cpu()
|
||||
if tiling_strategy == "random strict":
|
||||
set_slice(samples_next, tile_result, tile_h, tile_h_len, tile_w, tile_w_len, tiled_mask)
|
||||
else:
|
||||
set_slice(samples, tile_result, tile_h, tile_h_len, tile_w, tile_w_len, tiled_mask)
|
||||
if tiling_strategy == "random strict":
|
||||
samples = samples_next.clone()
|
||||
|
||||
|
||||
comfy.sample.cleanup_additional_models(modelPatches)
|
||||
|
||||
out = latent_image.copy()
|
||||
out["samples"] = samples.cpu()
|
||||
return (out, )
|
||||
|
||||
class TiledKSampler:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"model": ("MODEL",),
|
||||
"seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
||||
"tile_width": ("INT", {"default": 512, "min": 256, "max": MAX_RESOLUTION, "step": 64}),
|
||||
"tile_height": ("INT", {"default": 512, "min": 256, "max": MAX_RESOLUTION, "step": 64}),
|
||||
"tiling_strategy": (["random", "random strict", "padded", 'simple'], ),
|
||||
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
||||
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0}),
|
||||
"sampler_name": (comfy.samplers.KSampler.SAMPLERS, ),
|
||||
"scheduler": (comfy.samplers.KSampler.SCHEDULERS, ),
|
||||
"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"latent_image": ("LATENT", ),
|
||||
"denoise": ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "sample"
|
||||
|
||||
CATEGORY = "sampling"
|
||||
|
||||
def sample(self, model, seed, tile_width, tile_height, tiling_strategy, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise):
|
||||
steps_total = int(steps / denoise)
|
||||
return sample_common(model, 'enable', seed, tile_width, tile_height, tiling_strategy, steps_total, cfg, sampler_name, scheduler, positive, negative, latent_image, steps_total-steps, steps_total, 'disable', denoise=1.0, preview=True)
|
||||
|
||||
class TiledKSamplerAdvanced:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required":
|
||||
{"model": ("MODEL",),
|
||||
"add_noise": (["enable", "disable"], ),
|
||||
"noise_seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff}),
|
||||
"tile_width": ("INT", {"default": 512, "min": 256, "max": MAX_RESOLUTION, "step": 64}),
|
||||
"tile_height": ("INT", {"default": 512, "min": 256, "max": MAX_RESOLUTION, "step": 64}),
|
||||
"tiling_strategy": (["random", "random strict", "padded", 'simple'], ),
|
||||
"steps": ("INT", {"default": 20, "min": 1, "max": 10000}),
|
||||
"cfg": ("FLOAT", {"default": 8.0, "min": 0.0, "max": 100.0}),
|
||||
"sampler_name": (comfy.samplers.KSampler.SAMPLERS, ),
|
||||
"scheduler": (comfy.samplers.KSampler.SCHEDULERS, ),
|
||||
"positive": ("CONDITIONING", ),
|
||||
"negative": ("CONDITIONING", ),
|
||||
"latent_image": ("LATENT", ),
|
||||
"start_at_step": ("INT", {"default": 0, "min": 0, "max": 10000}),
|
||||
"end_at_step": ("INT", {"default": 10000, "min": 0, "max": 10000}),
|
||||
"return_with_leftover_noise": (["disable", "enable"], ),
|
||||
"preview": (["disable", "enable"], ),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "sample"
|
||||
|
||||
CATEGORY = "sampling"
|
||||
|
||||
def sample(self, model, add_noise, noise_seed, tile_width, tile_height, tiling_strategy, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, return_with_leftover_noise, preview, denoise=1.0):
|
||||
return sample_common(model, add_noise, noise_seed, tile_width, tile_height, tiling_strategy, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, start_at_step, end_at_step, return_with_leftover_noise, denoise=1.0, preview= preview == 'enable')
|
||||
16
py/cg_mixed_seed_noise.py
Normal file
16
py/cg_mixed_seed_noise.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# https://github.com/chrisgoringe/cg-noise
|
||||
import torch
|
||||
|
||||
def get_mixed_noise_function(original_noise_function, variation_seed, variation_weight):
|
||||
def prepare_mixed_noise(latent_image:torch.Tensor, seed, batch_inds):
|
||||
single_image_latent = latent_image[0].unsqueeze_(0)
|
||||
different_noise = original_noise_function(single_image_latent, variation_seed, batch_inds)
|
||||
original_noise = original_noise_function(single_image_latent, seed, batch_inds)
|
||||
if latent_image.shape[0]==1:
|
||||
mixed_noise = original_noise * (1.0-variation_weight) + different_noise * (variation_weight)
|
||||
else:
|
||||
mixed_noise = torch.empty_like(latent_image)
|
||||
for i in range(latent_image.shape[0]):
|
||||
mixed_noise[i] = original_noise * (1.0-variation_weight*i) + different_noise * (variation_weight*i)
|
||||
return mixed_noise
|
||||
return prepare_mixed_noise
|
||||
82
py/city96_latent_upscaler.py
Normal file
82
py/city96_latent_upscaler.py
Normal file
@@ -0,0 +1,82 @@
|
||||
# https://github.com/city96/SD-Latent-Upscaler
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from safetensors.torch import load_file
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
class Upscaler(nn.Module):
|
||||
"""
|
||||
Basic NN layout, ported from:
|
||||
https://github.com/city96/SD-Latent-Upscaler/blob/main/upscaler.py
|
||||
"""
|
||||
version = 2.1 # network revision
|
||||
def head(self):
|
||||
return [
|
||||
nn.Conv2d(self.chan, self.size, kernel_size=self.krn, padding=self.pad),
|
||||
nn.ReLU(),
|
||||
nn.Upsample(scale_factor=self.fac, mode="nearest"),
|
||||
nn.ReLU(),
|
||||
]
|
||||
def core(self):
|
||||
layers = []
|
||||
for _ in range(self.depth):
|
||||
layers += [
|
||||
nn.Conv2d(self.size, self.size, kernel_size=self.krn, padding=self.pad),
|
||||
nn.ReLU(),
|
||||
]
|
||||
return layers
|
||||
def tail(self):
|
||||
return [
|
||||
nn.Conv2d(self.size, self.chan, kernel_size=self.krn, padding=self.pad),
|
||||
]
|
||||
|
||||
def __init__(self, fac, depth=16):
|
||||
super().__init__()
|
||||
self.size = 64 # Conv2d size
|
||||
self.chan = 4 # in/out channels
|
||||
self.depth = depth # no. of layers
|
||||
self.fac = fac # scale factor
|
||||
self.krn = 3 # kernel size
|
||||
self.pad = 1 # padding
|
||||
|
||||
self.sequential = nn.Sequential(
|
||||
*self.head(),
|
||||
*self.core(),
|
||||
*self.tail(),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return self.sequential(x)
|
||||
|
||||
|
||||
class LatentUpscaler:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"samples": ("LATENT", ),
|
||||
"latent_ver": (["v1", "xl"],),
|
||||
"scale_factor": (["1.25", "1.5", "2.0"],),
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
FUNCTION = "upscale"
|
||||
CATEGORY = "latent"
|
||||
|
||||
def upscale(self, samples, latent_ver, scale_factor):
|
||||
model = Upscaler(scale_factor)
|
||||
weights = str(hf_hub_download(
|
||||
repo_id="city96/SD-Latent-Upscaler",
|
||||
filename=f"latent-upscaler-v{model.version}_SD{latent_ver}-x{scale_factor}.safetensors")
|
||||
)
|
||||
# weights = f"./latent-upscaler-v{model.version}_SD{latent_ver}-x{scale_factor}.safetensors"
|
||||
|
||||
model.load_state_dict(load_file(weights))
|
||||
lt = samples["samples"]
|
||||
lt = model(lt)
|
||||
del model
|
||||
return ({"samples": lt},)
|
||||
BIN
py/sd15_resizer.pt
Normal file
BIN
py/sd15_resizer.pt
Normal file
Binary file not shown.
BIN
py/sdxl_resizer.pt
Normal file
BIN
py/sdxl_resizer.pt
Normal file
Binary file not shown.
321
py/smZ_cfg_denoiser.py
Normal file
321
py/smZ_cfg_denoiser.py
Normal file
@@ -0,0 +1,321 @@
|
||||
# https://github.com/shiimizu/ComfyUI_smZNodes
|
||||
import comfy
|
||||
import torch
|
||||
from typing import List
|
||||
import comfy.sample
|
||||
from comfy import model_base, model_management
|
||||
from comfy.samplers import KSampler, CompVisVDenoiser, KSamplerX0Inpaint
|
||||
from comfy.k_diffusion.external import CompVisDenoiser
|
||||
import nodes
|
||||
import inspect
|
||||
import functools
|
||||
import importlib
|
||||
import os
|
||||
import re
|
||||
import itertools
|
||||
|
||||
import torch
|
||||
from comfy import model_management
|
||||
|
||||
def catenate_conds(conds):
|
||||
if not isinstance(conds[0], dict):
|
||||
return torch.cat(conds)
|
||||
|
||||
return {key: torch.cat([x[key] for x in conds]) for key in conds[0].keys()}
|
||||
|
||||
|
||||
def subscript_cond(cond, a, b):
|
||||
if not isinstance(cond, dict):
|
||||
return cond[a:b]
|
||||
|
||||
return {key: vec[a:b] for key, vec in cond.items()}
|
||||
|
||||
|
||||
def pad_cond(tensor, repeats, empty):
|
||||
if not isinstance(tensor, dict):
|
||||
return torch.cat([tensor, empty.repeat((tensor.shape[0], repeats, 1)).to(device=tensor.device)], axis=1)
|
||||
|
||||
tensor['crossattn'] = pad_cond(tensor['crossattn'], repeats, empty)
|
||||
return tensor
|
||||
|
||||
|
||||
class CFGDenoiser(torch.nn.Module):
|
||||
"""
|
||||
Classifier free guidance denoiser. A wrapper for stable diffusion model (specifically for unet)
|
||||
that can take a noisy picture and produce a noise-free picture using two guidances (prompts)
|
||||
instead of one. Originally, the second prompt is just an empty string, but we use non-empty
|
||||
negative prompt.
|
||||
"""
|
||||
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.inner_model = model
|
||||
self.model_wrap = None
|
||||
self.mask = None
|
||||
self.nmask = None
|
||||
self.init_latent = None
|
||||
self.steps = None
|
||||
"""number of steps as specified by user in UI"""
|
||||
|
||||
self.total_steps = None
|
||||
"""expected number of calls to denoiser calculated from self.steps and specifics of the selected sampler"""
|
||||
|
||||
self.step = 0
|
||||
self.image_cfg_scale = None
|
||||
self.padded_cond_uncond = False
|
||||
self.sampler = None
|
||||
self.model_wrap = None
|
||||
self.p = None
|
||||
self.mask_before_denoising = False
|
||||
|
||||
|
||||
def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
|
||||
denoised_uncond = x_out[-uncond.shape[0]:]
|
||||
denoised = torch.clone(denoised_uncond)
|
||||
|
||||
for i, conds in enumerate(conds_list):
|
||||
for cond_index, weight in conds:
|
||||
denoised[i] += (x_out[cond_index] - denoised_uncond[i]) * (weight * cond_scale)
|
||||
|
||||
return denoised
|
||||
|
||||
def combine_denoised_for_edit_model(self, x_out, cond_scale):
|
||||
out_cond, out_img_cond, out_uncond = x_out.chunk(3)
|
||||
denoised = out_uncond + cond_scale * (out_cond - out_img_cond) + self.image_cfg_scale * (out_img_cond - out_uncond)
|
||||
|
||||
return denoised
|
||||
|
||||
def get_pred_x0(self, x_in, x_out, sigma):
|
||||
return x_out
|
||||
|
||||
def update_inner_model(self):
|
||||
self.model_wrap = None
|
||||
|
||||
c, uc = self.p.get_conds()
|
||||
self.sampler.sampler_extra_args['cond'] = c
|
||||
self.sampler.sampler_extra_args['uncond'] = uc
|
||||
|
||||
def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
|
||||
model_management.throw_exception_if_processing_interrupted()
|
||||
|
||||
is_edit_model = False
|
||||
|
||||
conds_list, tensor = cond
|
||||
assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"
|
||||
|
||||
if self.mask_before_denoising and self.mask is not None:
|
||||
x = self.init_latent * self.mask + self.nmask * x
|
||||
|
||||
batch_size = len(conds_list)
|
||||
repeats = [len(conds_list[i]) for i in range(batch_size)]
|
||||
|
||||
if False:
|
||||
image_uncond = torch.zeros_like(image_cond)
|
||||
make_condition_dict = lambda c_crossattn, c_adm: {"c_crossattn": c_crossattn, "c_adm": c_adm, 'transformer_options': {'from_smZ': True}} # pylint: disable=C3001
|
||||
else:
|
||||
image_uncond = image_cond
|
||||
if isinstance(uncond, dict):
|
||||
make_condition_dict = lambda c_crossattn, c_concat: {**c_crossattn, "c_concat": None, "c_adm": x.c_adm, 'transformer_options': {'from_smZ': True}}
|
||||
else:
|
||||
make_condition_dict = lambda c_crossattn, c_concat: {"c_crossattn": c_crossattn, "c_concat": None, "c_adm": x.c_adm, 'transformer_options': {'from_smZ': True}}
|
||||
|
||||
if not is_edit_model:
|
||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x])
|
||||
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma])
|
||||
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond])
|
||||
else:
|
||||
x_in = torch.cat([torch.stack([x[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [x] + [x])
|
||||
sigma_in = torch.cat([torch.stack([sigma[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [sigma] + [sigma])
|
||||
image_cond_in = torch.cat([torch.stack([image_cond[i] for _ in range(n)]) for i, n in enumerate(repeats)] + [image_uncond] + [torch.zeros_like(self.init_latent)])
|
||||
|
||||
skip_uncond = False
|
||||
|
||||
# alternating uncond allows for higher thresholds without the quality loss normally expected from raising it
|
||||
if self.step % 2 and s_min_uncond > 0 and sigma[0] < s_min_uncond and not is_edit_model:
|
||||
skip_uncond = True
|
||||
x_in = x_in[:-batch_size]
|
||||
sigma_in = sigma_in[:-batch_size]
|
||||
|
||||
if tensor.shape[1] == uncond.shape[1] or skip_uncond:
|
||||
if is_edit_model:
|
||||
cond_in = catenate_conds([tensor, uncond, uncond])
|
||||
elif skip_uncond:
|
||||
cond_in = tensor
|
||||
else:
|
||||
cond_in = catenate_conds([tensor, uncond])
|
||||
|
||||
x_out = torch.zeros_like(x_in)
|
||||
for batch_offset in range(0, x_out.shape[0], batch_size):
|
||||
a = batch_offset
|
||||
b = a + batch_size
|
||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], **make_condition_dict(subscript_cond(cond_in, a, b), image_cond_in[a:b]))
|
||||
else:
|
||||
x_out = torch.zeros_like(x_in)
|
||||
for batch_offset in range(0, tensor.shape[0], batch_size):
|
||||
a = batch_offset
|
||||
b = min(a + batch_size, tensor.shape[0])
|
||||
|
||||
if not is_edit_model:
|
||||
c_crossattn = subscript_cond(tensor, a, b)
|
||||
else:
|
||||
c_crossattn = torch.cat([tensor[a:b]], uncond)
|
||||
|
||||
x_out[a:b] = self.inner_model(x_in[a:b], sigma_in[a:b], **make_condition_dict(c_crossattn, image_cond_in[a:b]))
|
||||
|
||||
if not skip_uncond:
|
||||
x_out[-uncond.shape[0]:] = self.inner_model(x_in[-uncond.shape[0]:], sigma_in[-uncond.shape[0]:], **make_condition_dict(uncond, image_cond_in[-uncond.shape[0]:]))
|
||||
|
||||
denoised_image_indexes = [x[0][0] for x in conds_list]
|
||||
if skip_uncond:
|
||||
fake_uncond = torch.cat([x_out[i:i+1] for i in denoised_image_indexes])
|
||||
x_out = torch.cat([x_out, fake_uncond]) # we skipped uncond denoising, so we put cond-denoised image to where the uncond-denoised image should be
|
||||
|
||||
if is_edit_model:
|
||||
denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
|
||||
elif skip_uncond:
|
||||
denoised = self.combine_denoised(x_out, conds_list, uncond, 1.0)
|
||||
else:
|
||||
denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)
|
||||
|
||||
if not self.mask_before_denoising and self.mask is not None:
|
||||
denoised = self.init_latent * self.mask + self.nmask * denoised
|
||||
|
||||
self.step += 1
|
||||
del x_out
|
||||
return denoised
|
||||
|
||||
# ========================================================================
|
||||
|
||||
def expand(tensor1, tensor2):
|
||||
def adjust_tensor_shape(tensor_small, tensor_big):
|
||||
# Calculate replication factor
|
||||
# -(-a // b) is ceiling of division without importing math.ceil
|
||||
replication_factor = -(-tensor_big.size(1) // tensor_small.size(1))
|
||||
|
||||
# Use repeat to extend tensor_small
|
||||
tensor_small_extended = tensor_small.repeat(1, replication_factor, 1)
|
||||
|
||||
# Take the rows of the extended tensor_small to match tensor_big
|
||||
tensor_small_matched = tensor_small_extended[:, :tensor_big.size(1), :]
|
||||
|
||||
return tensor_small_matched
|
||||
|
||||
# Check if their second dimensions are different
|
||||
if tensor1.size(1) != tensor2.size(1):
|
||||
# Check which tensor has the smaller second dimension and adjust its shape
|
||||
if tensor1.size(1) < tensor2.size(1):
|
||||
tensor1 = adjust_tensor_shape(tensor1, tensor2)
|
||||
else:
|
||||
tensor2 = adjust_tensor_shape(tensor2, tensor1)
|
||||
return (tensor1, tensor2)
|
||||
|
||||
def _find_outer_instance(target, target_type):
|
||||
import inspect
|
||||
frame = inspect.currentframe()
|
||||
while frame:
|
||||
if target in frame.f_locals:
|
||||
found = frame.f_locals[target]
|
||||
if isinstance(found, target_type) and found != 1: # steps == 1
|
||||
return found
|
||||
frame = frame.f_back
|
||||
return None
|
||||
|
||||
# ========================================================================
|
||||
def bounded_modulo(number, modulo_value):
|
||||
return number if number < modulo_value else modulo_value
|
||||
|
||||
def calc_cond(c, current_step):
|
||||
"""Group by smZ conds that may do prompt-editing / regular conds / comfy conds."""
|
||||
_cond = []
|
||||
# Group by conds from smZ
|
||||
fn=lambda x : x[1].get("from_smZ", None) is not None
|
||||
an_iterator = itertools.groupby(c, fn )
|
||||
for key, group in an_iterator:
|
||||
ls=list(group)
|
||||
# Group by prompt-editing conds
|
||||
fn2=lambda x : x[1].get("smZid", None)
|
||||
an_iterator2 = itertools.groupby(ls, fn2)
|
||||
for key2, group2 in an_iterator2:
|
||||
ls2=list(group2)
|
||||
if key2 is not None:
|
||||
orig_len = ls2[0][1].get('orig_len', 1)
|
||||
i = bounded_modulo(current_step, orig_len - 1)
|
||||
_cond = _cond + [ls2[i]]
|
||||
else:
|
||||
_cond = _cond + ls2
|
||||
return _cond
|
||||
|
||||
class CFGNoisePredictor(torch.nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.ksampler = _find_outer_instance('self', comfy.samplers.KSampler)
|
||||
self.step = 0
|
||||
self.orig = comfy.samplers.CFGNoisePredictor(model) #CFGNoisePredictorOrig(model)
|
||||
self.inner_model = model
|
||||
self.inner_model2 = CFGDenoiser(model.apply_model)
|
||||
self.inner_model2.num_timesteps = model.num_timesteps
|
||||
self.inner_model2.device = self.ksampler.device if hasattr(self.ksampler, "device") else None
|
||||
self.s_min_uncond = 0.0
|
||||
self.alphas_cumprod = model.alphas_cumprod
|
||||
self.c_adm = None
|
||||
self.init_cond = None
|
||||
self.init_uncond = None
|
||||
self.is_prompt_editing_u = False
|
||||
self.is_prompt_editing_c = False
|
||||
|
||||
def apply_model(self, x, timestep, cond, uncond, cond_scale, cond_concat=None, model_options={}, seed=None):
|
||||
|
||||
cc=calc_cond(cond, self.step)
|
||||
uu=calc_cond(uncond, self.step)
|
||||
self.step += 1
|
||||
|
||||
if (any([p[1].get('from_smZ', False) for p in cc]) or
|
||||
any([p[1].get('from_smZ', False) for p in uu])):
|
||||
if model_options.get('transformer_options',None) is None:
|
||||
model_options['transformer_options'] = {}
|
||||
model_options['transformer_options']['from_smZ'] = True
|
||||
|
||||
# Only supports one cond
|
||||
for ix in range(len(cc)):
|
||||
if cc[ix][1].get('from_smZ', False):
|
||||
cc = [cc[ix]]
|
||||
break
|
||||
for ix in range(len(uu)):
|
||||
if uu[ix][1].get('from_smZ', False):
|
||||
uu = [uu[ix]]
|
||||
break
|
||||
c=cc[0][1]
|
||||
u=uu[0][1]
|
||||
_cc = cc[0][0]
|
||||
_uu = uu[0][0]
|
||||
if c.get("adm_encoded", None) is not None:
|
||||
self.c_adm = torch.cat([c['adm_encoded'], u['adm_encoded']])
|
||||
# SDXL. Need to pad with repeats
|
||||
_cc, _uu = expand(_cc, _uu)
|
||||
_uu, _cc = expand(_uu, _cc)
|
||||
x.c_adm = self.c_adm
|
||||
conds_list = c.get('conds_list', [[(0, 1.0)]])
|
||||
image_cond = txt2img_image_conditioning(None, x)
|
||||
out = self.inner_model2(x, timestep, cond=(conds_list, _cc), uncond=_uu, cond_scale=cond_scale, s_min_uncond=self.s_min_uncond, image_cond=image_cond)
|
||||
return out
|
||||
|
||||
def txt2img_image_conditioning(sd_model, x, width=None, height=None):
|
||||
return x.new_zeros(x.shape[0], 5, 1, 1, dtype=x.dtype, device=x.device)
|
||||
|
||||
# =======================================================================================
|
||||
|
||||
def set_model_k(self: KSampler):
|
||||
self.model_denoise = CFGNoisePredictor(self.model) # main change
|
||||
if ((getattr(self.model, "parameterization", "") == "v") or
|
||||
(getattr(self.model, "model_type", -1) == model_base.ModelType.V_PREDICTION)):
|
||||
self.model_wrap = CompVisVDenoiser(self.model_denoise, quantize=True)
|
||||
self.model_wrap.parameterization = getattr(self.model, "parameterization", "v")
|
||||
else:
|
||||
self.model_wrap = CompVisDenoiser(self.model_denoise, quantize=True)
|
||||
self.model_wrap.parameterization = getattr(self.model, "parameterization", "eps")
|
||||
self.model_k = KSamplerX0Inpaint(self.model_wrap)
|
||||
|
||||
class SDKSampler(comfy.samplers.KSampler):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(SDKSampler, self).__init__(*args, **kwargs)
|
||||
set_model_k(self)
|
||||
140
py/smZ_rng_source.py
Normal file
140
py/smZ_rng_source.py
Normal file
@@ -0,0 +1,140 @@
|
||||
# https://github.com/shiimizu/ComfyUI_smZNodes
|
||||
import numpy as np
|
||||
|
||||
philox_m = [0xD2511F53, 0xCD9E8D57]
|
||||
philox_w = [0x9E3779B9, 0xBB67AE85]
|
||||
|
||||
two_pow32_inv = np.array([2.3283064e-10], dtype=np.float32)
|
||||
two_pow32_inv_2pi = np.array([2.3283064e-10 * 6.2831855], dtype=np.float32)
|
||||
|
||||
|
||||
def uint32(x):
|
||||
"""Converts (N,) np.uint64 array into (2, N) np.unit32 array."""
|
||||
return x.view(np.uint32).reshape(-1, 2).transpose(1, 0)
|
||||
|
||||
|
||||
def philox4_round(counter, key):
|
||||
"""A single round of the Philox 4x32 random number generator."""
|
||||
|
||||
v1 = uint32(counter[0].astype(np.uint64) * philox_m[0])
|
||||
v2 = uint32(counter[2].astype(np.uint64) * philox_m[1])
|
||||
|
||||
counter[0] = v2[1] ^ counter[1] ^ key[0]
|
||||
counter[1] = v2[0]
|
||||
counter[2] = v1[1] ^ counter[3] ^ key[1]
|
||||
counter[3] = v1[0]
|
||||
|
||||
|
||||
def philox4_32(counter, key, rounds=10):
|
||||
"""Generates 32-bit random numbers using the Philox 4x32 random number generator.
|
||||
|
||||
Parameters:
|
||||
counter (numpy.ndarray): A 4xN array of 32-bit integers representing the counter values (offset into generation).
|
||||
key (numpy.ndarray): A 2xN array of 32-bit integers representing the key values (seed).
|
||||
rounds (int): The number of rounds to perform.
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: A 4xN array of 32-bit integers containing the generated random numbers.
|
||||
"""
|
||||
|
||||
for _ in range(rounds - 1):
|
||||
philox4_round(counter, key)
|
||||
|
||||
key[0] = key[0] + philox_w[0]
|
||||
key[1] = key[1] + philox_w[1]
|
||||
|
||||
philox4_round(counter, key)
|
||||
return counter
|
||||
|
||||
|
||||
def box_muller(x, y):
|
||||
"""Returns just the first out of two numbers generated by Box–Muller transform algorithm."""
|
||||
u = x * two_pow32_inv + two_pow32_inv / 2
|
||||
v = y * two_pow32_inv_2pi + two_pow32_inv_2pi / 2
|
||||
|
||||
s = np.sqrt(-2.0 * np.log(u))
|
||||
|
||||
r1 = s * np.sin(v)
|
||||
return r1.astype(np.float32)
|
||||
|
||||
|
||||
class Generator:
|
||||
"""RNG that produces same outputs as torch.randn(..., device='cuda') on CPU"""
|
||||
|
||||
def __init__(self, seed):
|
||||
self.seed = seed
|
||||
self.offset = 0
|
||||
|
||||
def randn(self, shape):
|
||||
"""Generate a sequence of n standard normal random variables using the Philox 4x32 random number generator and the Box-Muller transform."""
|
||||
|
||||
n = 1
|
||||
for x in shape:
|
||||
n *= x
|
||||
|
||||
counter = np.zeros((4, n), dtype=np.uint32)
|
||||
counter[0] = self.offset
|
||||
counter[2] = np.arange(n, dtype=np.uint32) # up to 2^32 numbers can be generated - if you want more you'd need to spill into counter[3]
|
||||
self.offset += 1
|
||||
|
||||
key = np.empty(n, dtype=np.uint64)
|
||||
key.fill(self.seed)
|
||||
key = uint32(key)
|
||||
|
||||
g = philox4_32(counter, key)
|
||||
|
||||
return box_muller(g[0], g[1]).reshape(shape) # discard g[2] and g[3]
|
||||
|
||||
#=======================================================================================================================
|
||||
# Monkey Patch "prepare_noise" function
|
||||
# https://github.com/shiimizu/ComfyUI_smZNodes
|
||||
import torch
|
||||
import functools
|
||||
from comfy.sample import np
|
||||
import comfy.model_management
|
||||
|
||||
def rng_rand_source(rand_source='cpu'):
|
||||
device = comfy.model_management.text_encoder_device()
|
||||
|
||||
def prepare_noise(latent_image, seed, noise_inds=None, device='cpu'):
|
||||
"""
|
||||
creates random noise given a latent image and a seed.
|
||||
optional arg skip can be used to skip and discard x number of noise generations for a given seed
|
||||
"""
|
||||
generator = torch.Generator(device).manual_seed(seed)
|
||||
if rand_source == 'nv':
|
||||
rng = Generator(seed)
|
||||
if noise_inds is None:
|
||||
shape = latent_image.size()
|
||||
if rand_source == 'nv':
|
||||
return torch.asarray(rng.randn(shape), device=device)
|
||||
else:
|
||||
return torch.randn(shape, dtype=latent_image.dtype, layout=latent_image.layout, generator=generator,
|
||||
device=device)
|
||||
|
||||
unique_inds, inverse = np.unique(noise_inds, return_inverse=True)
|
||||
noises = []
|
||||
for i in range(unique_inds[-1] + 1):
|
||||
shape = [1] + list(latent_image.size())[1:]
|
||||
if rand_source == 'nv':
|
||||
noise = torch.asarray(rng.randn(shape), device=device)
|
||||
else:
|
||||
noise = torch.randn(shape, dtype=latent_image.dtype, layout=latent_image.layout, generator=generator,
|
||||
device=device)
|
||||
if i in unique_inds:
|
||||
noises.append(noise)
|
||||
noises = [noises[i] for i in inverse]
|
||||
noises = torch.cat(noises, axis=0)
|
||||
return noises
|
||||
|
||||
if rand_source == 'cpu':
|
||||
if hasattr(comfy.sample, 'prepare_noise_orig'):
|
||||
comfy.sample.prepare_noise = comfy.sample.prepare_noise_orig
|
||||
else:
|
||||
if not hasattr(comfy.sample, 'prepare_noise_orig'):
|
||||
comfy.sample.prepare_noise_orig = comfy.sample.prepare_noise
|
||||
_prepare_noise = functools.partial(prepare_noise, device=device)
|
||||
comfy.sample.prepare_noise = _prepare_noise
|
||||
|
||||
|
||||
|
||||
313
py/ttl_nn_latent_upscaler.py
Normal file
313
py/ttl_nn_latent_upscaler.py
Normal file
@@ -0,0 +1,313 @@
|
||||
import torch
|
||||
#from .latent_resizer import LatentResizer
|
||||
from comfy import model_management
|
||||
import os
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
|
||||
def normalization(channels):
|
||||
return nn.GroupNorm(32, channels)
|
||||
|
||||
|
||||
def zero_module(module):
|
||||
for p in module.parameters():
|
||||
p.detach().zero_()
|
||||
return module
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = normalization(in_channels)
|
||||
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
self.proj_out = nn.Conv2d(
|
||||
in_channels, in_channels, kernel_size=1, stride=1, padding=0
|
||||
)
|
||||
|
||||
def attention(self, h_: torch.Tensor) -> torch.Tensor:
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
b, c, h, w = q.shape
|
||||
q, k, v = map(
|
||||
lambda x: rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v)
|
||||
)
|
||||
h_ = nn.functional.scaled_dot_product_attention(
|
||||
q, k, v
|
||||
) # scale is dim ** -0.5 per default
|
||||
|
||||
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
|
||||
|
||||
def forward(self, x, **kwargs):
|
||||
h_ = x
|
||||
h_ = self.attention(h_)
|
||||
h_ = self.proj_out(h_)
|
||||
return x + h_
|
||||
|
||||
|
||||
def make_attn(in_channels, attn_kwargs=None):
|
||||
return AttnBlock(in_channels)
|
||||
|
||||
|
||||
class ResBlockEmb(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels,
|
||||
emb_channels,
|
||||
dropout=0,
|
||||
out_channels=None,
|
||||
use_conv=False,
|
||||
use_scale_shift_norm=False,
|
||||
kernel_size=3,
|
||||
exchange_temb_dims=False,
|
||||
skip_t_emb=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.emb_channels = emb_channels
|
||||
self.dropout = dropout
|
||||
self.out_channels = out_channels or channels
|
||||
self.use_conv = use_conv
|
||||
self.use_scale_shift_norm = use_scale_shift_norm
|
||||
self.exchange_temb_dims = exchange_temb_dims
|
||||
|
||||
padding = kernel_size // 2
|
||||
|
||||
self.in_layers = nn.Sequential(
|
||||
normalization(channels),
|
||||
nn.SiLU(),
|
||||
nn.Conv2d(channels, self.out_channels, kernel_size, padding=padding),
|
||||
)
|
||||
|
||||
self.skip_t_emb = skip_t_emb
|
||||
self.emb_out_channels = (
|
||||
2 * self.out_channels if use_scale_shift_norm else self.out_channels
|
||||
)
|
||||
if self.skip_t_emb:
|
||||
print(f"Skipping timestep embedding in {self.__class__.__name__}")
|
||||
assert not self.use_scale_shift_norm
|
||||
self.emb_layers = None
|
||||
self.exchange_temb_dims = False
|
||||
else:
|
||||
self.emb_layers = nn.Sequential(
|
||||
nn.SiLU(),
|
||||
nn.Linear(
|
||||
emb_channels,
|
||||
self.emb_out_channels,
|
||||
),
|
||||
)
|
||||
|
||||
self.out_layers = nn.Sequential(
|
||||
normalization(self.out_channels),
|
||||
nn.SiLU(),
|
||||
nn.Dropout(p=dropout),
|
||||
zero_module(
|
||||
nn.Conv2d(
|
||||
self.out_channels,
|
||||
self.out_channels,
|
||||
kernel_size,
|
||||
padding=padding,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
if self.out_channels == channels:
|
||||
self.skip_connection = nn.Identity()
|
||||
elif use_conv:
|
||||
self.skip_connection = nn.Conv2d(
|
||||
channels, self.out_channels, kernel_size, padding=padding
|
||||
)
|
||||
else:
|
||||
self.skip_connection = nn.Conv2d(channels, self.out_channels, 1)
|
||||
|
||||
def forward(self, x, emb):
|
||||
h = self.in_layers(x)
|
||||
|
||||
if self.skip_t_emb:
|
||||
emb_out = torch.zeros_like(h)
|
||||
else:
|
||||
emb_out = self.emb_layers(emb).type(h.dtype)
|
||||
while len(emb_out.shape) < len(h.shape):
|
||||
emb_out = emb_out[..., None]
|
||||
if self.use_scale_shift_norm:
|
||||
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
|
||||
scale, shift = torch.chunk(emb_out, 2, dim=1)
|
||||
h = out_norm(h) * (1 + scale) + shift
|
||||
h = out_rest(h)
|
||||
else:
|
||||
if self.exchange_temb_dims:
|
||||
emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
|
||||
h = h + emb_out
|
||||
h = self.out_layers(h)
|
||||
return self.skip_connection(x) + h
|
||||
|
||||
|
||||
class LatentResizer(nn.Module):
|
||||
def __init__(self, in_blocks=10, out_blocks=10, channels=128, dropout=0, attn=True):
|
||||
super().__init__()
|
||||
self.conv_in = nn.Conv2d(4, channels, 3, padding=1)
|
||||
|
||||
self.channels = channels
|
||||
embed_dim = 32
|
||||
self.embed = nn.Sequential(
|
||||
nn.Linear(1, embed_dim),
|
||||
nn.SiLU(),
|
||||
nn.Linear(embed_dim, embed_dim),
|
||||
)
|
||||
|
||||
self.in_blocks = nn.ModuleList([])
|
||||
for b in range(in_blocks):
|
||||
if (b == 1 or b == in_blocks - 1) and attn:
|
||||
self.in_blocks.append(make_attn(channels))
|
||||
self.in_blocks.append(ResBlockEmb(channels, embed_dim, dropout))
|
||||
|
||||
self.out_blocks = nn.ModuleList([])
|
||||
for b in range(out_blocks):
|
||||
if (b == 1 or b == out_blocks - 1) and attn:
|
||||
self.out_blocks.append(make_attn(channels))
|
||||
self.out_blocks.append(ResBlockEmb(channels, embed_dim, dropout))
|
||||
|
||||
self.norm_out = normalization(channels)
|
||||
self.conv_out = nn.Conv2d(channels, 4, 3, padding=1)
|
||||
|
||||
@classmethod
|
||||
def load_model(cls, filename, device="cpu", dtype=torch.float32, dropout=0):
|
||||
if not 'weights_only' in torch.load.__code__.co_varnames:
|
||||
weights = torch.load(filename, map_location=torch.device("cpu"))
|
||||
else:
|
||||
weights = torch.load(filename, map_location=torch.device("cpu"), weights_only=True)
|
||||
in_blocks = 0
|
||||
out_blocks = 0
|
||||
in_tfs = 0
|
||||
out_tfs = 0
|
||||
channels = weights["conv_in.bias"].shape[0]
|
||||
for k in weights.keys():
|
||||
k = k.split(".")
|
||||
if k[0] == "in_blocks":
|
||||
in_blocks = max(in_blocks, int(k[1]))
|
||||
if k[2] == "q" and k[3] == "weight":
|
||||
in_tfs += 1
|
||||
if k[0] == "out_blocks":
|
||||
out_blocks = max(out_blocks, int(k[1]))
|
||||
if k[2] == "q" and k[3] == "weight":
|
||||
out_tfs += 1
|
||||
in_blocks = in_blocks + 1 - in_tfs
|
||||
out_blocks = out_blocks + 1 - out_tfs
|
||||
resizer = cls(
|
||||
in_blocks=in_blocks,
|
||||
out_blocks=out_blocks,
|
||||
channels=channels,
|
||||
dropout=dropout,
|
||||
attn=(out_tfs != 0),
|
||||
)
|
||||
resizer.load_state_dict(weights)
|
||||
resizer.eval()
|
||||
resizer.to(device, dtype=dtype)
|
||||
return resizer
|
||||
|
||||
def forward(self, x, scale=None, size=None):
|
||||
if scale is None and size is None:
|
||||
raise ValueError("Either scale or size needs to be not None")
|
||||
if scale is not None and size is not None:
|
||||
raise ValueError("Both scale or size can't be not None")
|
||||
if scale is not None:
|
||||
size = (x.shape[-2] * scale, x.shape[-1] * scale)
|
||||
size = tuple([int(round(i)) for i in size])
|
||||
else:
|
||||
scale = size[-1] / x.shape[-1]
|
||||
|
||||
# Output is the same size as input
|
||||
if size == x.shape[-2:]:
|
||||
return x
|
||||
|
||||
scale = torch.tensor([scale - 1], dtype=x.dtype).to(x.device).unsqueeze(0)
|
||||
emb = self.embed(scale)
|
||||
|
||||
x = self.conv_in(x)
|
||||
|
||||
for b in self.in_blocks:
|
||||
if isinstance(b, ResBlockEmb):
|
||||
x = b(x, emb)
|
||||
else:
|
||||
x = b(x)
|
||||
x = F.interpolate(x, size=size, mode="bilinear")
|
||||
for b in self.out_blocks:
|
||||
if isinstance(b, ResBlockEmb):
|
||||
x = b(x, emb)
|
||||
else:
|
||||
x = b(x)
|
||||
|
||||
x = self.norm_out(x)
|
||||
x = F.silu(x)
|
||||
x = self.conv_out(x)
|
||||
return x
|
||||
|
||||
########################################################
|
||||
class NNLatentUpscale:
|
||||
"""
|
||||
Upscales SDXL latent using neural network
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.local_dir = os.path.dirname(os.path.realpath(__file__))
|
||||
self.scale_factor = 0.13025
|
||||
self.dtype = torch.float32
|
||||
if model_management.should_use_fp16():
|
||||
self.dtype = torch.float16
|
||||
self.weight_path = {
|
||||
"SDXL": os.path.join(self.local_dir, "sdxl_resizer.pt"),
|
||||
"SD 1.x": os.path.join(self.local_dir, "sd15_resizer.pt"),
|
||||
}
|
||||
self.version = "none"
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"latent": ("LATENT",),
|
||||
"version": (["SDXL", "SD 1.x"],),
|
||||
"upscale": (
|
||||
"FLOAT",
|
||||
{
|
||||
"default": 1.5,
|
||||
"min": 1.0,
|
||||
"max": 2.0,
|
||||
"step": 0.01,
|
||||
"display": "number",
|
||||
},
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
RETURN_TYPES = ("LATENT",)
|
||||
|
||||
FUNCTION = "upscale"
|
||||
|
||||
CATEGORY = "latent"
|
||||
|
||||
def upscale(self, latent, version, upscale):
|
||||
device = model_management.get_torch_device()
|
||||
samples = latent["samples"].to(device=device, dtype=self.dtype)
|
||||
|
||||
if version != self.version:
|
||||
self.model = LatentResizer.load_model(self.weight_path[version], device, self.dtype)
|
||||
self.version = version
|
||||
|
||||
self.model.to(device=device)
|
||||
latent_out = (self.model(self.scale_factor * samples, scale=upscale) / self.scale_factor)
|
||||
|
||||
if self.dtype != torch.float32:
|
||||
latent_out = latent_out.to(dtype=torch.float32)
|
||||
|
||||
latent_out = latent_out.to(device="cpu")
|
||||
|
||||
self.model.to(device=model_management.vae_offload_device())
|
||||
return ({"samples": latent_out},)
|
||||
Reference in New Issue
Block a user