logo
Browse Source

l

Signed-off-by: xujinling <jinling.xu@zilliz.com>
main
xujinling 3 years ago
parent
commit
f38b1ca451
  1. 6
      bridge_former.py

6
bridge_former.py

@ -10,6 +10,7 @@ from towhee.operator.base import NNOperator
from towhee.types.video_frame import VideoFrame
from towhee.models.utils.video_transforms import transform_video, get_configs
from towhee.models.bridgeformer import bridge_former
from towhee.models.clip.clip_utils import tokenize
from transformers import AutoTokenizer
from .get_configs import configs
@ -50,8 +51,6 @@ class BridgeFormer(NNOperator):
self.model = bridge_former.create_model(pretrained=True,
weights_path=weights_path,
model_name=self.model_name)
self.tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased', TOKENIZERS_PARALLELISM=False)
self.transform_cfgs = configs(self.model_name)
self.model.eval()
@ -84,8 +83,7 @@ class BridgeFormer(NNOperator):
def _inference_from_text(self, text: List[str]):
if self.model_name == "clip_initialized_model":
text_data = self.tokenizer(text, return_tensors='pt', padding='max_length', max_length=77,
truncation=True).to(self.device)
text_data = tokenize(text).to(self.device)
text_features = self.model.encode_text(text_data["input_ids"])
else:
text_data = self.tokenizer(text, return_tensors='pt').to(self.device)

Loading…
Cancel
Save