sbert
copied
3 changed files with 3 additions and 244 deletions
@ -1,103 +0,0 @@ |
|||||
from towhee import ops |
|
||||
import numpy |
|
||||
import onnx |
|
||||
import onnxruntime |
|
||||
|
|
||||
import os |
|
||||
from pathlib import Path |
|
||||
import logging |
|
||||
import platform |
|
||||
import psutil |
|
||||
|
|
||||
op = ops.sentence_embedding.sbert().get_op() |
|
||||
# full_models = op.supported_model_names() |
|
||||
# checked_models = AutoTransformers.supported_model_names(format='onnx') |
|
||||
# models = [x for x in full_models if x not in checked_models] |
|
||||
models = ['all-MiniLM-L12-v2'] |
|
||||
test_txt = 'hello, world.' |
|
||||
atol = 1e-3 |
|
||||
log_path = 'sbert.log' |
|
||||
f = open('onnx.csv', 'w+') |
|
||||
f.write('model,load_op,save_onnx,check_onnx,run_onnx,accuracy\n') |
|
||||
|
|
||||
logger = logging.getLogger('sbert_onnx') |
|
||||
logger.setLevel(logging.DEBUG) |
|
||||
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') |
|
||||
fh = logging.FileHandler(log_path) |
|
||||
fh.setLevel(logging.DEBUG) |
|
||||
fh.setFormatter(formatter) |
|
||||
logger.addHandler(fh) |
|
||||
ch = logging.StreamHandler() |
|
||||
ch.setLevel(logging.ERROR) |
|
||||
ch.setFormatter(formatter) |
|
||||
logger.addHandler(ch) |
|
||||
|
|
||||
logger.debug(f'machine: {platform.platform()}-{platform.processor()}') |
|
||||
logger.debug(f'free/available/total mem: {round(psutil.virtual_memory().free / (1024.0 ** 3))}' |
|
||||
f'/{round(psutil.virtual_memory().available / (1024.0 ** 3))}' |
|
||||
f'/{round(psutil.virtual_memory().total / (1024.0 ** 3))} GB') |
|
||||
logger.debug(f'cpu: {psutil.cpu_count()}') |
|
||||
|
|
||||
|
|
||||
status = None |
|
||||
for name in models: |
|
||||
logger.info(f'***{name}***') |
|
||||
saved_name = name.replace('/', '-') |
|
||||
onnx_path = f'saved/onnx/{saved_name}.onnx' |
|
||||
if status: |
|
||||
f.write(','.join(status) + '\n') |
|
||||
status = [name] + ['fail'] * 5 |
|
||||
try: |
|
||||
op = ops.sentence_embedding.sbert(model_name=name, device='cpu').get_op() |
|
||||
out1 = op(test_txt) |
|
||||
logger.info('OP LOADED.') |
|
||||
status[1] = 'success' |
|
||||
except Exception as e: |
|
||||
logger.error(f'FAIL TO LOAD OP: {e}') |
|
||||
continue |
|
||||
try: |
|
||||
op.save_model('onnx') |
|
||||
logger.info('ONNX SAVED.') |
|
||||
status[2] = 'success' |
|
||||
except Exception as e: |
|
||||
logger.error(f'FAIL TO SAVE ONNX: {e}') |
|
||||
continue |
|
||||
try: |
|
||||
try: |
|
||||
onnx_model = onnx.load(onnx_path) |
|
||||
onnx.checker.check_model(onnx_model) |
|
||||
except Exception: |
|
||||
saved_onnx = onnx.load(onnx_path, load_external_data=False) |
|
||||
onnx.checker.check_model(saved_onnx) |
|
||||
logger.info('ONNX CHECKED.') |
|
||||
status[3] = 'success' |
|
||||
except Exception as e: |
|
||||
logger.error(f'FAIL TO CHECK ONNX: {e}') |
|
||||
pass |
|
||||
try: |
|
||||
inputs = op._model.tokenize([test_txt]) |
|
||||
sess = onnxruntime.InferenceSession(onnx_path, providers=onnxruntime.get_available_providers()) |
|
||||
onnx_inputs = {} |
|
||||
for n in sess.get_inputs(): |
|
||||
k = n.name |
|
||||
if k in inputs: |
|
||||
onnx_inputs[k] = inputs[k].cpu().detach().numpy() |
|
||||
out2 = sess.run(None, input_feed=onnx_inputs)[0].squeeze(0) |
|
||||
logger.info('ONNX WORKED.') |
|
||||
status[4] = 'success' |
|
||||
if numpy.allclose(out1, out2, atol=atol): |
|
||||
logger.info('Check accuracy: OK') |
|
||||
status[5] = 'success' |
|
||||
else: |
|
||||
logger.info(f'Check accuracy: atol is larger than {atol}.') |
|
||||
except Exception as e: |
|
||||
logger.error(f'FAIL TO RUN ONNX: {e}') |
|
||||
continue |
|
||||
|
|
||||
if status: |
|
||||
f.write(','.join(status) + '\n') |
|
||||
|
|
||||
print('Finished.') |
|
||||
|
|
||||
|
|
||||
|
|
Loading…
Reference in new issue