|  | @ -2,27 +2,26 @@ from auto_transformers import AutoTransformers | 
		
	
		
			
				|  |  | import torch |  |  | import torch | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
		
			
				|  |  | f = open('torchscript.csv', 'a+') |  |  | f = open('torchscript.csv', 'a+') | 
		
	
		
			
				|  |  | f.write('model_name, run op, save_torchscript, check_result\n') |  |  |  | 
		
	
		
			
				|  |  |  |  |  | f.write('model_name,run_op,save_torchscript,check_result\n') | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
		
			
				|  |  | models = AutoTransformers.supported_model_names()[:1] |  |  |  | 
		
	
		
			
				|  |  |  |  |  | # models = AutoTransformers.supported_model_names()[:1] | 
		
	
		
			
				|  |  |  |  |  | models = ['bert-base-cased', 'distilbert-base-cased'] | 
		
	
		
			
				|  |  | 
 |  |  | 
 | 
		
	
		
			
				|  |  | for name in models: |  |  | for name in models: | 
		
	
		
			
				|  |  |     line = f'{name}, ' |  |  |  | 
		
	
		
			
				|  |  |  |  |  |     f.write(f'{name},') | 
		
	
		
			
				|  |  |     try: |  |  |     try: | 
		
	
		
			
				|  |  |         op = AutoTransformers(model_name=name) |  |  |         op = AutoTransformers(model_name=name) | 
		
	
		
			
				|  |  |         out1 = op('hello, world.') |  |  |         out1 = op('hello, world.') | 
		
	
		
			
				|  |  |         line += 'success, ' |  |  |  | 
		
	
		
			
				|  |  |  |  |  |         f.write('success,') | 
		
	
		
			
				|  |  |     except Exception as e: |  |  |     except Exception as e: | 
		
	
		
			
				|  |  |         line += 'fail\n' |  |  |  | 
		
	
		
			
				|  |  |         f.write(line) |  |  |  | 
		
	
		
			
				|  |  |  |  |  |         f.write('fail') | 
		
	
		
			
				|  |  |         print(f'Fail to load op for {name}: {e}') |  |  |         print(f'Fail to load op for {name}: {e}') | 
		
	
		
			
				|  |  |         continue |  |  |         continue | 
		
	
		
			
				|  |  |     try: |  |  |     try: | 
		
	
		
			
				|  |  |         op.save_model(format='torchscript') |  |  |         op.save_model(format='torchscript') | 
		
	
		
			
				|  |  |         line += 'success, ' |  |  |  | 
		
	
		
			
				|  |  |  |  |  |         f.write('success,') | 
		
	
		
			
				|  |  |     except Exception as e: |  |  |     except Exception as e: | 
		
	
		
			
				|  |  |         line += 'fail\n' |  |  |  | 
		
	
		
			
				|  |  |         f.write(line) |  |  |  | 
		
	
		
			
				|  |  |  |  |  |         f.write('fail') | 
		
	
		
			
				|  |  |         print(f'Fail to save onnx for {name}: {e}') |  |  |         print(f'Fail to save onnx for {name}: {e}') | 
		
	
		
			
				|  |  |         continue |  |  |         continue | 
		
	
		
			
				|  |  |     try: |  |  |     try: | 
		
	
	
		
			
				|  | @ -30,11 +29,10 @@ for name in models: | 
		
	
		
			
				|  |  |         op.model = torch.jit.load(f'saved/torchscript/{saved_name}.pt') |  |  |         op.model = torch.jit.load(f'saved/torchscript/{saved_name}.pt') | 
		
	
		
			
				|  |  |         out2 = op('hello, world.') |  |  |         out2 = op('hello, world.') | 
		
	
		
			
				|  |  |         assert (out1 == out2).all() |  |  |         assert (out1 == out2).all() | 
		
	
		
			
				|  |  |         line += 'success' |  |  |  | 
		
	
		
			
				|  |  |  |  |  |         f.write('success') | 
		
	
		
			
				|  |  |     except Exception as e: |  |  |     except Exception as e: | 
		
	
		
			
				|  |  |         line += 'fail\n' |  |  |  | 
		
	
		
			
				|  |  |         f.write(line) |  |  |  | 
		
	
		
			
				|  |  |  |  |  |         f.write('fail') | 
		
	
		
			
				|  |  |         print(f'Fail to check onnx for {name}: {e}') |  |  |         print(f'Fail to check onnx for {name}: {e}') | 
		
	
		
			
				|  |  |         continue |  |  |         continue | 
		
	
		
			
				|  |  |     line += '\n' |  |  |  | 
		
	
		
			
				|  |  |     f.write(line) |  |  |  | 
		
	
		
			
				|  |  |  |  |  |     f.write('\n') | 
		
	
		
			
				|  |  |  |  |  | print('Finished.') | 
		
	
	
		
			
				|  | 
 |