Browse Source
        
      
      Allow to pass checkpoint path
      
        Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
      
      
        main
      
      
     
    
    
    
	
		
			
				 1 changed files with 
10 additions and 
4 deletions
			 
			
		 
		
			
				- 
					
					
					 
					timm_image.py
				
 
			
		
		
			
			
			
			
			
			
				
				
					
						
							
								
									
	
		
			
				| 
					
					
						
							
						
					
					
				 | 
				@ -54,8 +54,12 @@ def torch_no_grad(f): | 
			
		
		
	
		
			
				 | 
				 | 
				
 | 
				 | 
				 | 
				
 | 
			
		
		
	
		
			
				 | 
				 | 
				@accelerate | 
				 | 
				 | 
				@accelerate | 
			
		
		
	
		
			
				 | 
				 | 
				class Model: | 
				 | 
				 | 
				class Model: | 
			
		
		
	
		
			
				 | 
				 | 
				    def __init__(self, model_name, device, num_classes): | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				    def __init__(self, model_name, device, num_classes, checkpoint_path=None): | 
			
		
		
	
		
			
				 | 
				 | 
				        self.device = device | 
				 | 
				 | 
				        self.device = device | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				        if checkpoint_path: | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				            assert os.path.exists(checkpoint_path), f'File not found: {checkpoint_path}' | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				            self.model = create_model(model_name, checkpoint_path, num_classes=num_classes) | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				        else: | 
			
		
		
	
		
			
				 | 
				 | 
				            self.model = create_model(model_name, pretrained=True, num_classes=num_classes)         | 
				 | 
				 | 
				            self.model = create_model(model_name, pretrained=True, num_classes=num_classes)         | 
			
		
		
	
		
			
				 | 
				 | 
				        self.model.eval() | 
				 | 
				 | 
				        self.model.eval() | 
			
		
		
	
		
			
				 | 
				 | 
				        self.model.to(device) | 
				 | 
				 | 
				        self.model.to(device) | 
			
		
		
	
	
		
			
				| 
					
					
					
						
							
						
					
				 | 
				@ -81,7 +85,8 @@ class TimmImage(NNOperator): | 
			
		
		
	
		
			
				 | 
				 | 
				                 model_name: str = None, | 
				 | 
				 | 
				                 model_name: str = None, | 
			
		
		
	
		
			
				 | 
				 | 
				                 num_classes: int = 1000, | 
				 | 
				 | 
				                 num_classes: int = 1000, | 
			
		
		
	
		
			
				 | 
				 | 
				                 skip_preprocess: bool = False, | 
				 | 
				 | 
				                 skip_preprocess: bool = False, | 
			
		
		
	
		
			
				 | 
				 | 
				                 device: str = None | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				                 device: str = None, | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				                 checkpoint_path: str = None | 
			
		
		
	
		
			
				 | 
				 | 
				                 ) -> None: | 
				 | 
				 | 
				                 ) -> None: | 
			
		
		
	
		
			
				 | 
				 | 
				        super().__init__() | 
				 | 
				 | 
				        super().__init__() | 
			
		
		
	
		
			
				 | 
				 | 
				        if device is None: | 
				 | 
				 | 
				        if device is None: | 
			
		
		
	
	
		
			
				| 
					
					
					
						
							
						
					
				 | 
				@ -92,7 +97,8 @@ class TimmImage(NNOperator): | 
			
		
		
	
		
			
				 | 
				 | 
				            self.model = Model( | 
				 | 
				 | 
				            self.model = Model( | 
			
		
		
	
		
			
				 | 
				 | 
				                model_name=model_name, | 
				 | 
				 | 
				                model_name=model_name, | 
			
		
		
	
		
			
				 | 
				 | 
				                device=self.device, | 
				 | 
				 | 
				                device=self.device, | 
			
		
		
	
		
			
				 | 
				 | 
				                num_classes=num_classes | 
				 | 
				 | 
				 | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				                num_classes=num_classes, | 
			
		
		
	
		
			
				 | 
				 | 
				 | 
				 | 
				 | 
				                checkpoint_path=checkpoint_path | 
			
		
		
	
		
			
				 | 
				 | 
				            ) | 
				 | 
				 | 
				            ) | 
			
		
		
	
		
			
				 | 
				 | 
				            try: | 
				 | 
				 | 
				            try: | 
			
		
		
	
		
			
				 | 
				 | 
				                self.tfms = create_transform( | 
				 | 
				 | 
				                self.tfms = create_transform( | 
			
		
		
	
	
		
			
				| 
					
						
							
						
					
					
					
				 | 
				
  |