From f38b1ca451ea435fea2d179b2b88dfff9c7fbd63 Mon Sep 17 00:00:00 2001 From: xujinling Date: Wed, 22 Jun 2022 18:18:51 +0800 Subject: [PATCH] l Signed-off-by: xujinling --- bridge_former.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/bridge_former.py b/bridge_former.py index 0bdc7aa..2109cb3 100644 --- a/bridge_former.py +++ b/bridge_former.py @@ -10,6 +10,7 @@ from towhee.operator.base import NNOperator from towhee.types.video_frame import VideoFrame from towhee.models.utils.video_transforms import transform_video, get_configs from towhee.models.bridgeformer import bridge_former +from towhee.models.clip.clip_utils import tokenize from transformers import AutoTokenizer from .get_configs import configs @@ -50,8 +51,6 @@ class BridgeFormer(NNOperator): self.model = bridge_former.create_model(pretrained=True, weights_path=weights_path, model_name=self.model_name) - self.tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased', TOKENIZERS_PARALLELISM=False) - self.transform_cfgs = configs(self.model_name) self.model.eval() @@ -84,8 +83,7 @@ class BridgeFormer(NNOperator): def _inference_from_text(self, text: List[str]): if self.model_name == "clip_initialized_model": - text_data = self.tokenizer(text, return_tensors='pt', padding='max_length', max_length=77, - truncation=True).to(self.device) + text_data = tokenize(text).to(self.device) text_features = self.model.encode_text(text_data["input_ids"]) else: text_data = self.tokenizer(text, return_tensors='pt').to(self.device)