diff --git a/py/bnk_adv_encode.py b/py/bnk_adv_encode.py index 0aa2c3a..f44eecf 100644 --- a/py/bnk_adv_encode.py +++ b/py/bnk_adv_encode.py @@ -235,6 +235,7 @@ def prepareXL(embs_l, embs_g, pooled, clip_balance): return torch.cat([embs_l * l_w, embs_g * g_w], dim=-1), pooled else: return embs_g, pooled + #===================== def advanced_encode(clip, text, token_normalization, weight_interpretation, w_max=1.0, clip_balance=.5, apply_to_pooled=True): tokenized = clip.tokenize(text, return_word_ids=True) @@ -271,7 +272,8 @@ def advanced_encode(clip, text, token_normalization, weight_interpretation, w_ma weight_interpretation, lambda x: (clip.encode_from_tokens(x), None), w_max=w_max) - + + #===================== def advanced_encode_XL(clip, text1, text2, token_normalization, weight_interpretation, w_max=1.0, clip_balance=.5, apply_to_pooled=True): tokenized1 = clip.tokenize(text1, return_word_ids=True) tokenized2 = clip.tokenize(text2, return_word_ids=True)