Browse Source
        
      
      Fix for triton device
      
        Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
      
      
        main
      
      
     
    
    
    
	
		
			
				 3 changed files with 
5 additions and 
5 deletions
			 
			
		 
		
			
				- 
					
					
					 
					benchmark/qps_test.py
				
- 
					
						
							
								BIN
							
						
					 benchmark/towhee.jpeg
- 
					
					
					 
					isc.py
				
				
				
					
						
							
								
									
	
		
			
				
					|  |  | @ -34,7 +34,7 @@ p = ( | 
			
		
	
		
			
				
					|  |  |  |         .output('vec') | 
			
		
	
		
			
				
					|  |  |  | ) | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  | data = 'towhee.jpeg' | 
			
		
	
		
			
				
					|  |  |  | data = '../towhee.jpeg' | 
			
		
	
		
			
				
					|  |  |  | out1 = p(data).get()[0] | 
			
		
	
		
			
				
					|  |  |  | print('Pipe: OK') | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
	
		
			
				
					|  |  | 
 | 
			
		
	
								
							
						
					 
					
				 
			 
		
			
			
			
			
			
			
				
				
					
					
						
							
								
									
	| 
			
			
				
					
						
						
							
							![]()  
								
									Width: 
									 | 
									Height: 
									 | 
								
								Size: 49 KiB
							 | 
								
							
						 
					
				 
			 
		
			
			
			
			
			
			
				
				
					
						
							
								
									
	
		
			
				
					|  |  | @ -106,7 +106,7 @@ class Isc(NNOperator): | 
			
		
	
		
			
				
					|  |  |  |             img = img if self.skip_tfms else self.tfms(img) | 
			
		
	
		
			
				
					|  |  |  |             img_list.append(img) | 
			
		
	
		
			
				
					|  |  |  |         inputs = torch.stack(img_list) | 
			
		
	
		
			
				
					|  |  |  |         inputs = inputs.to(self.device) | 
			
		
	
		
			
				
					|  |  |  |         inputs = inputs | 
			
		
	
		
			
				
					|  |  |  |         features = self.model(inputs) | 
			
		
	
		
			
				
					|  |  |  |         features = features.to('cpu') | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
	
		
			
				
					|  |  | @ -138,7 +138,7 @@ class Isc(NNOperator): | 
			
		
	
		
			
				
					|  |  |  |                 path = path + '.onnx' | 
			
		
	
		
			
				
					|  |  |  |             else: | 
			
		
	
		
			
				
					|  |  |  |                 raise ValueError(f'Invalid format {format}.') | 
			
		
	
		
			
				
					|  |  |  |         dummy_input = torch.rand(1, 3, 224, 224).to(self.device) | 
			
		
	
		
			
				
					|  |  |  |         dummy_input = torch.rand(1, 3, 224, 224) | 
			
		
	
		
			
				
					|  |  |  |         if format == 'pytorch': | 
			
		
	
		
			
				
					|  |  |  |             torch.save(self._model, path) | 
			
		
	
		
			
				
					|  |  |  |         elif format == 'torchscript': | 
			
		
	
	
		
			
				
					|  |  | @ -153,7 +153,7 @@ class Isc(NNOperator): | 
			
		
	
		
			
				
					|  |  |  |                 raise RuntimeError(f'Fail to save as torchscript: {e}.') | 
			
		
	
		
			
				
					|  |  |  |         elif format == 'onnx': | 
			
		
	
		
			
				
					|  |  |  |             try: | 
			
		
	
		
			
				
					|  |  |  |                 torch.onnx.export(self._model, | 
			
		
	
		
			
				
					|  |  |  |                 torch.onnx.export(self._model.to('cpu'), | 
			
		
	
		
			
				
					|  |  |  |                                   dummy_input, | 
			
		
	
		
			
				
					|  |  |  |                                   path, | 
			
		
	
		
			
				
					|  |  |  |                                   input_names=['input_0'], | 
			
		
	
	
		
			
				
					|  |  | 
 |