transformers
              
                 
                
            
          copied
				 3 changed files with 212 additions and 0 deletions
			
			
		| @ -0,0 +1,21 @@ | |||||
|  | # Evaluate with Similarity Search | ||||
|  | 
 | ||||
|  | ## Introduction | ||||
|  | 
 | ||||
|  | Build a classification system based on similarity search across embeddings. | ||||
|  | The core ideas in `evaluate.py`: | ||||
|  | 1. create a new Milvus collection each time | ||||
|  | 2. extract embeddings using a pretrained model with model name specified by `--model` | ||||
|  | 3. specify inference method with `--format` in value of `pytorch` or `onnx` | ||||
|  | 4. insert & search embeddings with Milvus collection without index | ||||
|  | 5. measure performance with accuracy at top 1, 5, 10 | ||||
|  |    1. vote for the prediction from topk search results (most frequent one) | ||||
|  |    2. compare final prediction with ground truth | ||||
|  |    3. calculate percent of correct predictions over all queries | ||||
|  | 
 | ||||
|  | ## Example Usage | ||||
|  | 
 | ||||
|  | ```bash | ||||
|  | python evaluate.py --model MODEL_NAME --format pytorch | ||||
|  | python evaluate.py --model MODEL_NAME --format onnx | ||||
|  | ``` | ||||
| @ -0,0 +1,184 @@ | |||||
|  | import os | ||||
|  | import onnxruntime | ||||
|  | 
 | ||||
|  | import towhee | ||||
|  | from towhee import ops | ||||
|  | from pymilvus import connections, DataType, FieldSchema, Collection, CollectionSchema, utility | ||||
|  | from datasets import load_dataset | ||||
|  | 
 | ||||
|  | from statistics import mode | ||||
|  | import argparse | ||||
|  | 
 | ||||
|  | import transformers | ||||
|  | 
 | ||||
|  | transformers.logging.set_verbosity_error() | ||||
|  | 
 | ||||
|  | parser = argparse.ArgumentParser() | ||||
|  | parser.add_argument('--model', required=True, type=str) | ||||
|  | parser.add_argument('--dataset', type=str, default='imdb') | ||||
|  | parser.add_argument('--insert_size', type=int, default=1000) | ||||
|  | parser.add_argument('--query_size', type=int, default=100) | ||||
|  | parser.add_argument('--topk', type=int, default=10) | ||||
|  | parser.add_argument('--collection_name', type=str, default=None) | ||||
|  | parser.add_argument('--format', type=str, required=True) | ||||
|  | 
 | ||||
|  | args = parser.parse_args() | ||||
|  | model_name = args.model | ||||
|  | dataset_name = args.dataset | ||||
|  | insert_size = args.insert_size | ||||
|  | query_size = args.query_size | ||||
|  | topk = args.topk | ||||
|  | collection_name = args.collection_name if args.collection_name else model_name.replace('-', '_').replace('/', '_') | ||||
|  | 
 | ||||
|  | device = 'cpu' | ||||
|  | host = 'localhost' | ||||
|  | port = '19530' | ||||
|  | index_type = 'FLAT' | ||||
|  | metric_type = 'L2' | ||||
|  | 
 | ||||
|  | data = load_dataset(dataset_name).shuffle(seed=32) | ||||
|  | assert insert_size <= len(data['train']), 'There is no enough data. Please decrease insert size.' | ||||
|  | assert insert_size <= len(data['test']), 'There is no enough data. Please decrease query size.' | ||||
|  | 
 | ||||
|  | insert_data = data['train'] | ||||
|  | insert_data = insert_data[:insert_size] if insert_size else insert_data | ||||
|  | query_data = data['test'] | ||||
|  | query_data = query_data[:query_size] if query_size else query_data | ||||
|  | 
 | ||||
|  | # Warm up | ||||
|  | print('Warming up...') | ||||
|  | op = ops.text_embedding.transformers(model_name=model_name, device=device).get_op() | ||||
|  | dim = op('This is test.').shape[-1] | ||||
|  | print(f'output dim: {dim}') | ||||
|  | 
 | ||||
|  | # Prepare Milvus | ||||
|  | print('Connecting milvus ...') | ||||
|  | connections.connect(host=host, port=port) | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | def create_milvus(collection_name): | ||||
|  |     print('Creating collection ...') | ||||
|  |     fields = [ | ||||
|  |         FieldSchema(name='id', dtype=DataType.INT64, description='embedding id', is_primary=True, auto_id=True), | ||||
|  |         FieldSchema(name='embedding', dtype=DataType.FLOAT_VECTOR, description='text embedding', dim=dim), | ||||
|  |         FieldSchema(name='label', dtype=DataType.VARCHAR, description='label', max_length=500) | ||||
|  |     ] | ||||
|  |     schema = CollectionSchema(fields=fields, description=f'text embeddings for {model_name} on {dataset_name}') | ||||
|  |     if utility.has_collection(collection_name): | ||||
|  |         print(f'drop old collection: {collection_name}') | ||||
|  |         collection = Collection(collection_name) | ||||
|  |         collection.drop() | ||||
|  |     collection = Collection(name=collection_name, schema=schema) | ||||
|  |     print(f'A new collection is created: {collection_name}.') | ||||
|  |     return collection | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | if args.format == 'pytorch': | ||||
|  |     collection_name = collection_name + '_pytorch' | ||||
|  | 
 | ||||
|  | 
 | ||||
|  |     def insert(model_name, collection_name): | ||||
|  |         ( | ||||
|  |             towhee.dc['text', 'label'](zip(insert_data['text'], insert_data['label'])).stream() | ||||
|  |                   .runas_op['text', 'text'](lambda s: s[:1024]) | ||||
|  |                   .text_embedding.transformers['text', 'emb'](model_name=model_name, device=device) | ||||
|  |                   .runas_op['emb', 'emb'](lambda x: x[0]) | ||||
|  |                   .runas_op['label', 'label'](lambda y: str(y)) | ||||
|  |                   .ann_insert.milvus[('emb', 'label'), 'miluvs_insert']( | ||||
|  |                         uri=f'tcp://{host}:{port}/{collection_name}' | ||||
|  |                         ) | ||||
|  |                   .show(3) | ||||
|  |         ) | ||||
|  |         collection = Collection(collection_name) | ||||
|  |         return collection.num_entities | ||||
|  | 
 | ||||
|  | 
 | ||||
|  |     def query(model_name, collection_name): | ||||
|  |         benchmark = ( | ||||
|  |             towhee.dc['text', 'gt'](zip(query_data['text'], query_data['label'])).stream() | ||||
|  |                   .runas_op['text', 'text'](lambda s: s[:1024]) | ||||
|  |                   .text_embedding.transformers['text', 'emb'](model_name=model_name, device=device) | ||||
|  |                   .runas_op['emb', 'emb'](lambda x: x[0]) | ||||
|  |                   .runas_op['gt', 'gt'](lambda y: str(y)) | ||||
|  |                   .ann_search.milvus['emb', 'milvus_res']( | ||||
|  |                         uri=f'tcp://{host}:{port}/{collection_name}', | ||||
|  |                         metric_type=metric_type, | ||||
|  |                         limit=topk, | ||||
|  |                         output_fields=['label'] | ||||
|  |                         ) | ||||
|  |                   .runas_op['milvus_res', 'preds'](lambda x: [y.label for y in x]).unstream() | ||||
|  |                   .runas_op['preds', 'pred1'](lambda x: mode(x[:1])) | ||||
|  |                   .runas_op['preds', 'pred5'](lambda x: mode(x[:5])) | ||||
|  |                   .runas_op['preds', 'pred10'](lambda x: mode(x[:10])) | ||||
|  |                   .with_metrics(['accuracy']) | ||||
|  |                   .evaluate['gt', 'pred1']('pred1') | ||||
|  |                   .evaluate['gt', 'pred5']('pred5') | ||||
|  |                   .evaluate['gt', 'pred10']('pred10') | ||||
|  |                   .report() | ||||
|  |         ) | ||||
|  |         return benchmark | ||||
|  | elif args.format == 'onnx': | ||||
|  |     collection_name = collection_name + '_onnx' | ||||
|  |     saved_name = model_name.replace('/', '-') | ||||
|  |     onnx_path = f'saved/onnx/{saved_name}.onnx' | ||||
|  |     if not os.path.exists(onnx_path): | ||||
|  |         op.save_model(format='onnx') | ||||
|  |     sess = onnxruntime.InferenceSession(onnx_path, | ||||
|  |                                         providers=onnxruntime.get_available_providers()) | ||||
|  | 
 | ||||
|  |     @towhee.register | ||||
|  |     def run_onnx(txt): | ||||
|  |         inputs = op.tokenizer(txt, return_tensors='np') | ||||
|  |         onnx_inputs = [x.name for x in sess.get_inputs()] | ||||
|  |         new_inputs = {} | ||||
|  |         for k in onnx_inputs: | ||||
|  |             new_inputs[k] = inputs[k] | ||||
|  |         outs = sess.run(output_names=['last_hidden_state'], input_feed=dict(new_inputs)) | ||||
|  |         return outs[0].squeeze(0) | ||||
|  | 
 | ||||
|  |     def insert(model_name, collection_name): | ||||
|  |         ( | ||||
|  |             towhee.dc['text', 'label'](zip(insert_data['text'], insert_data['label'])).stream() | ||||
|  |                   .runas_op['text', 'text'](lambda s: s[:1024]) | ||||
|  |                   .run_onnx['text', 'emb']() | ||||
|  |                   .runas_op['emb', 'emb'](lambda x: x[0]) | ||||
|  |                   .runas_op['label', 'label'](lambda y: str(y)) | ||||
|  |                   .ann_insert.milvus[('emb', 'label'), 'miluvs_insert']( | ||||
|  |                         uri=f'tcp://{host}:{port}/{collection_name}' | ||||
|  |                         ) | ||||
|  |                   .show(3) | ||||
|  |         ) | ||||
|  |         collection = Collection(collection_name) | ||||
|  |         return collection.num_entities | ||||
|  | 
 | ||||
|  |     def query(model_name, collection_name): | ||||
|  |         benchmark = ( | ||||
|  |             towhee.dc['text', 'gt'](zip(query_data['text'], query_data['label'])).stream() | ||||
|  |                   .runas_op['text', 'text'](lambda s: s[:1024]) | ||||
|  |                   .run_onnx['text', 'emb']() | ||||
|  |                   .runas_op['emb', 'emb'](lambda x: x[0]) | ||||
|  |                   .runas_op['gt', 'gt'](lambda y: str(y)) | ||||
|  |                   .ann_search.milvus['emb', 'milvus_res']( | ||||
|  |                            uri=f'tcp://{host}:{port}/{collection_name}', | ||||
|  |                            metric_type=metric_type, | ||||
|  |                            limit=topk, | ||||
|  |                            output_fields=['label'] | ||||
|  |                            ) | ||||
|  |                   .runas_op['milvus_res', 'preds'](lambda x: [y.label for y in x]).unstream() | ||||
|  |                   .runas_op['preds', 'pred1'](lambda x: mode(x[:1])) | ||||
|  |                   .runas_op['preds', 'pred5'](lambda x: mode(x[:5])) | ||||
|  |                   .runas_op['preds', 'pred10'](lambda x: mode(x[:10])) | ||||
|  |                   .with_metrics(['accuracy']) | ||||
|  |                   .evaluate['gt', 'pred1']('pred1') | ||||
|  |                   .evaluate['gt', 'pred5']('pred5') | ||||
|  |                   .evaluate['gt', 'pred10']('pred10') | ||||
|  |                   .report() | ||||
|  |         ) | ||||
|  |         return benchmark | ||||
|  | else: | ||||
|  |     raise AttributeError('Only support "pytorch" and "onnx" as format.') | ||||
|  | 
 | ||||
|  | collection = create_milvus(collection_name) | ||||
|  | insert_count = insert(model_name, collection_name) | ||||
|  | print('Total data inserted:', insert_count) | ||||
|  | benchmark = query(model_name, collection_name) | ||||
| @ -0,0 +1,7 @@ | |||||
|  | #!/bin/bash | ||||
|  | 
 | ||||
|  | for name in allenai/led-base-16384 flaubert/flaubert_base_cased flaubert/flaubert_base_uncased flaubert/flaubert_large_cased flaubert/flaubert_small_cased funnel-transformer/intermediate-base funnel-transformer/large-base funnel-transformer/medium-base funnel-transformer/small-base funnel-transformer/xlarge-base google/mobilebert-uncased tau/splinter-base tau/splinter-base-qass tau/splinter-large | ||||
|  | do | ||||
|  | 	python evaluate.py --model ${name} --format pytorch | ||||
|  | 	python evaluate.py --model ${name} --format onnx | ||||
|  | done | ||||
					Loading…
					
					
				
		Reference in new issue
	
	