logo
Browse Source

get_similarity_logits

main
ChengZi 3 years ago
parent
commit
82bda5fd39
  1. 30
      drl.py
  2. 1
      requirements.txt

30
drl.py

@ -91,3 +91,33 @@ class DRL(NNOperator):
visual_output = self.model.get_video_feat(video, video_mask, shaped=True) # B(1), N_v, D visual_output = self.model.get_video_feat(video, video_mask, shaped=True) # B(1), N_v, D
return visual_output.squeeze(0).detach().cpu().numpy() # N_v, D return visual_output.squeeze(0).detach().cpu().numpy() # N_v, D
def get_similarity_logits(self, text_feat, video_feat, text_mask, video_mask):
"""
Input is the embedding feature extracted from text or video, calculate the similarity matrix.
When input batch size is 1, return one similarity value.
Args:
text_feat (`torch.Tensor`):
Shape is (B, N_t, D). B means batch size, N_t means token length of text, D mean the dim of the network.
video_feat (`torch.Tensor`):
Shape is (B, N_v, D). B means batch size, N_t means token length of video, D mean the dim of the network.
text_mask (`torch.Tensor`):
Shape is (B, N_t), valid token position is 1, else 0.
video_mask (`torch.Tensor`):
Shape is (B, N_v), valid token position is 1, else 0.
Returns:
Similarity matrix with shape (B, B)
"""
t2v_logits, _, _ = self.model.get_similarity_logits(text_feat, video_feat, text_mask, video_mask, shaped=True)
# (B, B)
return t2v_logits
# if __name__ == '__main__':
# from towhee.models.clip import SimpleTokenizer
# text = 'hello world'
# ids = convert_tokens_to_id(SimpleTokenizer(), text, max_words=32)
# print(ids)
# text_mask = [[1 if i > 0 else 0 for i in ids[0]]]
# print(text_mask)

1
requirements.txt

@ -1,4 +1,5 @@
torchvision torchvision
torch torch
numpy
towhee>=0.7.0 towhee>=0.7.0
towhee.models>=0.7.0 towhee.models>=0.7.0
Loading…
Cancel
Save