milvus-client
              
                 
                
            
          copied
				 3 changed files with 96 additions and 0 deletions
			
			
		| @ -0,0 +1,5 @@ | |||||
|  | from .milvus_client import MilvusClientls | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | def milvus_client(*args, **kwargs): | ||||
|  |     return MilvusClient(*args, **kwargs) | ||||
| @ -0,0 +1,90 @@ | |||||
|  | from pymilvus import connections, Collection | ||||
|  | from towhee.operator import PyOperator | ||||
|  | import uuid | ||||
|  | 
 | ||||
|  | 
 | ||||
|  | class Milvus(PyOperator): | ||||
|  |     """ | ||||
|  |     Search for embedding vectors in Milvus. Note that the Milvus collection has data before searching, | ||||
|  | 
 | ||||
|  |     Args: | ||||
|  |         collection (`str`): | ||||
|  |             The collection name. | ||||
|  |         kwargs | ||||
|  |             The kwargs with collection.search, refer to https://milvus.io/docs/v2.0.x/search.md#Prepare-search-parameters. | ||||
|  |             And the `anns_field` defaults to the vector field name, `limit` defaults to 10, and `metric_type` in `param` defaults to 'L2' | ||||
|  |             if there has no index(FLAT), and for default index `param`: | ||||
|  |                 IVF_FLAT: {"params": {"nprobe": 10}}, | ||||
|  |                 IVF_SQ8: {"params": {"nprobe": 10}}, | ||||
|  |                 IVF_PQ: {"params": {"nprobe": 10}}, | ||||
|  |                 HNSW: {"params": {"ef": 10}}, | ||||
|  |                 IVF_HNSW: {"params": {"nprobe": 10, "ef": 10}}, | ||||
|  |                 RHNSW_FLAT: {"params": {"ef": 10}}, | ||||
|  |                 RHNSW_SQ: {"params": {"ef": 10}}, | ||||
|  |                 RHNSW_PQ: {"params": {"ef": 10}}, | ||||
|  |                 ANNOY: {"params": {"search_k": 10}}. | ||||
|  |     """ | ||||
|  | 
 | ||||
|  |     def __init__(self, host: str = 'localhost', port: int = 19530, collection_name: str = None, **kwargs): | ||||
|  |         """ | ||||
|  |         Get an existing collection. | ||||
|  |         """ | ||||
|  |         self._host = host | ||||
|  |         self._port = port | ||||
|  |         self._collection_name = collection_name | ||||
|  |         self._connect_name = uuid.uuid4().hex | ||||
|  |         connections.connect(alias=self._connect_name, host=self._host, port=self._port) | ||||
|  |         self._collection = Collection(self._collection_name, using=self._connect_name) | ||||
|  | 
 | ||||
|  |         self.kwargs = kwargs | ||||
|  |         if 'anns_field' not in self.kwargs: | ||||
|  |             fields_schema = self._collection.schema.fields | ||||
|  |             for schema in fields_schema: | ||||
|  |                 if schema.dtype in (101, 100): | ||||
|  |                     self.kwargs['anns_field'] = schema.name | ||||
|  | 
 | ||||
|  |         if 'limit' not in self.kwargs: | ||||
|  |             self.kwargs['limit'] = 10 | ||||
|  | 
 | ||||
|  |         index_params = { | ||||
|  |             'IVF_FLAT': {'params': {'nprobe': 10}}, | ||||
|  |             'IVF_SQ8': {'params': {'nprobe': 10}}, | ||||
|  |             'IVF_PQ': {'params': {'nprobe': 10}}, | ||||
|  |             'HNSW': {'params': {'ef': 10}}, | ||||
|  |             'RHNSW_FLAT': {'params': {'ef': 10}}, | ||||
|  |             'RHNSW_SQ': {'params': {'ef': 10}}, | ||||
|  |             'RHNSW_PQ': {'params': {'ef': 10}}, | ||||
|  |             'IVF_HNSW': {'params': {'nprobe': 10, 'ef': 10}}, | ||||
|  |             'ANNOY': {'params': {'search_k': 10}} | ||||
|  |         } | ||||
|  | 
 | ||||
|  |         if 'param' not in self.kwargs: | ||||
|  |             if len(self._collection.indexes) != 0: | ||||
|  |                 index_type = self._collection.indexes[0].params['index_type'] | ||||
|  |                 self.kwargs['param'] = index_params[index_type] | ||||
|  |             else: | ||||
|  |                 self.kwargs['param'] = index_params['IVF_FLAT'] | ||||
|  |             if 'metric_type' in self.kwargs: | ||||
|  |                 self.kwargs['param']['metric_type'] = self.kwargs['metric_type'] | ||||
|  |             else: | ||||
|  |                 self.kwargs['param']['metric_type'] = 'L2' | ||||
|  | 
 | ||||
|  |     def __call__(self, query: list): | ||||
|  |         milvus_result = self._collection.search( | ||||
|  |             data=[query], | ||||
|  |             **self.kwargs | ||||
|  |         ) | ||||
|  | 
 | ||||
|  |         result = [] | ||||
|  |         for re in milvus_result: | ||||
|  |             row = [] | ||||
|  |             for hit in re: | ||||
|  |                 row.extend([hit.id, hit.score]) | ||||
|  |                 if 'output_fields' in self.kwargs: | ||||
|  |                     for k in self.kwargs['output_fields']: | ||||
|  |                         row.append(hit.entity._row_data[k]) | ||||
|  |                 result.append(row) | ||||
|  |         return result | ||||
|  | 
 | ||||
|  |     def __del__(self): | ||||
|  |         connections.disconnect(self._connect_name) | ||||
| @ -0,0 +1 @@ | |||||
|  | pymilvus | ||||
					Loading…
					
					
				
		Reference in new issue
	
	