osschat-milvus
              
                
                
            
          copied
				 4 changed files with 166 additions and 1 deletions
			
			
		@ -1,2 +1,59 @@ | 
				
			|||
# osschat-milvus | 
				
			|||
# ANN Search Operator: MilvusClient | 
				
			|||
 | 
				
			|||
*author: junjie.jiangjjj* | 
				
			|||
 | 
				
			|||
<br /> | 
				
			|||
 | 
				
			|||
## Desription | 
				
			|||
Search embedding in [Milvus](https://milvus.io/), **please make sure you have inserted data to Milvus Collection**. | 
				
			|||
 | 
				
			|||
<br /> | 
				
			|||
 | 
				
			|||
 | 
				
			|||
 | 
				
			|||
## Code Example | 
				
			|||
 | 
				
			|||
> Please make sure you have inserted data into Milvus and [load the collection](https://milvus.io/docs/v2.1.x/load_collection.md) to memory. | 
				
			|||
 | 
				
			|||
 | 
				
			|||
```python | 
				
			|||
from towhee import pipe, ops, DataCollection | 
				
			|||
 | 
				
			|||
p = pipe.input('collection_name', 'text')  \ | 
				
			|||
    .map('text', 'vec', ops.sentence_embedding.transformers(model_name='all-MiniLM-L12-v2'))  \ | 
				
			|||
    .flat_map(('collection_name', 'vec'), 'rows', ops.ann_search.milvus_client(host='127.0.0.1', port='19530', **{'output_fields': ['text']}))  \ | 
				
			|||
    .map('rows', ('id', 'score', 'text'), lambda x: (x[0], x[1], x[2]))  \ | 
				
			|||
    .output('id', 'score', 'text') | 
				
			|||
 | 
				
			|||
DataCollection(p('test_collection', 'cat')).show() | 
				
			|||
 | 
				
			|||
# result: | 
				
			|||
 | 
				
			|||
``` | 
				
			|||
 | 
				
			|||
```python | 
				
			|||
from towhee import pipe, ops | 
				
			|||
 | 
				
			|||
# search additional info url: | 
				
			|||
from towhee import pipe, ops, DataCollection | 
				
			|||
 | 
				
			|||
p = pipe.input('collection_name', 'text')  \ | 
				
			|||
    .map('text', 'vec', ops.sentence_embedding.transformers(model_name='all-MiniLM-L12-v2'))  \ | 
				
			|||
    .map(('collection_name', 'vec'), 'rows', ops.ann_search.milvus_client(host='127.0.0.1', port='19530', **{'output_fields': ['text']}))  \ | 
				
			|||
    .output('rows') | 
				
			|||
 | 
				
			|||
DataCollection(p('test_collection', 'cat')).show() | 
				
			|||
``` | 
				
			|||
 | 
				
			|||
<br /> | 
				
			|||
 | 
				
			|||
 | 
				
			|||
 | 
				
			|||
## Factory Constructor | 
				
			|||
 | 
				
			|||
Create the operator via the following factory method: | 
				
			|||
 | 
				
			|||
***ann_search.milvus_client(host='127.0.0.1', port='19530')*** | 
				
			|||
 | 
				
			|||
 | 
				
			|||
<br /> | 
				
			|||
 | 
				
			|||
@ -0,0 +1,5 @@ | 
				
			|||
from .milvus_client import MilvusClient | 
				
			|||
 | 
				
			|||
 | 
				
			|||
def osschat_milvus(*args, **kwargs): | 
				
			|||
    return MilvusClient(*args, **kwargs) | 
				
			|||
@ -0,0 +1,102 @@ | 
				
			|||
from pymilvus import connections, Collection | 
				
			|||
from towhee.operator import PyOperator, SharedType | 
				
			|||
import uuid | 
				
			|||
 | 
				
			|||
 | 
				
			|||
class MilvusClient(PyOperator): | 
				
			|||
    """ | 
				
			|||
    Search for embedding vectors in Milvus. Note that the Milvus collection has data before searching, | 
				
			|||
 | 
				
			|||
    Args: | 
				
			|||
        collection (`str`): | 
				
			|||
            The collection name. | 
				
			|||
        kwargs | 
				
			|||
            The kwargs with collection.search, refer to https://milvus.io/docs/v2.0.x/search.md#Prepare-search-parameters. | 
				
			|||
            And the `anns_field` defaults to the vector field name, `limit` defaults to 10, and `metric_type` in `param` defaults to 'L2' | 
				
			|||
            if there has no index(FLAT), and for default index `param`: | 
				
			|||
                IVF_FLAT: {"params": {"nprobe": 10}}, | 
				
			|||
                IVF_SQ8: {"params": {"nprobe": 10}}, | 
				
			|||
                IVF_PQ: {"params": {"nprobe": 10}}, | 
				
			|||
                HNSW: {"params": {"ef": 10}}, | 
				
			|||
                IVF_HNSW: {"params": {"nprobe": 10, "ef": 10}}, | 
				
			|||
                RHNSW_FLAT: {"params": {"ef": 10}}, | 
				
			|||
                RHNSW_SQ: {"params": {"ef": 10}}, | 
				
			|||
                RHNSW_PQ: {"params": {"ef": 10}}, | 
				
			|||
                ANNOY: {"params": {"search_k": 10}}. | 
				
			|||
    """ | 
				
			|||
 | 
				
			|||
    def __init__(self, host: str = 'localhost', port: int = 19530, | 
				
			|||
                 user: str = None, password: str = None, **kwargs): | 
				
			|||
        """ | 
				
			|||
        Get an existing collection. | 
				
			|||
        """ | 
				
			|||
        self._host = host | 
				
			|||
        self._port = port | 
				
			|||
        self.kwargs = kwargs | 
				
			|||
        self._connect_name = uuid.uuid4().hex | 
				
			|||
        if None in [user, password]: | 
				
			|||
            connections.connect(alias=self._connect_name, host=self._host, port=self._port) | 
				
			|||
        else: | 
				
			|||
            connections.connect(alias=self._connect_name, host=self._host, port=self._port, | 
				
			|||
                                user=user, password=password, secure=True) | 
				
			|||
         | 
				
			|||
 | 
				
			|||
    def __call__(self, collection_name: str, query: 'ndarray'): | 
				
			|||
        self._collection = Collection(collection_name, using=self._connect_name) | 
				
			|||
        if 'anns_field' not in self.kwargs: | 
				
			|||
            fields_schema = self._collection.schema.fields | 
				
			|||
            for schema in fields_schema: | 
				
			|||
                if schema.dtype in (101, 100): | 
				
			|||
                    self.kwargs['anns_field'] = schema.name | 
				
			|||
        if 'limit' not in self.kwargs: | 
				
			|||
            self.kwargs['limit'] = 10 | 
				
			|||
        index_params = { | 
				
			|||
            'FLAT': {'params': {'nprobe': 10}}, | 
				
			|||
            'IVF_FLAT': {'params': {'nprobe': 10}}, | 
				
			|||
            'IVF_SQ8': {'params': {'nprobe': 10}}, | 
				
			|||
            'IVF_PQ': {'params': {'nprobe': 10}}, | 
				
			|||
            'HNSW': {'params': {'ef': 10}}, | 
				
			|||
            'RHNSW_FLAT': {'params': {'ef': 10}}, | 
				
			|||
            'RHNSW_SQ': {'params': {'ef': 10}}, | 
				
			|||
            'RHNSW_PQ': {'params': {'ef': 10}}, | 
				
			|||
            'IVF_HNSW': {'params': {'nprobe': 10, 'ef': 10}}, | 
				
			|||
            'ANNOY': {'params': {'search_k': 10}}, | 
				
			|||
            'AUTOINDEX': {} | 
				
			|||
        } | 
				
			|||
        if 'param' not in self.kwargs: | 
				
			|||
            if len(self._collection.indexes) != 0: | 
				
			|||
                index_type = self._collection.indexes[0].params['index_type'] | 
				
			|||
                self.kwargs['param'] = index_params[index_type] | 
				
			|||
            else: | 
				
			|||
                self.kwargs['param'] = index_params['IVF_FLAT'] | 
				
			|||
            if 'metric_type' in self.kwargs: | 
				
			|||
                self.kwargs['param']['metric_type'] = self.kwargs['metric_type'] | 
				
			|||
            else: | 
				
			|||
                self.kwargs['param']['metric_type'] = 'L2' | 
				
			|||
         | 
				
			|||
        self._collection.load() | 
				
			|||
        milvus_result = self._collection.search( | 
				
			|||
            data=[query], | 
				
			|||
            **self.kwargs | 
				
			|||
        ) | 
				
			|||
 | 
				
			|||
        result = [] | 
				
			|||
        for hit in milvus_result[0]: | 
				
			|||
            row = [] | 
				
			|||
            row.extend([hit.id, hit.score]) | 
				
			|||
            if 'output_fields' in self.kwargs: | 
				
			|||
                for k in self.kwargs['output_fields']: | 
				
			|||
                 row.append(hit.entity.get(k)) | 
				
			|||
            result.append(row) | 
				
			|||
        return result | 
				
			|||
 | 
				
			|||
    @property | 
				
			|||
    def shared_type(self): | 
				
			|||
        return SharedType.NotShareable | 
				
			|||
 | 
				
			|||
    # def __del__(self): | 
				
			|||
    #     if connections.has_connection(self._connect_name): | 
				
			|||
    #         try: | 
				
			|||
    #             connections.disconnect(self._connect_name) | 
				
			|||
    #         except: | 
				
			|||
    #             pass | 
				
			|||
@ -0,0 +1 @@ | 
				
			|||
pymilvus | 
				
			|||
					Loading…
					
					
				
		Reference in new issue