|  | @ -91,3 +91,33 @@ class DRL(NNOperator): | 
		
	
		
			
				|  |  |         visual_output = self.model.get_video_feat(video, video_mask, shaped=True)  # B(1), N_v, D |  |  |         visual_output = self.model.get_video_feat(video, video_mask, shaped=True)  # B(1), N_v, D | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
		
			
				|  |  |         return visual_output.squeeze(0).detach().cpu().numpy()  # N_v, D |  |  |         return visual_output.squeeze(0).detach().cpu().numpy()  # N_v, D | 
		
	
		
			
				|  |  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |  |     def get_similarity_logits(self, text_feat, video_feat, text_mask, video_mask): | 
		
	
		
			
				|  |  |  |  |  |         """ | 
		
	
		
			
				|  |  |  |  |  |         Input is the embedding feature extracted from text or video, calculate the similarity matrix. | 
		
	
		
			
				|  |  |  |  |  |         When input batch size is 1, return one similarity value. | 
		
	
		
			
				|  |  |  |  |  |         Args: | 
		
	
		
			
				|  |  |  |  |  |             text_feat (`torch.Tensor`): | 
		
	
		
			
				|  |  |  |  |  |                  Shape is (B, N_t, D). B means batch size, N_t means token length of text, D mean the dim of the network. | 
		
	
		
			
				|  |  |  |  |  |             video_feat (`torch.Tensor`): | 
		
	
		
			
				|  |  |  |  |  |                  Shape is (B, N_v, D). B means batch size, N_t means token length of video, D mean the dim of the network. | 
		
	
		
			
				|  |  |  |  |  |             text_mask (`torch.Tensor`): | 
		
	
		
			
				|  |  |  |  |  |                  Shape is (B, N_t), valid token position is 1, else 0. | 
		
	
		
			
				|  |  |  |  |  |             video_mask (`torch.Tensor`): | 
		
	
		
			
				|  |  |  |  |  |                  Shape is (B, N_v), valid token position is 1, else 0. | 
		
	
		
			
				|  |  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |  |         Returns: | 
		
	
		
			
				|  |  |  |  |  |             Similarity matrix with shape (B, B) | 
		
	
		
			
				|  |  |  |  |  |         """ | 
		
	
		
			
				|  |  |  |  |  |         t2v_logits, _, _ = self.model.get_similarity_logits(text_feat, video_feat, text_mask, video_mask, shaped=True) | 
		
	
		
			
				|  |  |  |  |  |         # (B, B) | 
		
	
		
			
				|  |  |  |  |  |         return t2v_logits | 
		
	
		
			
				|  |  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |  | 
 | 
		
	
		
			
				|  |  |  |  |  | # if __name__ == '__main__': | 
		
	
		
			
				|  |  |  |  |  | #     from towhee.models.clip import SimpleTokenizer | 
		
	
		
			
				|  |  |  |  |  | #     text = 'hello world' | 
		
	
		
			
				|  |  |  |  |  | #     ids = convert_tokens_to_id(SimpleTokenizer(), text, max_words=32) | 
		
	
		
			
				|  |  |  |  |  | #     print(ids) | 
		
	
		
			
				|  |  |  |  |  | #     text_mask = [[1 if i > 0 else 0 for i in ids[0]]] | 
		
	
		
			
				|  |  |  |  |  | #     print(text_mask) |