logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

100 lines
3.7 KiB

import logging
import os
import json
from pathlib import Path
from typing import List, Union
import torch
import numpy
from towhee import register
from towhee.operator.base import NNOperator
from towhee.types.video_frame import VideoFrame
from towhee.models.utils.video_transforms import transform_video
from towhee.models.bridgeformer import bridge_former
from transformers import AutoTokenizer
from .get_configs import configs
log = logging.getLogger()
@register(output_schema=['labels', 'scores', 'features'])
class BridgeFormer(NNOperator):
"""
extracts features for video or text with BridgeFormer model
Args:
model_name (str):
BridgeFormer model name to be used in BridgeFormer
modality (str):
Flag to decide what to return
- 'video': return video embedding
- 'text': return a dense of text embeddings
weights_path (str):
Pretrained model weights
"""
def __init__(self,
model_name: str = "frozen_model",
modality: str = 'video',
weights_path: str = None,
framework: str = "pytorch",
skip_preprocess: bool = False,
):
super().__init__(framework=framework)
self.model_name = model_name
self.skip_preprocess = skip_preprocess
self.modality = modality
self.device = "cuda" if torch.cuda.is_available() else "cpu"
if weights_path is None:
weights_name = {"clip_initialized_model": "MCQ_CLIP.pth", "frozen_model": "MCQ.pth"}
weights_path = os.path.join(str(Path(__file__).parent), weights_name[self.model_name])
self.model = bridge_former.create_model(pretrained=True,
weights_path=weights_path,
model_name=self.model_name)
self.tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased', TOKENIZERS_PARALLELISM=False)
self.transform_cfgs = configs(self.model_name)
def decoder_video(self, data: List[VideoFrame]):
# Convert list of towhee.types.Image to numpy.ndarray in float32
video = numpy.stack([img.astype(numpy.float32) / 255. for img in data], axis=0)
assert len(video.shape) == 4
video = video.transpose(3, 0, 1, 2) # thwc -> cthw
video = transform_video(
video=video,
**self.transform_cfgs
)
# [B x C x T x H x W]
video = video.to(self.device)[None, ...]
return video
def __call__(self, data: Union[List[VideoFrame], List[str]]):
if self.modality == 'video':
vec = self._inference_from_video(data)
elif self.modality == 'text':
vec = self._inference_from_text(data)
else:
raise ValueError("modality[{}] not implemented.".format(self._modality))
return vec
def _inference_from_text(self, text: List[str]):
text_data = self.tokenizer(text, return_tensors='pt')
text_data = text_data.to(self.device)
if self.model_name == "clip_initialized_model":
text_features = self.model.encode_text(text_data["input_ids"])
else:
text_features = self.model.compute_text(text_data)
return text_features.squeeze(0).detach().flatten().cpu().numpy()
def _inference_from_video(self, data: List[VideoFrame]):
# [B x T x C x H x W]
video = self.decoder_video(data).transpose(1, 2)
if self.model_name == "clip_initialized_model":
visual_features = self.model.encode_image(video)
else:
visual_features = self.model.compute_video(video)
return visual_features.squeeze(0).detach().flatten().cpu().numpy()