diff --git a/py/bnk_adv_encode.py b/py/bnk_adv_encode.py index f44eecf..0e30a05 100644 --- a/py/bnk_adv_encode.py +++ b/py/bnk_adv_encode.py @@ -1,10 +1,10 @@ import torch import numpy as np import itertools -from math import gcd +#from math import gcd from comfy import model_management -from comfy.sdxl_clip import SDXLClipModel, SDXLRefinerClipModel, SDXLClipG +from comfy.sdxl_clip import SDXLClipModel def _grouper(n, iterable): it = iter(iterable) @@ -235,11 +235,10 @@ def prepareXL(embs_l, embs_g, pooled, clip_balance): return torch.cat([embs_l * l_w, embs_g * g_w], dim=-1), pooled else: return embs_g, pooled - #===================== def advanced_encode(clip, text, token_normalization, weight_interpretation, w_max=1.0, clip_balance=.5, apply_to_pooled=True): tokenized = clip.tokenize(text, return_word_ids=True) - if isinstance(clip.cond_stage_model, (SDXLClipModel, SDXLRefinerClipModel, SDXLClipG)): + if isinstance(tokenized, dict): embs_l = None embs_g = None pooled = None @@ -273,32 +272,6 @@ def advanced_encode(clip, text, token_normalization, weight_interpretation, w_ma lambda x: (clip.encode_from_tokens(x), None), w_max=w_max) - #===================== -def advanced_encode_XL(clip, text1, text2, token_normalization, weight_interpretation, w_max=1.0, clip_balance=.5, apply_to_pooled=True): - tokenized1 = clip.tokenize(text1, return_word_ids=True) - tokenized2 = clip.tokenize(text2, return_word_ids=True) - - embs_l, _ = advanced_encode_from_tokens(tokenized1['l'], - token_normalization, - weight_interpretation, - lambda x: encode_token_weights(clip, x, encode_token_weights_l), - w_max=w_max, - return_pooled=False) - - embs_g, pooled = advanced_encode_from_tokens(tokenized2['g'], - token_normalization, - weight_interpretation, - lambda x: encode_token_weights(clip, x, encode_token_weights_g), - w_max=w_max, - return_pooled=True, - apply_to_pooled=apply_to_pooled) - - gcd_num = gcd(embs_l.shape[1], embs_g.shape[1]) - repeat_l = int((embs_g.shape[1] / gcd_num) * embs_l.shape[1]) - repeat_g = int((embs_l.shape[1] / gcd_num) * embs_g.shape[1]) - - return prepareXL(embs_l.expand((-1,repeat_l,-1)), embs_g.expand((-1,repeat_g,-1)), pooled, clip_balance) - ######################################################################################################################## from nodes import MAX_RESOLUTION