transformers
              
                 
                
            
          copied
			You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
			
			Readme
Files and versions
		
      
        
        
          
            38 lines
          
        
        
          
            1.1 KiB
          
        
        
      
		
    
      
      
    
	
  
	
            38 lines
          
        
        
          
            1.1 KiB
          
        
        
      | from auto_transformers import AutoTransformers | |
| import torch | |
| 
 | |
| f = open('torchscript.csv', 'a+') | |
| f.write('model_name,run_op,save_torchscript,check_result\n') | |
| 
 | |
| # models = AutoTransformers.supported_model_names()[:1] | |
| models = ['bert-base-cased', 'distilbert-base-cased'] | |
| 
 | |
| for name in models: | |
|     f.write(f'{name},') | |
|     try: | |
|         op = AutoTransformers(model_name=name) | |
|         out1 = op('hello, world.') | |
|         f.write('success,') | |
|     except Exception as e: | |
|         f.write('fail') | |
|         print(f'Fail to load op for {name}: {e}') | |
|         continue | |
|     try: | |
|         op.save_model(format='torchscript') | |
|         f.write('success,') | |
|     except Exception as e: | |
|         f.write('fail') | |
|         print(f'Fail to save onnx for {name}: {e}') | |
|         continue | |
|     try: | |
|         saved_name = name.replace('/', '-') | |
|         op.model = torch.jit.load(f'saved/torchscript/{saved_name}.pt') | |
|         out2 = op('hello, world.') | |
|         assert (out1 == out2).all() | |
|         f.write('success') | |
|     except Exception as e: | |
|         f.write('fail') | |
|         print(f'Fail to check onnx for {name}: {e}') | |
|         continue | |
|     f.write('\n') | |
| print('Finished.')
 | 
