diff --git a/py/bnk_adv_encode.py b/py/bnk_adv_encode.py index 0e30a05..077c712 100644 --- a/py/bnk_adv_encode.py +++ b/py/bnk_adv_encode.py @@ -1,10 +1,10 @@ import torch import numpy as np import itertools -#from math import gcd +from math import gcd from comfy import model_management -from comfy.sdxl_clip import SDXLClipModel +from comfy.sdxl_clip import SDXLClipModel, SDXLRefinerClipModel, SDXLClipG def _grouper(n, iterable): it = iter(iterable) @@ -238,24 +238,17 @@ def prepareXL(embs_l, embs_g, pooled, clip_balance): def advanced_encode(clip, text, token_normalization, weight_interpretation, w_max=1.0, clip_balance=.5, apply_to_pooled=True): tokenized = clip.tokenize(text, return_word_ids=True) - if isinstance(tokenized, dict): + if isinstance(clip.cond_stage_model, (SDXLClipModel, SDXLRefinerClipModel, SDXLClipG)): embs_l = None embs_g = None pooled = None - if 'l' in tokenized: - if isinstance(clip.cond_stage_model, SDXLClipModel): - embs_l, _ = advanced_encode_from_tokens(tokenized['l'], - token_normalization, - weight_interpretation, - lambda x: encode_token_weights(clip, x, encode_token_weights_l), - w_max=w_max, - return_pooled=False) - else: - return advanced_encode_from_tokens(tokenized['l'], - token_normalization, - weight_interpretation, - lambda x: encode_token_weights(clip, x, encode_token_weights_l), - w_max=w_max) + if 'l' in tokenized and isinstance(clip.cond_stage_model, SDXLClipModel): + embs_l, _ = advanced_encode_from_tokens(tokenized['l'], + token_normalization, + weight_interpretation, + lambda x: encode_token_weights(clip, x, encode_token_weights_l), + w_max=w_max, + return_pooled=False) if 'g' in tokenized: embs_g, pooled = advanced_encode_from_tokens(tokenized['g'], token_normalization, @@ -266,11 +259,35 @@ def advanced_encode(clip, text, token_normalization, weight_interpretation, w_ma apply_to_pooled=apply_to_pooled) return prepareXL(embs_l, embs_g, pooled, clip_balance) else: - return advanced_encode_from_tokens(tokenized, - token_normalization, - weight_interpretation, - lambda x: (clip.encode_from_tokens(x), None), + return advanced_encode_from_tokens(tokenized['l'], + token_normalization, + weight_interpretation, + lambda x: (clip.encode_from_tokens({'l': x}), None), w_max=w_max) +def advanced_encode_XL(clip, text1, text2, token_normalization, weight_interpretation, w_max=1.0, clip_balance=.5, apply_to_pooled=True): + tokenized1 = clip.tokenize(text1, return_word_ids=True) + tokenized2 = clip.tokenize(text2, return_word_ids=True) + + embs_l, _ = advanced_encode_from_tokens(tokenized1['l'], + token_normalization, + weight_interpretation, + lambda x: encode_token_weights(clip, x, encode_token_weights_l), + w_max=w_max, + return_pooled=False) + + embs_g, pooled = advanced_encode_from_tokens(tokenized2['g'], + token_normalization, + weight_interpretation, + lambda x: encode_token_weights(clip, x, encode_token_weights_g), + w_max=w_max, + return_pooled=True, + apply_to_pooled=apply_to_pooled) + + gcd_num = gcd(embs_l.shape[1], embs_g.shape[1]) + repeat_l = int((embs_g.shape[1] / gcd_num) * embs_l.shape[1]) + repeat_g = int((embs_l.shape[1] / gcd_num) * embs_g.shape[1]) + + return prepareXL(embs_l.expand((-1,repeat_l,-1)), embs_g.expand((-1,repeat_g,-1)), pooled, clip_balance) ######################################################################################################################## from nodes import MAX_RESOLUTION