|  |  |  | from auto_transformers import AutoTransformers | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import numpy | 
					
						
							|  |  |  | import onnx | 
					
						
							|  |  |  | import onnxruntime | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | import os | 
					
						
							|  |  |  | from pathlib import Path | 
					
						
							|  |  |  | import logging | 
					
						
							|  |  |  | import platform | 
					
						
							|  |  |  | import psutil | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | # full_models = AutoTransformers.supported_model_names() | 
					
						
							|  |  |  | # checked_models = AutoTransformers.supported_model_names(format='onnx') | 
					
						
							|  |  |  | # models = [x for x in full_models if x not in checked_models] | 
					
						
							|  |  |  | models = ['bert-base-cased', 'distilbert-base-cased'] | 
					
						
							|  |  |  | test_txt = 'hello, world.' | 
					
						
							|  |  |  | atol = 1e-3 | 
					
						
							|  |  |  | log_path = 'transformers_onnx.log' | 
					
						
							|  |  |  | f = open('onnx.csv', 'w+') | 
					
						
							|  |  |  | f.write('model,load_op,save_onnx,check_onnx,run_onnx,accuracy\n') | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | logger = logging.getLogger('transformers_onnx') | 
					
						
							|  |  |  | logger.setLevel(logging.DEBUG) | 
					
						
							|  |  |  | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') | 
					
						
							|  |  |  | fh = logging.FileHandler(log_path) | 
					
						
							|  |  |  | fh.setLevel(logging.DEBUG) | 
					
						
							|  |  |  | fh.setFormatter(formatter) | 
					
						
							|  |  |  | logger.addHandler(fh) | 
					
						
							|  |  |  | ch = logging.StreamHandler() | 
					
						
							|  |  |  | ch.setLevel(logging.ERROR) | 
					
						
							|  |  |  | ch.setFormatter(formatter) | 
					
						
							|  |  |  | logger.addHandler(ch) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | logger.debug(f'machine: {platform.platform()}-{platform.processor()}') | 
					
						
							|  |  |  | logger.debug(f'free/available/total mem: {round(psutil.virtual_memory().free / (1024.0 ** 3))}' | 
					
						
							|  |  |  |              f'/{round(psutil.virtual_memory().available / (1024.0 ** 3))}' | 
					
						
							|  |  |  |              f'/{round(psutil.virtual_memory().total / (1024.0 ** 3))} GB') | 
					
						
							|  |  |  | logger.debug(f'cpu: {psutil.cpu_count()}') | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | status = None | 
					
						
							|  |  |  | for name in models: | 
					
						
							|  |  |  |     logger.info(f'***{name}***') | 
					
						
							|  |  |  |     saved_name = name.replace('/', '-') | 
					
						
							|  |  |  |     onnx_path = f'saved/onnx/{saved_name}/model.onnx' | 
					
						
							|  |  |  |     if status: | 
					
						
							|  |  |  |         f.write(','.join(status) + '\n') | 
					
						
							|  |  |  |     status = [name] + ['fail'] * 5 | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         op = AutoTransformers(model_name=name) | 
					
						
							|  |  |  |         out1 = op(test_txt) | 
					
						
							|  |  |  |         logger.info('OP LOADED.') | 
					
						
							|  |  |  |         status[1] = 'success' | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							|  |  |  |         logger.error(f'FAIL TO LOAD OP: {e}') | 
					
						
							|  |  |  |         continue | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         op.save_model(format='onnx') | 
					
						
							|  |  |  |         logger.info('ONNX SAVED.') | 
					
						
							|  |  |  |         status[2] = 'success' | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							|  |  |  |         logger.error(f'FAIL TO SAVE ONNX: {e}') | 
					
						
							|  |  |  |         continue | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         try: | 
					
						
							|  |  |  |             onnx_model = onnx.load(onnx_path) | 
					
						
							|  |  |  |             onnx.checker.check_model(onnx_model) | 
					
						
							|  |  |  |         except Exception: | 
					
						
							|  |  |  |             saved_onnx = onnx.load(onnx_path, load_external_data=False) | 
					
						
							|  |  |  |             onnx.checker.check_model(saved_onnx) | 
					
						
							|  |  |  |         logger.info('ONNX CHECKED.') | 
					
						
							|  |  |  |         status[3] = 'success' | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							|  |  |  |         logger.error(f'FAIL TO CHECK ONNX: {e}') | 
					
						
							|  |  |  |         continue | 
					
						
							|  |  |  |     try: | 
					
						
							|  |  |  |         sess = onnxruntime.InferenceSession(onnx_path, | 
					
						
							|  |  |  |                                             providers=onnxruntime.get_available_providers()) | 
					
						
							|  |  |  |         inputs = op.tokenizer(test_txt, return_tensors='np') | 
					
						
							|  |  |  |         out2 = sess.run(output_names=['output_0'], input_feed=dict(inputs)) | 
					
						
							|  |  |  |         logger.info('ONNX WORKED.') | 
					
						
							|  |  |  |         status[4] = 'success' | 
					
						
							|  |  |  |         if numpy.allclose(out1, out2, atol=atol): | 
					
						
							|  |  |  |             logger.info('Check accuracy: OK') | 
					
						
							|  |  |  |             status[5] = 'success' | 
					
						
							|  |  |  |         else: | 
					
						
							|  |  |  |             logger.info(f'Check accuracy: atol is larger than {atol}.') | 
					
						
							|  |  |  |     except Exception as e: | 
					
						
							|  |  |  |         logger.error(f'FAIL TO RUN ONNX: {e}') | 
					
						
							|  |  |  |         continue | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | if status: | 
					
						
							|  |  |  |     f.write(','.join(status) + '\n') | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | print('Finished.') |