| 
					
					
						
							
						
					
					
				 | 
				@ -45,6 +45,7 @@ log = logging.getLogger('timm_op') | 
			
		
		
	
		
			
				 | 
				 | 
				class Model: | 
				 | 
				 | 
				class Model: | 
			
		
		
	
		
			
				 | 
				 | 
				    def __init__(self, model_name, device, num_classes): | 
				 | 
				 | 
				    def __init__(self, model_name, device, num_classes): | 
			
		
		
	
		
			
				 | 
				 | 
				        self.model = create_model(model_name, pretrained=True, num_classes=num_classes) | 
				 | 
				 | 
				        self.model = create_model(model_name, pretrained=True, num_classes=num_classes) | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				        self.config = resolve_data_config({}, model=self.model) | 
			
		
		
	
		
			
				 | 
				 | 
				        self.model.eval() | 
				 | 
				 | 
				        self.model.eval() | 
			
		
		
	
		
			
				 | 
				 | 
				        self.model.to(device) | 
				 | 
				 | 
				        self.model.to(device) | 
			
		
		
	
		
			
				 | 
				 | 
				
 | 
				 | 
				 | 
				
 | 
			
		
		
	
	
		
			
				| 
					
						
							
						
					
					
						
							
						
					
					
				 | 
				@ -77,13 +78,12 @@ class TimmImage(NNOperator): | 
			
		
		
	
		
			
				 | 
				 | 
				        self.device = device | 
				 | 
				 | 
				        self.device = device | 
			
		
		
	
		
			
				 | 
				 | 
				        self.model_name = model_name | 
				 | 
				 | 
				        self.model_name = model_name | 
			
		
		
	
		
			
				 | 
				 | 
				        if self.model_name: | 
				 | 
				 | 
				        if self.model_name: | 
			
		
		
	
		
			
				 | 
				 | 
				            self.accelerate_model = Model( | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				            self.model = Model( | 
			
		
		
	
		
			
				 | 
				 | 
				                model_name=model_name, | 
				 | 
				 | 
				                model_name=model_name, | 
			
		
		
	
		
			
				 | 
				 | 
				                device=self.device, | 
				 | 
				 | 
				                device=self.device, | 
			
		
		
	
		
			
				 | 
				 | 
				                num_classes=num_classes | 
				 | 
				 | 
				                num_classes=num_classes | 
			
		
		
	
		
			
				 | 
				 | 
				            ) | 
				 | 
				 | 
				            ) | 
			
		
		
	
		
			
				 | 
				 | 
				            self.model = self.accelerate_model.model | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				            self.config = resolve_data_config({}, model=self.model) | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				            self.config = self.model.config | 
			
		
		
	
		
			
				 | 
				 | 
				            self.tfms = create_transform(**self.config) | 
				 | 
				 | 
				            self.tfms = create_transform(**self.config) | 
			
		
		
	
		
			
				 | 
				 | 
				            self.skip_tfms = skip_preprocess | 
				 | 
				 | 
				            self.skip_tfms = skip_preprocess | 
			
		
		
	
		
			
				 | 
				 | 
				        else: | 
				 | 
				 | 
				        else: | 
			
		
		
	
	
		
			
				| 
					
					
					
						
							
						
					
				 | 
				@ -102,7 +102,7 @@ class TimmImage(NNOperator): | 
			
		
		
	
		
			
				 | 
				 | 
				            img_list.append(img) | 
				 | 
				 | 
				            img_list.append(img) | 
			
		
		
	
		
			
				 | 
				 | 
				        inputs = torch.stack(img_list) | 
				 | 
				 | 
				        inputs = torch.stack(img_list) | 
			
		
		
	
		
			
				 | 
				 | 
				        inputs = inputs.to(self.device) | 
				 | 
				 | 
				        inputs = inputs.to(self.device) | 
			
		
		
	
		
			
				 | 
				 | 
				        features = self.accelerate_model(inputs) | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				        features = self.model(inputs) | 
			
		
		
	
		
			
				 | 
				 | 
				        if features.dim() == 4: | 
				 | 
				 | 
				        if features.dim() == 4: | 
			
		
		
	
		
			
				 | 
				 | 
				            global_pool = nn.AdaptiveAvgPool2d(1).to(self.device) | 
				 | 
				 | 
				            global_pool = nn.AdaptiveAvgPool2d(1).to(self.device) | 
			
		
		
	
		
			
				 | 
				 | 
				            features = global_pool(features) | 
				 | 
				 | 
				            features = global_pool(features) | 
			
		
		
	
	
		
			
				| 
					
					
					
						
							
						
					
				 | 
				@ -113,6 +113,10 @@ class TimmImage(NNOperator): | 
			
		
		
	
		
			
				 | 
				 | 
				            vecs = features.squeeze(0).detach().numpy() | 
				 | 
				 | 
				            vecs = features.squeeze(0).detach().numpy() | 
			
		
		
	
		
			
				 | 
				 | 
				        return vecs | 
				 | 
				 | 
				        return vecs | 
			
		
		
	
		
			
				 | 
				 | 
				
 | 
				 | 
				 | 
				
 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				    @property | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				    def _model(self): | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				        return self.model.model | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				
 | 
			
		
		
	
		
			
				 | 
				 | 
				    @arg(1, to_image_color('RGB')) | 
				 | 
				 | 
				    @arg(1, to_image_color('RGB')) | 
			
		
		
	
		
			
				 | 
				 | 
				    def convert_img(self, img: towhee._types.Image): | 
				 | 
				 | 
				    def convert_img(self, img: towhee._types.Image): | 
			
		
		
	
		
			
				 | 
				 | 
				        img = PILImage.fromarray(img.astype('uint8'), 'RGB') | 
				 | 
				 | 
				        img = PILImage.fromarray(img.astype('uint8'), 'RGB') | 
			
		
		
	
	
		
			
				| 
					
					
					
						
							
						
					
				 | 
				@ -133,21 +137,21 @@ class TimmImage(NNOperator): | 
			
		
		
	
		
			
				 | 
				 | 
				                raise AttributeError(f'Invalid format {format}.') | 
				 | 
				 | 
				                raise AttributeError(f'Invalid format {format}.') | 
			
		
		
	
		
			
				 | 
				 | 
				        dummy_input = torch.rand((1,) + self.config['input_size']) | 
				 | 
				 | 
				        dummy_input = torch.rand((1,) + self.config['input_size']) | 
			
		
		
	
		
			
				 | 
				 | 
				        if format == 'pytorch': | 
				 | 
				 | 
				        if format == 'pytorch': | 
			
		
		
	
		
			
				 | 
				 | 
				            torch.save(self.model, path) | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				            torch.save(self._model, path) | 
			
		
		
	
		
			
				 | 
				 | 
				        elif format == 'torchscript': | 
				 | 
				 | 
				        elif format == 'torchscript': | 
			
		
		
	
		
			
				 | 
				 | 
				            try: | 
				 | 
				 | 
				            try: | 
			
		
		
	
		
			
				 | 
				 | 
				                try: | 
				 | 
				 | 
				                try: | 
			
		
		
	
		
			
				 | 
				 | 
				                    jit_model = torch.jit.script(self.model) | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				                    jit_model = torch.jit.script(self._model) | 
			
		
		
	
		
			
				 | 
				 | 
				                except Exception: | 
				 | 
				 | 
				                except Exception: | 
			
		
		
	
		
			
				 | 
				 | 
				                    jit_model = torch.jit.trace(self.model, dummy_input, strict=False) | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				                    jit_model = torch.jit.trace(self._model, dummy_input, strict=False) | 
			
		
		
	
		
			
				 | 
				 | 
				                torch.jit.save(jit_model, path) | 
				 | 
				 | 
				                torch.jit.save(jit_model, path) | 
			
		
		
	
		
			
				 | 
				 | 
				            except Exception as e: | 
				 | 
				 | 
				            except Exception as e: | 
			
		
		
	
		
			
				 | 
				 | 
				                log.error(f'Fail to save as torchscript: {e}.') | 
				 | 
				 | 
				                log.error(f'Fail to save as torchscript: {e}.') | 
			
		
		
	
		
			
				 | 
				 | 
				                raise RuntimeError(f'Fail to save as torchscript: {e}.') | 
				 | 
				 | 
				                raise RuntimeError(f'Fail to save as torchscript: {e}.') | 
			
		
		
	
		
			
				 | 
				 | 
				        elif format == 'onnx': | 
				 | 
				 | 
				        elif format == 'onnx': | 
			
		
		
	
		
			
				 | 
				 | 
				            self.model.forward = self.model.forward_features | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				            self._model.forward = self._model.forward_features | 
			
		
		
	
		
			
				 | 
				 | 
				            try: | 
				 | 
				 | 
				            try: | 
			
		
		
	
		
			
				 | 
				 | 
				                torch.onnx.export(self.model, | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				                torch.onnx.export(self._model, | 
			
		
		
	
		
			
				 | 
				 | 
				                                  dummy_input, | 
				 | 
				 | 
				                                  dummy_input, | 
			
		
		
	
		
			
				 | 
				 | 
				                                  path, | 
				 | 
				 | 
				                                  path, | 
			
		
		
	
		
			
				 | 
				 | 
				                                  input_names=['input_0'], | 
				 | 
				 | 
				                                  input_names=['input_0'], | 
			
		
		
	
	
		
			
				| 
					
						
							
						
					
					
					
				 | 
				
  |