diff --git a/drl.py b/drl.py index 96e6727..20505b7 100644 --- a/drl.py +++ b/drl.py @@ -91,3 +91,33 @@ class DRL(NNOperator): visual_output = self.model.get_video_feat(video, video_mask, shaped=True) # B(1), N_v, D return visual_output.squeeze(0).detach().cpu().numpy() # N_v, D + + def get_similarity_logits(self, text_feat, video_feat, text_mask, video_mask): + """ + Input is the embedding feature extracted from text or video, calculate the similarity matrix. + When input batch size is 1, return one similarity value. + Args: + text_feat (`torch.Tensor`): + Shape is (B, N_t, D). B means batch size, N_t means token length of text, D mean the dim of the network. + video_feat (`torch.Tensor`): + Shape is (B, N_v, D). B means batch size, N_t means token length of video, D mean the dim of the network. + text_mask (`torch.Tensor`): + Shape is (B, N_t), valid token position is 1, else 0. + video_mask (`torch.Tensor`): + Shape is (B, N_v), valid token position is 1, else 0. + + Returns: + Similarity matrix with shape (B, B) + """ + t2v_logits, _, _ = self.model.get_similarity_logits(text_feat, video_feat, text_mask, video_mask, shaped=True) + # (B, B) + return t2v_logits + + +# if __name__ == '__main__': +# from towhee.models.clip import SimpleTokenizer +# text = 'hello world' +# ids = convert_tokens_to_id(SimpleTokenizer(), text, max_words=32) +# print(ids) +# text_mask = [[1 if i > 0 else 0 for i in ids[0]]] +# print(text_mask) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 3b2f37c..b50071d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,5 @@ torchvision torch +numpy towhee>=0.7.0 towhee.models>=0.7.0 \ No newline at end of file