|
|
@ -10,6 +10,7 @@ from towhee.operator.base import NNOperator |
|
|
|
from towhee.types.video_frame import VideoFrame |
|
|
|
from towhee.models.utils.video_transforms import transform_video, get_configs |
|
|
|
from towhee.models.bridgeformer import bridge_former |
|
|
|
from towhee.models.clip.clip_utils import tokenize |
|
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
|
|
from .get_configs import configs |
|
|
@ -50,8 +51,6 @@ class BridgeFormer(NNOperator): |
|
|
|
self.model = bridge_former.create_model(pretrained=True, |
|
|
|
weights_path=weights_path, |
|
|
|
model_name=self.model_name) |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased', TOKENIZERS_PARALLELISM=False) |
|
|
|
|
|
|
|
self.transform_cfgs = configs(self.model_name) |
|
|
|
|
|
|
|
self.model.eval() |
|
|
@ -84,8 +83,7 @@ class BridgeFormer(NNOperator): |
|
|
|
def _inference_from_text(self, text: List[str]): |
|
|
|
|
|
|
|
if self.model_name == "clip_initialized_model": |
|
|
|
text_data = self.tokenizer(text, return_tensors='pt', padding='max_length', max_length=77, |
|
|
|
truncation=True).to(self.device) |
|
|
|
text_data = tokenize(text).to(self.device) |
|
|
|
text_features = self.model.encode_text(text_data["input_ids"]) |
|
|
|
else: |
|
|
|
text_data = self.tokenizer(text, return_tensors='pt').to(self.device) |
|
|
|