5 changed files with 208 additions and 1 deletions
			
			
		| @ -1,2 +1,127 @@ | |||||
| # collaborative-experts |  | ||||
|  | # Video-Text Retrieval Embedding with Collaborative Experts | ||||
|  | 
 | ||||
|  | *author: Chen Zhang* | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | <br /> | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | ## Description | ||||
|  | 
 | ||||
|  | This operator extracts features for video or text with [Use What You Have: Video Retrieval Using Representations From Collaborative Experts](https://arxiv.org/pdf/1907.13487v2.pdf).    | ||||
|  | For video encoder, this operator exploits embeddings with different modality information extracted from pre-trained experts models, such as motion, appearance, scene, ASR or OCR. | ||||
|  | For text query encoder, it exploits text embeddings extracted from pre-trained models such as word2vec or GPT.   | ||||
|  | This operator is a collaborative experts model, which aggregates information from these different pre-trained expert models, and output the video embeddings and text embeddings. | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | <br /> | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | ## Code Example | ||||
|  | 
 | ||||
|  | For video input, load embeddings extracted from different upstream expert networks, such as audio, face, action, RGB, OCR and so on. They can be from upstream operators or from disk .   | ||||
|  | For text query input, load text embeddings extracted from pre-trained models. They can be from upstream operators or from disk .   | ||||
|  | For `ind` input, if a data in one of the modalities is invalid(such as NaN) or you do not want use it, the corresponding value in `ind` is 0, else 1. | ||||
|  | 
 | ||||
|  | 
 | ||||
|  |  *Write the pipeline code*: | ||||
|  | 
 | ||||
|  | ```python | ||||
|  | import torch | ||||
|  | from towhee import Entity | ||||
|  | import towhee | ||||
|  | 
 | ||||
|  | torch.manual_seed(42) | ||||
|  | 
 | ||||
|  | batch_size = 8 | ||||
|  | experts = {"audio": torch.rand(batch_size, 29, 128), | ||||
|  |            "face": torch.rand(batch_size, 512), | ||||
|  |            "i3d.i3d.0": torch.rand(batch_size, 1024), | ||||
|  |            "imagenet.resnext101_32x48d.0": torch.rand(batch_size, 2048), | ||||
|  |            "imagenet.senet154.0": torch.rand(batch_size, 2048), | ||||
|  |            "ocr": torch.rand(batch_size, 49, 300), | ||||
|  |            "r2p1d.r2p1d-ig65m.0": torch.rand(batch_size, 512), | ||||
|  |            "scene.densenet161.0": torch.rand(batch_size, 2208), | ||||
|  |            "speech": torch.rand(batch_size, 32, 300) | ||||
|  |            } | ||||
|  | ind = { | ||||
|  |     "r2p1d.r2p1d-ig65m.0": torch.ones(batch_size,), | ||||
|  |     "imagenet.senet154.0": torch.ones(batch_size,), | ||||
|  |     "imagenet.resnext101_32x48d.0": torch.ones(batch_size,), | ||||
|  |     "scene.densenet161.0": torch.ones(batch_size,), | ||||
|  |     "audio": torch.ones(batch_size,), | ||||
|  |     "speech": torch.ones(batch_size,), | ||||
|  |     "ocr": torch.randint(low=0, high=2, size=(batch_size,)), | ||||
|  |     "face": torch.randint(low=0, high=2, size=(batch_size,)), | ||||
|  |     "i3d.i3d.0": torch.ones(batch_size,), | ||||
|  | } | ||||
|  | 
 | ||||
|  | text = torch.randn(batch_size, 1, 37, 768) | ||||
|  | 
 | ||||
|  | towhee.dc([Entity(experts=experts, ind=ind, text=text)]) \ | ||||
|  |     .video_text_embedding.collaborative_experts[('experts', 'ind', 'text'), ('text_embds', 'vid_embds')]().show() | ||||
|  | ``` | ||||
|  | 
 | ||||
|  |  | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | <br /> | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | ## Factory Constructor | ||||
|  | 
 | ||||
|  | Create the operator via the following factory method | ||||
|  | 
 | ||||
|  | ***collaborative_experts(config: Dict = None, weights_path: str = None, device: str = None)*** | ||||
|  | 
 | ||||
|  | **Parameters:** | ||||
|  | 
 | ||||
|  | ​   ***config:*** *Optional[dict]* | ||||
|  | 
 | ||||
|  | ​   Default is None, if None, use the [default config](https://github.com/towhee-io/towhee/blob/a713ea2deaa0273f0b6af28354a36572e8eba604/towhee/models/collaborative_experts/collaborative_experts.py#L1130) the same as these in the original paper and repo,  | ||||
|  | 
 | ||||
|  | ​   ***weights_path:*** *Optional[str]* | ||||
|  | 
 | ||||
|  | ​   Pretrained model weights path, if None, use the weights in this operator.   | ||||
|  | 
 | ||||
|  | ​   ***device:*** *Optional[str]* | ||||
|  | 
 | ||||
|  | ​   cpu or cuda. | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | <br /> | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | ## Interface | ||||
|  | 
 | ||||
|  | When video modality, load a video embeddings extracted from different upstream expert networks, such as video, RGB, audio.   | ||||
|  | When text modality, read the text to generate a text embedding.  | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | **Parameters:** | ||||
|  | 
 | ||||
|  | ​	***experts:*** *dict* | ||||
|  | 
 | ||||
|  | ​  Embeddings extracted from different upstream expert networks, such as audio, face, action, RGB, OCR and so on. They can be from upstream operators or from disk .  	 | ||||
|  | 
 | ||||
|  | ​	***ind:*** *dict* | ||||
|  | 
 | ||||
|  | ​  If a data in one of the modalities is invalid(such as NaN) or you do not want use it, the corresponding value in `ind` is 0, else 1. | ||||
|  | 
 | ||||
|  | ​	***text:*** *Tensor* | ||||
|  | 
 | ||||
|  | ​  Text embeddings extracted from pre-trained models. They can be from upstream operators or from disk .   | ||||
|  | 	 | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | **Returns:** *numpy.ndarray* | ||||
|  | 
 | ||||
|  | ​   Text embeddings and video embeddings. They are both a dict with different modality, the key is the same as input modality, and the value is a tensor with shape (batch size, 768).  | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | 
 | ||||
| 
 | 
 | ||||
|  | |||||
| @ -0,0 +1,20 @@ | |||||
|  | # Copyright 2021 Zilliz. All rights reserved. | ||||
|  | # | ||||
|  | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  | # you may not use this file except in compliance with the License. | ||||
|  | # You may obtain a copy of the License at | ||||
|  | # | ||||
|  | #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | # | ||||
|  | # Unless required by applicable law or agreed to in writing, software | ||||
|  | # distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  | # See the License for the specific language governing permissions and | ||||
|  | # limitations under the License. | ||||
|  | 
 | ||||
|  | from .collaborative_experts import CollaborativeExperts | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | def collaborative_experts(**kwargs): | ||||
|  |     return CollaborativeExperts(**kwargs) | ||||
|  | 
 | ||||
| @ -0,0 +1,59 @@ | |||||
|  | # Copyright 2021 Zilliz. All rights reserved. | ||||
|  | # | ||||
|  | # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
|  | # you may not use this file except in compliance with the License. | ||||
|  | # You may obtain a copy of the License at | ||||
|  | # | ||||
|  | #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
|  | # | ||||
|  | # Unless required by applicable law or agreed to in writing, software | ||||
|  | # distributed under the License is distributed on an "AS IS" BASIS, | ||||
|  | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
|  | # See the License for the specific language governing permissions and | ||||
|  | # limitations under the License. | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | from typing import Dict | ||||
|  | 
 | ||||
|  | from pathlib import Path | ||||
|  | from towhee.operator.base import NNOperator | ||||
|  | from towhee import register | ||||
|  | from towhee.models import collaborative_experts | ||||
|  | 
 | ||||
|  | import warnings | ||||
|  | 
 | ||||
|  | warnings.filterwarnings('ignore') | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | @register(output_schema=['vec']) | ||||
|  | class CollaborativeExperts(NNOperator): | ||||
|  |     """ | ||||
|  |     CollaborativeExperts embedding operator | ||||
|  |     """ | ||||
|  | 
 | ||||
|  |     def __init__(self, config: Dict = None, weights_path: str = None, device: str = None): | ||||
|  |         super().__init__() | ||||
|  |         if weights_path is None: | ||||
|  |             weights_path = str(Path(__file__).parent / 'trained_model.pth') | ||||
|  |         self.ce_net_model = collaborative_experts.create_model(config, weights_path, device) | ||||
|  | 
 | ||||
|  |     def __call__(self, experts, ind, text): | ||||
|  |         out = self.ce_net_model(experts, ind, text) | ||||
|  |         text_embds = {k: v.squeeze(1).detach().cpu().numpy() for k, v in out['text_embds'].items()} | ||||
|  |         vid_embds = {k: v.detach().cpu().numpy() for k, v in out['vid_embds'].items()} | ||||
|  |         return text_embds, vid_embds | ||||
|  | 
 | ||||
|  |     def get_text_embds(self, experts, ind, text): | ||||
|  |         out = self.ce_net_model(experts, ind, text) | ||||
|  |         text_embds = {k: v.squeeze(1).detach().cpu().numpy() for k, v in out['text_embds'].items()} | ||||
|  |         return text_embds | ||||
|  | 
 | ||||
|  |     def get_vid_embds(self, experts, ind, text): | ||||
|  |         out = self.ce_net_model(experts, ind, text) | ||||
|  |         vid_embds = {k: v.detach().cpu().numpy() for k, v in out['vid_embds'].items()} | ||||
|  |         return vid_embds | ||||
|  | 
 | ||||
|  |     def get_cross_view_conf_matrix(self, experts, ind, text): | ||||
|  |         out = self.ce_net_model(experts, ind, text) | ||||
|  |         return out['cross_view_conf_matrix'].detach().cpu().numpy() | ||||
|  | 
 | ||||
| After Width: | Height: | Size: 152 KiB | 
| @ -0,0 +1,3 @@ | |||||
|  | torch | ||||
|  | towhee.models | ||||
|  | towhee | ||||
					Loading…
					
					
				
		Reference in new issue
	
	