5 changed files with 208 additions and 1 deletions
			
			
		| @ -1,2 +1,127 @@ | |||
| # collaborative-experts | |||
| # Video-Text Retrieval Embedding with Collaborative Experts | |||
| 
 | |||
| *author: Chen Zhang* | |||
| 
 | |||
| 
 | |||
| <br /> | |||
| 
 | |||
| 
 | |||
| 
 | |||
| ## Description | |||
| 
 | |||
| This operator extracts features for video or text with [Use What You Have: Video Retrieval Using Representations From Collaborative Experts](https://arxiv.org/pdf/1907.13487v2.pdf).    | |||
| For video encoder, this operator exploits embeddings with different modality information extracted from pre-trained experts models, such as motion, appearance, scene, ASR or OCR. | |||
| For text query encoder, it exploits text embeddings extracted from pre-trained models such as word2vec or GPT.   | |||
| This operator is a collaborative experts model, which aggregates information from these different pre-trained expert models, and output the video embeddings and text embeddings. | |||
| 
 | |||
| 
 | |||
| <br /> | |||
| 
 | |||
| 
 | |||
| ## Code Example | |||
| 
 | |||
| For video input, load embeddings extracted from different upstream expert networks, such as audio, face, action, RGB, OCR and so on. They can be from upstream operators or from disk .   | |||
| For text query input, load text embeddings extracted from pre-trained models. They can be from upstream operators or from disk .   | |||
| For `ind` input, if a data in one of the modalities is invalid(such as NaN) or you do not want use it, the corresponding value in `ind` is 0, else 1. | |||
| 
 | |||
| 
 | |||
|  *Write the pipeline code*: | |||
| 
 | |||
| ```python | |||
| import torch | |||
| from towhee import Entity | |||
| import towhee | |||
| 
 | |||
| torch.manual_seed(42) | |||
| 
 | |||
| batch_size = 8 | |||
| experts = {"audio": torch.rand(batch_size, 29, 128), | |||
|            "face": torch.rand(batch_size, 512), | |||
|            "i3d.i3d.0": torch.rand(batch_size, 1024), | |||
|            "imagenet.resnext101_32x48d.0": torch.rand(batch_size, 2048), | |||
|            "imagenet.senet154.0": torch.rand(batch_size, 2048), | |||
|            "ocr": torch.rand(batch_size, 49, 300), | |||
|            "r2p1d.r2p1d-ig65m.0": torch.rand(batch_size, 512), | |||
|            "scene.densenet161.0": torch.rand(batch_size, 2208), | |||
|            "speech": torch.rand(batch_size, 32, 300) | |||
|            } | |||
| ind = { | |||
|     "r2p1d.r2p1d-ig65m.0": torch.ones(batch_size,), | |||
|     "imagenet.senet154.0": torch.ones(batch_size,), | |||
|     "imagenet.resnext101_32x48d.0": torch.ones(batch_size,), | |||
|     "scene.densenet161.0": torch.ones(batch_size,), | |||
|     "audio": torch.ones(batch_size,), | |||
|     "speech": torch.ones(batch_size,), | |||
|     "ocr": torch.randint(low=0, high=2, size=(batch_size,)), | |||
|     "face": torch.randint(low=0, high=2, size=(batch_size,)), | |||
|     "i3d.i3d.0": torch.ones(batch_size,), | |||
| } | |||
| 
 | |||
| text = torch.randn(batch_size, 1, 37, 768) | |||
| 
 | |||
| towhee.dc([Entity(experts=experts, ind=ind, text=text)]) \ | |||
|     .video_text_embedding.collaborative_experts[('experts', 'ind', 'text'), ('text_embds', 'vid_embds')]().show() | |||
| ``` | |||
| 
 | |||
|  | |||
| 
 | |||
| 
 | |||
| <br /> | |||
| 
 | |||
| 
 | |||
| 
 | |||
| ## Factory Constructor | |||
| 
 | |||
| Create the operator via the following factory method | |||
| 
 | |||
| ***collaborative_experts(config: Dict = None, weights_path: str = None, device: str = None)*** | |||
| 
 | |||
| **Parameters:** | |||
| 
 | |||
|    ***config:*** *Optional[dict]* | |||
| 
 | |||
|    Default is None, if None, use the [default config](https://github.com/towhee-io/towhee/blob/a713ea2deaa0273f0b6af28354a36572e8eba604/towhee/models/collaborative_experts/collaborative_experts.py#L1130) the same as these in the original paper and repo,  | |||
| 
 | |||
|    ***weights_path:*** *Optional[str]* | |||
| 
 | |||
|    Pretrained model weights path, if None, use the weights in this operator.   | |||
| 
 | |||
|    ***device:*** *Optional[str]* | |||
| 
 | |||
|    cpu or cuda. | |||
| 
 | |||
| 
 | |||
| <br /> | |||
| 
 | |||
| 
 | |||
| 
 | |||
| ## Interface | |||
| 
 | |||
| When video modality, load a video embeddings extracted from different upstream expert networks, such as video, RGB, audio.   | |||
| When text modality, read the text to generate a text embedding.  | |||
| 
 | |||
| 
 | |||
| **Parameters:** | |||
| 
 | |||
| 	***experts:*** *dict* | |||
| 
 | |||
|   Embeddings extracted from different upstream expert networks, such as audio, face, action, RGB, OCR and so on. They can be from upstream operators or from disk .  	 | |||
| 
 | |||
| 	***ind:*** *dict* | |||
| 
 | |||
|   If a data in one of the modalities is invalid(such as NaN) or you do not want use it, the corresponding value in `ind` is 0, else 1. | |||
| 
 | |||
| 	***text:*** *Tensor* | |||
| 
 | |||
|   Text embeddings extracted from pre-trained models. They can be from upstream operators or from disk .   | |||
| 	 | |||
| 
 | |||
| 
 | |||
| 
 | |||
| **Returns:** *numpy.ndarray* | |||
| 
 | |||
|    Text embeddings and video embeddings. They are both a dict with different modality, the key is the same as input modality, and the value is a tensor with shape (batch size, 768).  | |||
| 
 | |||
| 
 | |||
| 
 | |||
| 
 | |||
|  | |||
| @ -0,0 +1,20 @@ | |||
| # Copyright 2021 Zilliz. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| #     http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| 
 | |||
| from .collaborative_experts import CollaborativeExperts | |||
| 
 | |||
| 
 | |||
| def collaborative_experts(**kwargs): | |||
|     return CollaborativeExperts(**kwargs) | |||
| 
 | |||
| @ -0,0 +1,59 @@ | |||
| # Copyright 2021 Zilliz. All rights reserved. | |||
| # | |||
| # Licensed under the Apache License, Version 2.0 (the "License"); | |||
| # you may not use this file except in compliance with the License. | |||
| # You may obtain a copy of the License at | |||
| # | |||
| #     http://www.apache.org/licenses/LICENSE-2.0 | |||
| # | |||
| # Unless required by applicable law or agreed to in writing, software | |||
| # distributed under the License is distributed on an "AS IS" BASIS, | |||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
| # See the License for the specific language governing permissions and | |||
| # limitations under the License. | |||
| 
 | |||
| 
 | |||
| from typing import Dict | |||
| 
 | |||
| from pathlib import Path | |||
| from towhee.operator.base import NNOperator | |||
| from towhee import register | |||
| from towhee.models import collaborative_experts | |||
| 
 | |||
| import warnings | |||
| 
 | |||
| warnings.filterwarnings('ignore') | |||
| 
 | |||
| 
 | |||
| @register(output_schema=['vec']) | |||
| class CollaborativeExperts(NNOperator): | |||
|     """ | |||
|     CollaborativeExperts embedding operator | |||
|     """ | |||
| 
 | |||
|     def __init__(self, config: Dict = None, weights_path: str = None, device: str = None): | |||
|         super().__init__() | |||
|         if weights_path is None: | |||
|             weights_path = str(Path(__file__).parent / 'trained_model.pth') | |||
|         self.ce_net_model = collaborative_experts.create_model(config, weights_path, device) | |||
| 
 | |||
|     def __call__(self, experts, ind, text): | |||
|         out = self.ce_net_model(experts, ind, text) | |||
|         text_embds = {k: v.squeeze(1).detach().cpu().numpy() for k, v in out['text_embds'].items()} | |||
|         vid_embds = {k: v.detach().cpu().numpy() for k, v in out['vid_embds'].items()} | |||
|         return text_embds, vid_embds | |||
| 
 | |||
|     def get_text_embds(self, experts, ind, text): | |||
|         out = self.ce_net_model(experts, ind, text) | |||
|         text_embds = {k: v.squeeze(1).detach().cpu().numpy() for k, v in out['text_embds'].items()} | |||
|         return text_embds | |||
| 
 | |||
|     def get_vid_embds(self, experts, ind, text): | |||
|         out = self.ce_net_model(experts, ind, text) | |||
|         vid_embds = {k: v.detach().cpu().numpy() for k, v in out['vid_embds'].items()} | |||
|         return vid_embds | |||
| 
 | |||
|     def get_cross_view_conf_matrix(self, experts, ind, text): | |||
|         out = self.ce_net_model(experts, ind, text) | |||
|         return out['cross_view_conf_matrix'].detach().cpu().numpy() | |||
| 
 | |||
| After Width: | Height: | Size: 152 KiB | 
| @ -0,0 +1,3 @@ | |||
| torch | |||
| towhee.models | |||
| towhee | |||
					Loading…
					
					
				
		Reference in new issue
	
	