From cc54d1d5dff22c65015cc27966f4ae9e24a8182d Mon Sep 17 00:00:00 2001 From: VALADI K JAGANATHAN Date: Wed, 1 Nov 2023 11:44:02 +0530 Subject: [PATCH] Update bnk_adv_encode.py --- py/bnk_adv_encode.py | 32 ++++++++++++++++++++------------ 1 file changed, 20 insertions(+), 12 deletions(-) diff --git a/py/bnk_adv_encode.py b/py/bnk_adv_encode.py index 077c712..49b126d 100644 --- a/py/bnk_adv_encode.py +++ b/py/bnk_adv_encode.py @@ -238,17 +238,24 @@ def prepareXL(embs_l, embs_g, pooled, clip_balance): def advanced_encode(clip, text, token_normalization, weight_interpretation, w_max=1.0, clip_balance=.5, apply_to_pooled=True): tokenized = clip.tokenize(text, return_word_ids=True) - if isinstance(clip.cond_stage_model, (SDXLClipModel, SDXLRefinerClipModel, SDXLClipG)): + if isinstance(tokenized, dict): embs_l = None embs_g = None pooled = None - if 'l' in tokenized and isinstance(clip.cond_stage_model, SDXLClipModel): - embs_l, _ = advanced_encode_from_tokens(tokenized['l'], - token_normalization, - weight_interpretation, - lambda x: encode_token_weights(clip, x, encode_token_weights_l), - w_max=w_max, - return_pooled=False) + if 'l' in tokenized: + if isinstance(clip.cond_stage_model, SDXLClipModel): + embs_l, _ = advanced_encode_from_tokens(tokenized['l'], + token_normalization, + weight_interpretation, + lambda x: encode_token_weights(clip, x, encode_token_weights_l), + w_max=w_max, + return_pooled=False) + else: + return advanced_encode_from_tokens(tokenized['l'], + token_normalization, + weight_interpretation, + lambda x: encode_token_weights(clip, x, encode_token_weights_l), + w_max=w_max) if 'g' in tokenized: embs_g, pooled = advanced_encode_from_tokens(tokenized['g'], token_normalization, @@ -259,11 +266,12 @@ def advanced_encode(clip, text, token_normalization, weight_interpretation, w_ma apply_to_pooled=apply_to_pooled) return prepareXL(embs_l, embs_g, pooled, clip_balance) else: - return advanced_encode_from_tokens(tokenized['l'], - token_normalization, - weight_interpretation, - lambda x: (clip.encode_from_tokens({'l': x}), None), + return advanced_encode_from_tokens(tokenized, + token_normalization, + weight_interpretation, + lambda x: (clip.encode_from_tokens(x), None), w_max=w_max) + def advanced_encode_XL(clip, text1, text2, token_normalization, weight_interpretation, w_max=1.0, clip_balance=.5, apply_to_pooled=True): tokenized1 = clip.tokenize(text1, return_word_ids=True) tokenized2 = clip.tokenize(text2, return_word_ids=True)