diff --git a/.idea/workspace.xml b/.idea/workspace.xml
new file mode 100644
index 0000000..cc1717f
--- /dev/null
+++ b/.idea/workspace.xml
@@ -0,0 +1,43 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ {
+ "keyToString": {
+ "RunOnceActivity.OpenProjectViewOnStart": "true",
+ "RunOnceActivity.ShowReadmeOnStart": "true",
+ "last_opened_file_path": "/Users/zilliz/PycharmProjects/operator/video_text_embedding/bridge-former",
+ "settings.editor.selected.configurable": "com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable"
+ }
+}
+
+
+
+
+ 1655864445198
+
+
+ 1655864445198
+
+
+
+
\ No newline at end of file
diff --git a/README.md b/README.md
index 07fa084..df590ef 100644
--- a/README.md
+++ b/README.md
@@ -1,2 +1,115 @@
-# bridge-former
+# Video-Text Retrieval Embedding with BridgeFormer
+
+*author: Jinling Xu*
+
+
+
+## Description
+
+This operator extracts features for video or text with [BridgeFormer](https://arxiv.org/pdf/2201.04850.pdf) which can generate embeddings for text and video by jointly training a video encoder and text encoder to maximize the cosine similarity.
+
+
+
+
+
+## Code Example
+
+Load a video from path './demo_video.mp4' to generate a video embedding.
+
+Read the text 'kids feeding and playing with the horse' to generate a text embedding.
+
+ *Write the pipeline in simplified style*:
+
+- Encode video (default):
+```python
+import towhee
+towhee.dc(['./demo_video.mp4']) \
+ .video_decode.ffmpeg() \
+ .video_text_embedding.bridge_former(model_name='frozen_model', modality='video') \
+ .show()
+
+```
+
+
+- Encode text:
+```python
+import towhee
+
+towhee.dc(['kids feeding and playing with the horse']) \
+ .video_text_embedding.bridge_former(model_name='frozen_model', modality='text') \
+ .show()
+```
+
+
+*Write a same pipeline with explicit inputs/outputs name specifications:*
+
+```python
+import towhee
+
+towhee.dc['path'](['./demo_video.mp4']) \
+ .video_decode.ffmpeg['path', 'frames'](sample_type='uniform_temporal_subsample', args={'num_samples': 4}) \
+ .runas_op['frames', 'frames'](func=lambda x: [y for y in x]) \
+ .video_text_embedding.bridge_former['frames', 'vec'](model_name='frozen_model', modality='video') \
+ .select['path', 'vec']() \
+ .show(formatter={'path': 'video_path'})
+
+towhee.dc['text'](["kids feeding and playing with the horse"]) \
+ .video_text_embedding.bridge_former['text','vec'](model_name='frozen_model', modality='text') \
+ .select['text', 'vec']() \
+ .show()
+```
+
+
+
+
+
+
+
+
+## Factory Constructor
+
+Create the operator via the following factory method
+
+***bridge_former(model_name, modality, weight_path)***
+
+**Parameters:**
+
+ ***model_name:*** *str*
+
+ The model name of frozen in time. Supported model names:
+- frozen_model
+- clip_initialized_model
+
+
+ ***modality:*** *str*
+
+ Which modality(*video* or *text*) is used to generate the embedding.
+
+ ***weight_path:*** *str*
+
+ pretrained model weights path.
+
+
+
+
+
+## Interface
+
+An video-text embedding operator takes a list of [Towhee VideoFrame](link/to/towhee/image/api/doc) or string as input and generate an embedding in ndarray.
+
+
+**Parameters:**
+
+ ***data:*** *List[towhee.types.Image]* or *str*
+
+ The data (list of Towhee VideoFrame (which is uniform subsampled from a video) or text based on specified modality) to generate embedding.
+
+
+
+**Returns:** *numpy.ndarray*
+
+ The data embedding extracted by model.
+
+
+
diff --git a/__init__.py b/__init__.py
new file mode 100644
index 0000000..05cdbbd
--- /dev/null
+++ b/__init__.py
@@ -0,0 +1,20 @@
+# Copyright 2021 Zilliz. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .bridge_former import BridgeFormer
+
+
+def bridge_former(**kwargs):
+ return BridgeFormer(**kwargs)
+
diff --git a/bridge_former.py b/bridge_former.py
new file mode 100644
index 0000000..156f2c0
--- /dev/null
+++ b/bridge_former.py
@@ -0,0 +1,100 @@
+import logging
+import os
+import json
+from pathlib import Path
+from typing import List, Union
+import torch
+import numpy
+from towhee import register
+from towhee.operator.base import NNOperator
+from towhee.types.video_frame import VideoFrame
+from towhee.models.utils.video_transforms import transform_video
+from towhee.models.bridgeformer import bridge_former
+from transformers import AutoTokenizer
+
+from .get_configs import configs
+log = logging.getLogger()
+
+
+@register(output_schema=['labels', 'scores', 'features'])
+class BridgeFormer(NNOperator):
+ """
+ extracts features for video or text with BridgeFormer model
+ Args:
+ model_name (str):
+ BridgeFormer model name to be used in BridgeFormer
+ modality (str):
+ Flag to decide what to return
+ - 'video': return video embedding
+ - 'text': return a dense of text embeddings
+ weights_path (str):
+ Pretrained model weights
+ """
+ def __init__(self,
+ model_name: str = "frozen_model",
+ modality: str = 'video',
+ weights_path: str = None,
+ framework: str = "pytorch",
+ skip_preprocess: bool = False,
+
+ ):
+ super().__init__(framework=framework)
+ self.model_name = model_name
+ self.skip_preprocess = skip_preprocess
+ self.modality = modality
+
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
+ if weights_path is None:
+ weights_name = {"clip_initialized_model": "MCQ_CLIP.pth", "frozen_model": "MCQ.pth"}
+ weights_path = os.path.join(str(Path(__file__).parent), weights_name[self.model_name])
+ self.model = bridge_former.create_model(pretrained=True,
+ weights_path=weights_path,
+ model_name=self.model_name)
+ self.tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased', TOKENIZERS_PARALLELISM=False)
+
+ self.transform_cfgs = configs(self.model_name)
+
+ def decoder_video(self, data: List[VideoFrame]):
+ # Convert list of towhee.types.Image to numpy.ndarray in float32
+ video = numpy.stack([img.astype(numpy.float32) / 255. for img in data], axis=0)
+ assert len(video.shape) == 4
+
+ video = video.transpose(3, 0, 1, 2) # thwc -> cthw
+
+ video = transform_video(
+ video=video,
+ **self.transform_cfgs
+ )
+ # [B x C x T x H x W]
+ video = video.to(self.device)[None, ...]
+ return video
+
+ def __call__(self, data: Union[List[VideoFrame], List[str]]):
+ if self.modality == 'video':
+ vec = self._inference_from_video(data)
+ elif self.modality == 'text':
+ vec = self._inference_from_text(data)
+ else:
+ raise ValueError("modality[{}] not implemented.".format(self._modality))
+ return vec
+
+ def _inference_from_text(self, text: List[str]):
+ text_data = self.tokenizer(text, return_tensors='pt')
+
+ text_data = text_data.to(self.device)
+ if self.model_name == "clip_initialized_model":
+ text_features = self.model.encode_text(text_data["input_ids"])
+ else:
+ text_features = self.model.compute_text(text_data)
+ return text_features.squeeze(0).detach().flatten().cpu().numpy()
+
+ def _inference_from_video(self, data: List[VideoFrame]):
+ # [B x T x C x H x W]
+ video = self.decoder_video(data).transpose(1, 2)
+ if self.model_name == "clip_initialized_model":
+ visual_features = self.model.encode_image(video)
+ else:
+ visual_features = self.model.compute_video(video)
+ return visual_features.squeeze(0).detach().flatten().cpu().numpy()
+
+
diff --git a/demo_video.mp4 b/demo_video.mp4
new file mode 100755
index 0000000..e6fb645
Binary files /dev/null and b/demo_video.mp4 differ
diff --git a/get_configs.py b/get_configs.py
new file mode 100644
index 0000000..04eea88
--- /dev/null
+++ b/get_configs.py
@@ -0,0 +1,19 @@
+
+
+def configs(model_name):
+ args = {
+ 'clip_initialized_model':
+ {"side_size": 224,
+ "crop_size": 256,
+ "num_frames": 8,
+ "mean": [0.48145466, 0.4578275, 0.40821073],
+ "std": [0.26862954, 0.26130258, 0.27577711]},
+ 'frozen_model':
+ {"side_size": 224,
+ "crop_size": 256,
+ "num_frames": 4,
+ "mean": [0.485, 0.456, 0.406],
+ "std": [0.229, 0.224, 0.225], }
+ }
+ return args[model_name]
+