mirror of
https://github.com/jags111/efficiency-nodes-comfyui.git
synced 2026-03-21 21:22:13 -03:00
Efficiency Nodes V2.0
This commit is contained in:
317
py/bnk_adv_encode.py
Normal file
317
py/bnk_adv_encode.py
Normal file
@@ -0,0 +1,317 @@
|
||||
import torch
|
||||
import numpy as np
|
||||
import itertools
|
||||
#from math import gcd
|
||||
|
||||
from comfy import model_management
|
||||
from comfy.sdxl_clip import SDXLClipModel
|
||||
|
||||
def _grouper(n, iterable):
|
||||
it = iter(iterable)
|
||||
while True:
|
||||
chunk = list(itertools.islice(it, n))
|
||||
if not chunk:
|
||||
return
|
||||
yield chunk
|
||||
|
||||
def _norm_mag(w, n):
|
||||
d = w - 1
|
||||
return 1 + np.sign(d) * np.sqrt(np.abs(d)**2 / n)
|
||||
#return np.sign(w) * np.sqrt(np.abs(w)**2 / n)
|
||||
|
||||
def divide_length(word_ids, weights):
|
||||
sums = dict(zip(*np.unique(word_ids, return_counts=True)))
|
||||
sums[0] = 1
|
||||
weights = [[_norm_mag(w, sums[id]) if id != 0 else 1.0
|
||||
for w, id in zip(x, y)] for x, y in zip(weights, word_ids)]
|
||||
return weights
|
||||
|
||||
def shift_mean_weight(word_ids, weights):
|
||||
delta = 1 - np.mean([w for x, y in zip(weights, word_ids) for w, id in zip(x,y) if id != 0])
|
||||
weights = [[w if id == 0 else w+delta
|
||||
for w, id in zip(x, y)] for x, y in zip(weights, word_ids)]
|
||||
return weights
|
||||
|
||||
def scale_to_norm(weights, word_ids, w_max):
|
||||
top = np.max(weights)
|
||||
w_max = min(top, w_max)
|
||||
weights = [[w_max if id == 0 else (w/top) * w_max
|
||||
for w, id in zip(x, y)] for x, y in zip(weights, word_ids)]
|
||||
return weights
|
||||
|
||||
def from_zero(weights, base_emb):
|
||||
weight_tensor = torch.tensor(weights, dtype=base_emb.dtype, device=base_emb.device)
|
||||
weight_tensor = weight_tensor.reshape(1,-1,1).expand(base_emb.shape)
|
||||
return base_emb * weight_tensor
|
||||
|
||||
def mask_word_id(tokens, word_ids, target_id, mask_token):
|
||||
new_tokens = [[mask_token if wid == target_id else t
|
||||
for t, wid in zip(x,y)] for x,y in zip(tokens, word_ids)]
|
||||
mask = np.array(word_ids) == target_id
|
||||
return (new_tokens, mask)
|
||||
|
||||
def batched_clip_encode(tokens, length, encode_func, num_chunks):
|
||||
embs = []
|
||||
for e in _grouper(32, tokens):
|
||||
enc, pooled = encode_func(e)
|
||||
enc = enc.reshape((len(e), length, -1))
|
||||
embs.append(enc)
|
||||
embs = torch.cat(embs)
|
||||
embs = embs.reshape((len(tokens) // num_chunks, length * num_chunks, -1))
|
||||
return embs
|
||||
|
||||
def from_masked(tokens, weights, word_ids, base_emb, length, encode_func, m_token=266):
|
||||
pooled_base = base_emb[0,length-1:length,:]
|
||||
wids, inds = np.unique(np.array(word_ids).reshape(-1), return_index=True)
|
||||
weight_dict = dict((id,w)
|
||||
for id,w in zip(wids ,np.array(weights).reshape(-1)[inds])
|
||||
if w != 1.0)
|
||||
|
||||
if len(weight_dict) == 0:
|
||||
return torch.zeros_like(base_emb), base_emb[0,length-1:length,:]
|
||||
|
||||
weight_tensor = torch.tensor(weights, dtype=base_emb.dtype, device=base_emb.device)
|
||||
weight_tensor = weight_tensor.reshape(1,-1,1).expand(base_emb.shape)
|
||||
|
||||
#m_token = (clip.tokenizer.end_token, 1.0) if clip.tokenizer.pad_with_end else (0,1.0)
|
||||
#TODO: find most suitable masking token here
|
||||
m_token = (m_token, 1.0)
|
||||
|
||||
ws = []
|
||||
masked_tokens = []
|
||||
masks = []
|
||||
|
||||
#create prompts
|
||||
for id, w in weight_dict.items():
|
||||
masked, m = mask_word_id(tokens, word_ids, id, m_token)
|
||||
masked_tokens.extend(masked)
|
||||
|
||||
m = torch.tensor(m, dtype=base_emb.dtype, device=base_emb.device)
|
||||
m = m.reshape(1,-1,1).expand(base_emb.shape)
|
||||
masks.append(m)
|
||||
|
||||
ws.append(w)
|
||||
|
||||
#batch process prompts
|
||||
embs = batched_clip_encode(masked_tokens, length, encode_func, len(tokens))
|
||||
masks = torch.cat(masks)
|
||||
|
||||
embs = (base_emb.expand(embs.shape) - embs)
|
||||
pooled = embs[0,length-1:length,:]
|
||||
|
||||
embs *= masks
|
||||
embs = embs.sum(axis=0, keepdim=True)
|
||||
|
||||
pooled_start = pooled_base.expand(len(ws), -1)
|
||||
ws = torch.tensor(ws).reshape(-1,1).expand(pooled_start.shape)
|
||||
pooled = (pooled - pooled_start) * (ws - 1)
|
||||
pooled = pooled.mean(axis=0, keepdim=True)
|
||||
|
||||
return ((weight_tensor - 1) * embs), pooled_base + pooled
|
||||
|
||||
def mask_inds(tokens, inds, mask_token):
|
||||
clip_len = len(tokens[0])
|
||||
inds_set = set(inds)
|
||||
new_tokens = [[mask_token if i*clip_len + j in inds_set else t
|
||||
for j, t in enumerate(x)] for i, x in enumerate(tokens)]
|
||||
return new_tokens
|
||||
|
||||
def down_weight(tokens, weights, word_ids, base_emb, length, encode_func, m_token=266):
|
||||
w, w_inv = np.unique(weights,return_inverse=True)
|
||||
|
||||
if np.sum(w < 1) == 0:
|
||||
return base_emb, tokens, base_emb[0,length-1:length,:]
|
||||
#m_token = (clip.tokenizer.end_token, 1.0) if clip.tokenizer.pad_with_end else (0,1.0)
|
||||
#using the comma token as a masking token seems to work better than aos tokens for SD 1.x
|
||||
m_token = (m_token, 1.0)
|
||||
|
||||
masked_tokens = []
|
||||
|
||||
masked_current = tokens
|
||||
for i in range(len(w)):
|
||||
if w[i] >= 1:
|
||||
continue
|
||||
masked_current = mask_inds(masked_current, np.where(w_inv == i)[0], m_token)
|
||||
masked_tokens.extend(masked_current)
|
||||
|
||||
embs = batched_clip_encode(masked_tokens, length, encode_func, len(tokens))
|
||||
embs = torch.cat([base_emb, embs])
|
||||
w = w[w<=1.0]
|
||||
w_mix = np.diff([0] + w.tolist())
|
||||
w_mix = torch.tensor(w_mix, dtype=embs.dtype, device=embs.device).reshape((-1,1,1))
|
||||
|
||||
weighted_emb = (w_mix * embs).sum(axis=0, keepdim=True)
|
||||
return weighted_emb, masked_current, weighted_emb[0,length-1:length,:]
|
||||
|
||||
def scale_emb_to_mag(base_emb, weighted_emb):
|
||||
norm_base = torch.linalg.norm(base_emb)
|
||||
norm_weighted = torch.linalg.norm(weighted_emb)
|
||||
embeddings_final = (norm_base / norm_weighted) * weighted_emb
|
||||
return embeddings_final
|
||||
|
||||
def recover_dist(base_emb, weighted_emb):
|
||||
fixed_std = (base_emb.std() / weighted_emb.std()) * (weighted_emb - weighted_emb.mean())
|
||||
embeddings_final = fixed_std + (base_emb.mean() - fixed_std.mean())
|
||||
return embeddings_final
|
||||
|
||||
def A1111_renorm(base_emb, weighted_emb):
|
||||
embeddings_final = (base_emb.mean() / weighted_emb.mean()) * weighted_emb
|
||||
return embeddings_final
|
||||
|
||||
def advanced_encode_from_tokens(tokenized, token_normalization, weight_interpretation, encode_func, m_token=266, length=77, w_max=1.0, return_pooled=False, apply_to_pooled=False):
|
||||
tokens = [[t for t,_,_ in x] for x in tokenized]
|
||||
weights = [[w for _,w,_ in x] for x in tokenized]
|
||||
word_ids = [[wid for _,_,wid in x] for x in tokenized]
|
||||
|
||||
#weight normalization
|
||||
#====================
|
||||
|
||||
#distribute down/up weights over word lengths
|
||||
if token_normalization.startswith("length"):
|
||||
weights = divide_length(word_ids, weights)
|
||||
|
||||
#make mean of word tokens 1
|
||||
if token_normalization.endswith("mean"):
|
||||
weights = shift_mean_weight(word_ids, weights)
|
||||
|
||||
#weight interpretation
|
||||
#=====================
|
||||
pooled = None
|
||||
|
||||
if weight_interpretation == "comfy":
|
||||
weighted_tokens = [[(t,w) for t, w in zip(x, y)] for x, y in zip(tokens, weights)]
|
||||
weighted_emb, pooled_base = encode_func(weighted_tokens)
|
||||
pooled = pooled_base
|
||||
else:
|
||||
unweighted_tokens = [[(t,1.0) for t, _,_ in x] for x in tokenized]
|
||||
base_emb, pooled_base = encode_func(unweighted_tokens)
|
||||
|
||||
if weight_interpretation == "A1111":
|
||||
weighted_emb = from_zero(weights, base_emb)
|
||||
weighted_emb = A1111_renorm(base_emb, weighted_emb)
|
||||
pooled = pooled_base
|
||||
|
||||
if weight_interpretation == "compel":
|
||||
pos_tokens = [[(t,w) if w >= 1.0 else (t,1.0) for t, w in zip(x, y)] for x, y in zip(tokens, weights)]
|
||||
weighted_emb, _ = encode_func(pos_tokens)
|
||||
weighted_emb, _, pooled = down_weight(pos_tokens, weights, word_ids, weighted_emb, length, encode_func)
|
||||
|
||||
if weight_interpretation == "comfy++":
|
||||
weighted_emb, tokens_down, _ = down_weight(unweighted_tokens, weights, word_ids, base_emb, length, encode_func)
|
||||
weights = [[w if w > 1.0 else 1.0 for w in x] for x in weights]
|
||||
#unweighted_tokens = [[(t,1.0) for t, _,_ in x] for x in tokens_down]
|
||||
embs, pooled = from_masked(unweighted_tokens, weights, word_ids, base_emb, length, encode_func)
|
||||
weighted_emb += embs
|
||||
|
||||
if weight_interpretation == "down_weight":
|
||||
weights = scale_to_norm(weights, word_ids, w_max)
|
||||
weighted_emb, _, pooled = down_weight(unweighted_tokens, weights, word_ids, base_emb, length, encode_func)
|
||||
|
||||
if return_pooled:
|
||||
if apply_to_pooled:
|
||||
return weighted_emb, pooled
|
||||
else:
|
||||
return weighted_emb, pooled_base
|
||||
return weighted_emb, None
|
||||
|
||||
def encode_token_weights_g(model, token_weight_pairs):
|
||||
return model.clip_g.encode_token_weights(token_weight_pairs)
|
||||
|
||||
def encode_token_weights_l(model, token_weight_pairs):
|
||||
l_out, _ = model.clip_l.encode_token_weights(token_weight_pairs)
|
||||
return l_out, None
|
||||
|
||||
def encode_token_weights(model, token_weight_pairs, encode_func):
|
||||
if model.layer_idx is not None:
|
||||
model.cond_stage_model.clip_layer(model.layer_idx)
|
||||
|
||||
model_management.load_model_gpu(model.patcher)
|
||||
return encode_func(model.cond_stage_model, token_weight_pairs)
|
||||
|
||||
def prepareXL(embs_l, embs_g, pooled, clip_balance):
|
||||
l_w = 1 - max(0, clip_balance - .5) * 2
|
||||
g_w = 1 - max(0, .5 - clip_balance) * 2
|
||||
if embs_l is not None:
|
||||
return torch.cat([embs_l * l_w, embs_g * g_w], dim=-1), pooled
|
||||
else:
|
||||
return embs_g, pooled
|
||||
|
||||
def advanced_encode(clip, text, token_normalization, weight_interpretation, w_max=1.0, clip_balance=.5, apply_to_pooled=True):
|
||||
tokenized = clip.tokenize(text, return_word_ids=True)
|
||||
if isinstance(tokenized, dict):
|
||||
embs_l = None
|
||||
embs_g = None
|
||||
pooled = None
|
||||
if 'l' in tokenized and isinstance(clip.cond_stage_model, SDXLClipModel):
|
||||
embs_l, _ = advanced_encode_from_tokens(tokenized['l'],
|
||||
token_normalization,
|
||||
weight_interpretation,
|
||||
lambda x: encode_token_weights(clip, x, encode_token_weights_l),
|
||||
w_max=w_max,
|
||||
return_pooled=False)
|
||||
if 'g' in tokenized:
|
||||
embs_g, pooled = advanced_encode_from_tokens(tokenized['g'],
|
||||
token_normalization,
|
||||
weight_interpretation,
|
||||
lambda x: encode_token_weights(clip, x, encode_token_weights_g),
|
||||
w_max=w_max,
|
||||
return_pooled=True,
|
||||
apply_to_pooled=apply_to_pooled)
|
||||
return prepareXL(embs_l, embs_g, pooled, clip_balance)
|
||||
else:
|
||||
return advanced_encode_from_tokens(tokenized,
|
||||
token_normalization,
|
||||
weight_interpretation,
|
||||
lambda x: (clip.encode_from_tokens(x), None),
|
||||
w_max=w_max)
|
||||
|
||||
########################################################################################################################
|
||||
from nodes import MAX_RESOLUTION
|
||||
|
||||
class AdvancedCLIPTextEncode:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"text": ("STRING", {"multiline": True}),
|
||||
"clip": ("CLIP",),
|
||||
"token_normalization": (["none", "mean", "length", "length+mean"],),
|
||||
"weight_interpretation": (["comfy", "A1111", "compel", "comfy++", "down_weight"],),
|
||||
# "affect_pooled": (["disable", "enable"],),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/advanced"
|
||||
|
||||
def encode(self, clip, text, token_normalization, weight_interpretation, affect_pooled='disable'):
|
||||
embeddings_final, pooled = advanced_encode(clip, text, token_normalization, weight_interpretation, w_max=1.0,
|
||||
apply_to_pooled=affect_pooled == 'enable')
|
||||
return ([[embeddings_final, {"pooled_output": pooled}]],)
|
||||
|
||||
|
||||
class AddCLIPSDXLRParams:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {
|
||||
"conditioning": ("CONDITIONING",),
|
||||
"width": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
|
||||
"height": ("INT", {"default": 1024.0, "min": 0, "max": MAX_RESOLUTION}),
|
||||
"ascore": ("FLOAT", {"default": 6.0, "min": 0.0, "max": 1000.0, "step": 0.01}),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("CONDITIONING",)
|
||||
FUNCTION = "encode"
|
||||
|
||||
CATEGORY = "conditioning/advanced"
|
||||
|
||||
def encode(self, conditioning, width, height, ascore):
|
||||
c = []
|
||||
for t in conditioning:
|
||||
n = [t[0], t[1].copy()]
|
||||
n[1]['width'] = width
|
||||
n[1]['height'] = height
|
||||
n[1]['aesthetic_score'] = ascore
|
||||
c.append(n)
|
||||
return (c,)
|
||||
|
||||
Reference in New Issue
Block a user