| 
					
					
						
							
						
					
					
				 | 
				@ -41,12 +41,14 @@ class Omnivore(NNOperator): | 
			
		
		
	
		
			
				 | 
				 | 
				    def __init__(self, | 
				 | 
				 | 
				    def __init__(self, | 
			
		
		
	
		
			
				 | 
				 | 
				                 model_name: str = 'omnivore_swinT', | 
				 | 
				 | 
				                 model_name: str = 'omnivore_swinT', | 
			
		
		
	
		
			
				 | 
				 | 
				                 framework: str = 'pytorch', | 
				 | 
				 | 
				                 framework: str = 'pytorch', | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				                 input_type: str = 'video', | 
			
		
		
	
		
			
				 | 
				 | 
				                 skip_preprocess: bool = False, | 
				 | 
				 | 
				                 skip_preprocess: bool = False, | 
			
		
		
	
		
			
				 | 
				 | 
				                 classmap: dict = None, | 
				 | 
				 | 
				                 classmap: dict = None, | 
			
		
		
	
		
			
				 | 
				 | 
				                 topk: int = 5, | 
				 | 
				 | 
				                 topk: int = 5, | 
			
		
		
	
		
			
				 | 
				 | 
				                 ): | 
				 | 
				 | 
				                 ): | 
			
		
		
	
		
			
				 | 
				 | 
				        super().__init__(framework=framework) | 
				 | 
				 | 
				        super().__init__(framework=framework) | 
			
		
		
	
		
			
				 | 
				 | 
				        self.model_name = model_name | 
				 | 
				 | 
				        self.model_name = model_name | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				        self.input_type = input_type | 
			
		
		
	
		
			
				 | 
				 | 
				        self.skip_preprocess = skip_preprocess | 
				 | 
				 | 
				        self.skip_preprocess = skip_preprocess | 
			
		
		
	
		
			
				 | 
				 | 
				        self.topk = topk | 
				 | 
				 | 
				        self.topk = topk | 
			
		
		
	
		
			
				 | 
				 | 
				        self.dataset_name = 'kinetics_400' | 
				 | 
				 | 
				        self.dataset_name = 'kinetics_400' | 
			
		
		
	
	
		
			
				| 
					
						
							
						
					
					
						
							
						
					
					
				 | 
				@ -99,11 +101,14 @@ class Omnivore(NNOperator): | 
			
		
		
	
		
			
				 | 
				 | 
				                    ) | 
				 | 
				 | 
				                    ) | 
			
		
		
	
		
			
				 | 
				 | 
				        inputs = data.to(self.device)[None, ...] | 
				 | 
				 | 
				        inputs = data.to(self.device)[None, ...] | 
			
		
		
	
		
			
				 | 
				 | 
				
 | 
				 | 
				 | 
				
 | 
			
		
		
	
		
			
				 | 
				 | 
				        outs = self.model(inputs,input_type="video") | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				        feats = self.model.forward_features(inputs ,input_type = self.input_type) | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				        features = feats.to('cpu').squeeze(0).detach().numpy() | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				
 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				        outs = self.model.head(feats) | 
			
		
		
	
		
			
				 | 
				 | 
				        post_act = torch.nn.Softmax(dim=1) | 
				 | 
				 | 
				        post_act = torch.nn.Softmax(dim=1) | 
			
		
		
	
		
			
				 | 
				 | 
				        preds = post_act(outs) | 
				 | 
				 | 
				        preds = post_act(outs) | 
			
		
		
	
		
			
				 | 
				 | 
				        pred_scores, pred_classes = preds.topk(k=self.topk) | 
				 | 
				 | 
				        pred_scores, pred_classes = preds.topk(k=self.topk) | 
			
		
		
	
		
			
				 | 
				 | 
				        labels = [self.classmap[int(i)] for i in pred_classes[0]] | 
				 | 
				 | 
				        labels = [self.classmap[int(i)] for i in pred_classes[0]] | 
			
		
		
	
		
			
				 | 
				 | 
				        scores = [round(float(x), 5) for x in pred_scores[0]] | 
				 | 
				 | 
				        scores = [round(float(x), 5) for x in pred_scores[0]] | 
			
		
		
	
		
			
				 | 
				 | 
				        print(labels,scores) | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				        return labels, scores | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				
 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				        return labels, scores, features | 
			
		
		
	
	
		
			
				| 
					
					
					
				 | 
				
  |