Update bnk_adv_encode.py

This commit is contained in:
VALADI K JAGANATHAN
2023-11-01 12:14:43 +05:30
committed by GitHub
parent d431b9caca
commit 25a7bb2539

View File

@@ -1,10 +1,10 @@
import torch
import numpy as np
import itertools
from math import gcd
#from math import gcd
from comfy import model_management
from comfy.sdxl_clip import SDXLClipModel, SDXLRefinerClipModel, SDXLClipG
from comfy.sdxl_clip import SDXLClipModel
def _grouper(n, iterable):
it = iter(iterable)
@@ -235,11 +235,10 @@ def prepareXL(embs_l, embs_g, pooled, clip_balance):
return torch.cat([embs_l * l_w, embs_g * g_w], dim=-1), pooled
else:
return embs_g, pooled
#=====================
def advanced_encode(clip, text, token_normalization, weight_interpretation, w_max=1.0, clip_balance=.5, apply_to_pooled=True):
tokenized = clip.tokenize(text, return_word_ids=True)
if isinstance(clip.cond_stage_model, (SDXLClipModel, SDXLRefinerClipModel, SDXLClipG)):
if isinstance(tokenized, dict):
embs_l = None
embs_g = None
pooled = None
@@ -273,32 +272,6 @@ def advanced_encode(clip, text, token_normalization, weight_interpretation, w_ma
lambda x: (clip.encode_from_tokens(x), None),
w_max=w_max)
#=====================
def advanced_encode_XL(clip, text1, text2, token_normalization, weight_interpretation, w_max=1.0, clip_balance=.5, apply_to_pooled=True):
tokenized1 = clip.tokenize(text1, return_word_ids=True)
tokenized2 = clip.tokenize(text2, return_word_ids=True)
embs_l, _ = advanced_encode_from_tokens(tokenized1['l'],
token_normalization,
weight_interpretation,
lambda x: encode_token_weights(clip, x, encode_token_weights_l),
w_max=w_max,
return_pooled=False)
embs_g, pooled = advanced_encode_from_tokens(tokenized2['g'],
token_normalization,
weight_interpretation,
lambda x: encode_token_weights(clip, x, encode_token_weights_g),
w_max=w_max,
return_pooled=True,
apply_to_pooled=apply_to_pooled)
gcd_num = gcd(embs_l.shape[1], embs_g.shape[1])
repeat_l = int((embs_g.shape[1] / gcd_num) * embs_l.shape[1])
repeat_g = int((embs_l.shape[1] / gcd_num) * embs_g.shape[1])
return prepareXL(embs_l.expand((-1,repeat_l,-1)), embs_g.expand((-1,repeat_g,-1)), pooled, clip_balance)
########################################################################################################################
from nodes import MAX_RESOLUTION