|
@ -21,9 +21,11 @@ parser.add_argument('--query_size', type=int, default=100) |
|
|
parser.add_argument('--topk', type=int, default=10) |
|
|
parser.add_argument('--topk', type=int, default=10) |
|
|
parser.add_argument('--collection_name', type=str, default=None) |
|
|
parser.add_argument('--collection_name', type=str, default=None) |
|
|
parser.add_argument('--format', type=str, required=True) |
|
|
parser.add_argument('--format', type=str, required=True) |
|
|
|
|
|
parser.add_argument('--onnx_dir', type=str, default='../saved/onnx') |
|
|
|
|
|
|
|
|
args = parser.parse_args() |
|
|
args = parser.parse_args() |
|
|
model_name = args.model |
|
|
model_name = args.model |
|
|
|
|
|
onnx_path = os.path.join(args.onnx_dir, model_name.replace('/', '-'), '.onnx') |
|
|
dataset_name = args.dataset |
|
|
dataset_name = args.dataset |
|
|
insert_size = args.insert_size |
|
|
insert_size = args.insert_size |
|
|
query_size = args.query_size |
|
|
query_size = args.query_size |
|
@ -76,7 +78,6 @@ def create_milvus(collection_name): |
|
|
if args.format == 'pytorch': |
|
|
if args.format == 'pytorch': |
|
|
collection_name = collection_name + '_pytorch' |
|
|
collection_name = collection_name + '_pytorch' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def insert(model_name, collection_name): |
|
|
def insert(model_name, collection_name): |
|
|
( |
|
|
( |
|
|
towhee.dc['text', 'label'](zip(insert_data['text'], insert_data['label'])).stream() |
|
|
towhee.dc['text', 'label'](zip(insert_data['text'], insert_data['label'])).stream() |
|
@ -120,9 +121,8 @@ if args.format == 'pytorch': |
|
|
elif args.format == 'onnx': |
|
|
elif args.format == 'onnx': |
|
|
collection_name = collection_name + '_onnx' |
|
|
collection_name = collection_name + '_onnx' |
|
|
saved_name = model_name.replace('/', '-') |
|
|
saved_name = model_name.replace('/', '-') |
|
|
onnx_path = f'saved/onnx/{saved_name}.onnx' |
|
|
|
|
|
if not os.path.exists(onnx_path): |
|
|
if not os.path.exists(onnx_path): |
|
|
op.save_model(format='onnx') |
|
|
|
|
|
|
|
|
op.save_model(format='onnx', path=onnx_path[:-5]) |
|
|
sess = onnxruntime.InferenceSession(onnx_path, |
|
|
sess = onnxruntime.InferenceSession(onnx_path, |
|
|
providers=onnxruntime.get_available_providers()) |
|
|
providers=onnxruntime.get_available_providers()) |
|
|
|
|
|
|
|
|