|
@ -91,3 +91,33 @@ class DRL(NNOperator): |
|
|
visual_output = self.model.get_video_feat(video, video_mask, shaped=True) # B(1), N_v, D |
|
|
visual_output = self.model.get_video_feat(video, video_mask, shaped=True) # B(1), N_v, D |
|
|
|
|
|
|
|
|
return visual_output.squeeze(0).detach().cpu().numpy() # N_v, D |
|
|
return visual_output.squeeze(0).detach().cpu().numpy() # N_v, D |
|
|
|
|
|
|
|
|
|
|
|
def get_similarity_logits(self, text_feat, video_feat, text_mask, video_mask): |
|
|
|
|
|
""" |
|
|
|
|
|
Input is the embedding feature extracted from text or video, calculate the similarity matrix. |
|
|
|
|
|
When input batch size is 1, return one similarity value. |
|
|
|
|
|
Args: |
|
|
|
|
|
text_feat (`torch.Tensor`): |
|
|
|
|
|
Shape is (B, N_t, D). B means batch size, N_t means token length of text, D mean the dim of the network. |
|
|
|
|
|
video_feat (`torch.Tensor`): |
|
|
|
|
|
Shape is (B, N_v, D). B means batch size, N_t means token length of video, D mean the dim of the network. |
|
|
|
|
|
text_mask (`torch.Tensor`): |
|
|
|
|
|
Shape is (B, N_t), valid token position is 1, else 0. |
|
|
|
|
|
video_mask (`torch.Tensor`): |
|
|
|
|
|
Shape is (B, N_v), valid token position is 1, else 0. |
|
|
|
|
|
|
|
|
|
|
|
Returns: |
|
|
|
|
|
Similarity matrix with shape (B, B) |
|
|
|
|
|
""" |
|
|
|
|
|
t2v_logits, _, _ = self.model.get_similarity_logits(text_feat, video_feat, text_mask, video_mask, shaped=True) |
|
|
|
|
|
# (B, B) |
|
|
|
|
|
return t2v_logits |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# if __name__ == '__main__': |
|
|
|
|
|
# from towhee.models.clip import SimpleTokenizer |
|
|
|
|
|
# text = 'hello world' |
|
|
|
|
|
# ids = convert_tokens_to_id(SimpleTokenizer(), text, max_words=32) |
|
|
|
|
|
# print(ids) |
|
|
|
|
|
# text_mask = [[1 if i > 0 else 0 for i in ids[0]]] |
|
|
|
|
|
# print(text_mask) |