transformers
              
                 
                
            
          copied
				 3 changed files with 212 additions and 0 deletions
			
			
		| @ -0,0 +1,21 @@ | |||
| # Evaluate with Similarity Search | |||
| 
 | |||
| ## Introduction | |||
| 
 | |||
| Build a classification system based on similarity search across embeddings. | |||
| The core ideas in `evaluate.py`: | |||
| 1. create a new Milvus collection each time | |||
| 2. extract embeddings using a pretrained model with model name specified by `--model` | |||
| 3. specify inference method with `--format` in value of `pytorch` or `onnx` | |||
| 4. insert & search embeddings with Milvus collection without index | |||
| 5. measure performance with accuracy at top 1, 5, 10 | |||
|    1. vote for the prediction from topk search results (most frequent one) | |||
|    2. compare final prediction with ground truth | |||
|    3. calculate percent of correct predictions over all queries | |||
| 
 | |||
| ## Example Usage | |||
| 
 | |||
| ```bash | |||
| python evaluate.py --model MODEL_NAME --format pytorch | |||
| python evaluate.py --model MODEL_NAME --format onnx | |||
| ``` | |||
| @ -0,0 +1,184 @@ | |||
| import os | |||
| import onnxruntime | |||
| 
 | |||
| import towhee | |||
| from towhee import ops | |||
| from pymilvus import connections, DataType, FieldSchema, Collection, CollectionSchema, utility | |||
| from datasets import load_dataset | |||
| 
 | |||
| from statistics import mode | |||
| import argparse | |||
| 
 | |||
| import transformers | |||
| 
 | |||
| transformers.logging.set_verbosity_error() | |||
| 
 | |||
| parser = argparse.ArgumentParser() | |||
| parser.add_argument('--model', required=True, type=str) | |||
| parser.add_argument('--dataset', type=str, default='imdb') | |||
| parser.add_argument('--insert_size', type=int, default=1000) | |||
| parser.add_argument('--query_size', type=int, default=100) | |||
| parser.add_argument('--topk', type=int, default=10) | |||
| parser.add_argument('--collection_name', type=str, default=None) | |||
| parser.add_argument('--format', type=str, required=True) | |||
| 
 | |||
| args = parser.parse_args() | |||
| model_name = args.model | |||
| dataset_name = args.dataset | |||
| insert_size = args.insert_size | |||
| query_size = args.query_size | |||
| topk = args.topk | |||
| collection_name = args.collection_name if args.collection_name else model_name.replace('-', '_').replace('/', '_') | |||
| 
 | |||
| device = 'cpu' | |||
| host = 'localhost' | |||
| port = '19530' | |||
| index_type = 'FLAT' | |||
| metric_type = 'L2' | |||
| 
 | |||
| data = load_dataset(dataset_name).shuffle(seed=32) | |||
| assert insert_size <= len(data['train']), 'There is no enough data. Please decrease insert size.' | |||
| assert insert_size <= len(data['test']), 'There is no enough data. Please decrease query size.' | |||
| 
 | |||
| insert_data = data['train'] | |||
| insert_data = insert_data[:insert_size] if insert_size else insert_data | |||
| query_data = data['test'] | |||
| query_data = query_data[:query_size] if query_size else query_data | |||
| 
 | |||
| # Warm up | |||
| print('Warming up...') | |||
| op = ops.text_embedding.transformers(model_name=model_name, device=device).get_op() | |||
| dim = op('This is test.').shape[-1] | |||
| print(f'output dim: {dim}') | |||
| 
 | |||
| # Prepare Milvus | |||
| print('Connecting milvus ...') | |||
| connections.connect(host=host, port=port) | |||
| 
 | |||
| 
 | |||
| def create_milvus(collection_name): | |||
|     print('Creating collection ...') | |||
|     fields = [ | |||
|         FieldSchema(name='id', dtype=DataType.INT64, description='embedding id', is_primary=True, auto_id=True), | |||
|         FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, description='text embedding', dim=dim), | |||
|         FieldSchema(name='label', dtype=DataType.VARCHAR, description='label', max_length=500) | |||
|     ] | |||
|     schema = CollectionSchema(fields=fields, description=f'text embeddings for {model_name} on {dataset_name}') | |||
|     if utility.has_collection(collection_name): | |||
|         print(f'drop old collection: {collection_name}') | |||
|         collection = Collection(collection_name) | |||
|         collection.drop() | |||
|     collection = Collection(name=collection_name, schema=schema) | |||
|     print(f'A new collection is created: {collection_name}.') | |||
|     return collection | |||
| 
 | |||
| 
 | |||
| if args.format == 'pytorch': | |||
|     collection_name = collection_name + '_pytorch' | |||
| 
 | |||
| 
 | |||
|     def insert(model_name, collection_name): | |||
|         ( | |||
|             towhee.dc['text', 'label'](zip(insert_data['text'], insert_data['label'])).stream() | |||
|                   .runas_op['text', 'text'](lambda s: s[:1024]) | |||
|                   .text_embedding.transformers['text', 'emb'](model_name=model_name, device=device) | |||
|                   .runas_op['emb', 'emb'](lambda x: x[0]) | |||
|                   .runas_op['label', 'label'](lambda y: str(y)) | |||
|                   .ann_insert.milvus[('emb', 'label'), 'miluvs_insert']( | |||
|                         uri=f'tcp://{host}:{port}/{collection_name}' | |||
|                         ) | |||
|                   .show(3) | |||
|         ) | |||
|         collection = Collection(collection_name) | |||
|         return collection.num_entities | |||
| 
 | |||
| 
 | |||
|     def query(model_name, collection_name): | |||
|         benchmark = ( | |||
|             towhee.dc['text', 'gt'](zip(query_data['text'], query_data['label'])).stream() | |||
|                   .runas_op['text', 'text'](lambda s: s[:1024]) | |||
|                   .text_embedding.transformers['text', 'emb'](model_name=model_name, device=device) | |||
|                   .runas_op['emb', 'emb'](lambda x: x[0]) | |||
|                   .runas_op['gt', 'gt'](lambda y: str(y)) | |||
|                   .ann_search.milvus['emb', 'milvus_res']( | |||
|                         uri=f'tcp://{host}:{port}/{collection_name}', | |||
|                         metric_type=metric_type, | |||
|                         limit=topk, | |||
|                         output_fields=['label'] | |||
|                         ) | |||
|                   .runas_op['milvus_res', 'preds'](lambda x: [y.label for y in x]).unstream() | |||
|                   .runas_op['preds', 'pred1'](lambda x: mode(x[:1])) | |||
|                   .runas_op['preds', 'pred5'](lambda x: mode(x[:5])) | |||
|                   .runas_op['preds', 'pred10'](lambda x: mode(x[:10])) | |||
|                   .with_metrics(['accuracy']) | |||
|                   .evaluate['gt', 'pred1']('pred1') | |||
|                   .evaluate['gt', 'pred5']('pred5') | |||
|                   .evaluate['gt', 'pred10']('pred10') | |||
|                   .report() | |||
|         ) | |||
|         return benchmark | |||
| elif args.format == 'onnx': | |||
|     collection_name = collection_name + '_onnx' | |||
|     saved_name = model_name.replace('/', '-') | |||
|     onnx_path = f'saved/onnx/{saved_name}.onnx' | |||
|     if not os.path.exists(onnx_path): | |||
|         op.save_model(format='onnx') | |||
|     sess = onnxruntime.InferenceSession(onnx_path, | |||
|                                         providers=onnxruntime.get_available_providers()) | |||
| 
 | |||
|     @towhee.register | |||
|     def run_onnx(txt): | |||
|         inputs = op.tokenizer(txt, return_tensors='np') | |||
|         onnx_inputs = [x.name for x in sess.get_inputs()] | |||
|         new_inputs = {} | |||
|         for k in onnx_inputs: | |||
|             new_inputs[k] = inputs[k] | |||
|         outs = sess.run(output_names=['last_hidden_state'], input_feed=dict(new_inputs)) | |||
|         return outs[0].squeeze(0) | |||
| 
 | |||
|     def insert(model_name, collection_name): | |||
|         ( | |||
|             towhee.dc['text', 'label'](zip(insert_data['text'], insert_data['label'])).stream() | |||
|                   .runas_op['text', 'text'](lambda s: s[:1024]) | |||
|                   .run_onnx['text', 'emb']() | |||
|                   .runas_op['emb', 'emb'](lambda x: x[0]) | |||
|                   .runas_op['label', 'label'](lambda y: str(y)) | |||
|                   .ann_insert.milvus[('emb', 'label'), 'miluvs_insert']( | |||
|                         uri=f'tcp://{host}:{port}/{collection_name}' | |||
|                         ) | |||
|                   .show(3) | |||
|         ) | |||
|         collection = Collection(collection_name) | |||
|         return collection.num_entities | |||
| 
 | |||
|     def query(model_name, collection_name): | |||
|         benchmark = ( | |||
|             towhee.dc['text', 'gt'](zip(query_data['text'], query_data['label'])).stream() | |||
|                   .runas_op['text', 'text'](lambda s: s[:1024]) | |||
|                   .run_onnx['text', 'emb']() | |||
|                   .runas_op['emb', 'emb'](lambda x: x[0]) | |||
|                   .runas_op['gt', 'gt'](lambda y: str(y)) | |||
|                   .ann_search.milvus['emb', 'milvus_res']( | |||
|                            uri=f'tcp://{host}:{port}/{collection_name}', | |||
|                            metric_type=metric_type, | |||
|                            limit=topk, | |||
|                            output_fields=['label'] | |||
|                            ) | |||
|                   .runas_op['milvus_res', 'preds'](lambda x: [y.label for y in x]).unstream() | |||
|                   .runas_op['preds', 'pred1'](lambda x: mode(x[:1])) | |||
|                   .runas_op['preds', 'pred5'](lambda x: mode(x[:5])) | |||
|                   .runas_op['preds', 'pred10'](lambda x: mode(x[:10])) | |||
|                   .with_metrics(['accuracy']) | |||
|                   .evaluate['gt', 'pred1']('pred1') | |||
|                   .evaluate['gt', 'pred5']('pred5') | |||
|                   .evaluate['gt', 'pred10']('pred10') | |||
|                   .report() | |||
|         ) | |||
|         return benchmark | |||
| else: | |||
|     raise AttributeError('Only support "pytorch" and "onnx" as format.') | |||
| 
 | |||
| collection = create_milvus(collection_name) | |||
| insert_count = insert(model_name, collection_name) | |||
| print('Total data inserted:', insert_count) | |||
| benchmark = query(model_name, collection_name) | |||
| @ -0,0 +1,7 @@ | |||
| #!/bin/bash | |||
| 
 | |||
| for name in allenai/led-base-16384 flaubert/flaubert_base_cased flaubert/flaubert_base_uncased flaubert/flaubert_large_cased flaubert/flaubert_small_cased funnel-transformer/intermediate-base funnel-transformer/large-base funnel-transformer/medium-base funnel-transformer/small-base funnel-transformer/xlarge-base google/mobilebert-uncased tau/splinter-base tau/splinter-base-qass tau/splinter-large | |||
| do | |||
| 	python evaluate.py --model ${name} --format pytorch | |||
| 	python evaluate.py --model ${name} --format onnx | |||
| done | |||
					Loading…
					
					
				
		Reference in new issue
	
	