diff --git a/README.md b/README.md index bbc2d8f..e185cbf 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,59 @@ -# osschat-milvus +# ANN Search Operator: MilvusClient +*author: junjie.jiangjjj* + +
+ +## Desription +Search embedding in [Milvus](https://milvus.io/), **please make sure you have inserted data to Milvus Collection**. + +
+ + + +## Code Example + +> Please make sure you have inserted data into Milvus and [load the collection](https://milvus.io/docs/v2.1.x/load_collection.md) to memory. + + +```python +from towhee import pipe, ops, DataCollection + +p = pipe.input('collection_name', 'text') \ + .map('text', 'vec', ops.sentence_embedding.transformers(model_name='all-MiniLM-L12-v2')) \ + .flat_map(('collection_name', 'vec'), 'rows', ops.ann_search.milvus_client(host='127.0.0.1', port='19530', **{'output_fields': ['text']})) \ + .map('rows', ('id', 'score', 'text'), lambda x: (x[0], x[1], x[2])) \ + .output('id', 'score', 'text') + +DataCollection(p('test_collection', 'cat')).show() + +# result: + +``` + +```python +from towhee import pipe, ops + +# search additional info url: +from towhee import pipe, ops, DataCollection + +p = pipe.input('collection_name', 'text') \ + .map('text', 'vec', ops.sentence_embedding.transformers(model_name='all-MiniLM-L12-v2')) \ + .map(('collection_name', 'vec'), 'rows', ops.ann_search.milvus_client(host='127.0.0.1', port='19530', **{'output_fields': ['text']})) \ + .output('rows') + +DataCollection(p('test_collection', 'cat')).show() +``` + +
+ + + +## Factory Constructor + +Create the operator via the following factory method: + +***ann_search.milvus_client(host='127.0.0.1', port='19530')*** + + +
diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..1f73cb7 --- /dev/null +++ b/__init__.py @@ -0,0 +1,5 @@ +from .milvus_client import MilvusClient + + +def osschat_milvus(*args, **kwargs): + return MilvusClient(*args, **kwargs) diff --git a/milvus_client.py b/milvus_client.py new file mode 100644 index 0000000..fb1cce2 --- /dev/null +++ b/milvus_client.py @@ -0,0 +1,102 @@ +from pymilvus import connections, Collection +from towhee.operator import PyOperator, SharedType +import uuid + + +class MilvusClient(PyOperator): + """ + Search for embedding vectors in Milvus. Note that the Milvus collection has data before searching, + + Args: + collection (`str`): + The collection name. + kwargs + The kwargs with collection.search, refer to https://milvus.io/docs/v2.0.x/search.md#Prepare-search-parameters. + And the `anns_field` defaults to the vector field name, `limit` defaults to 10, and `metric_type` in `param` defaults to 'L2' + if there has no index(FLAT), and for default index `param`: + IVF_FLAT: {"params": {"nprobe": 10}}, + IVF_SQ8: {"params": {"nprobe": 10}}, + IVF_PQ: {"params": {"nprobe": 10}}, + HNSW: {"params": {"ef": 10}}, + IVF_HNSW: {"params": {"nprobe": 10, "ef": 10}}, + RHNSW_FLAT: {"params": {"ef": 10}}, + RHNSW_SQ: {"params": {"ef": 10}}, + RHNSW_PQ: {"params": {"ef": 10}}, + ANNOY: {"params": {"search_k": 10}}. + """ + + def __init__(self, host: str = 'localhost', port: int = 19530, + user: str = None, password: str = None, **kwargs): + """ + Get an existing collection. + """ + self._host = host + self._port = port + self.kwargs = kwargs + self._connect_name = uuid.uuid4().hex + if None in [user, password]: + connections.connect(alias=self._connect_name, host=self._host, port=self._port) + else: + connections.connect(alias=self._connect_name, host=self._host, port=self._port, + user=user, password=password, secure=True) + + + def __call__(self, collection_name: str, query: 'ndarray'): + self._collection = Collection(collection_name, using=self._connect_name) + if 'anns_field' not in self.kwargs: + fields_schema = self._collection.schema.fields + for schema in fields_schema: + if schema.dtype in (101, 100): + self.kwargs['anns_field'] = schema.name + if 'limit' not in self.kwargs: + self.kwargs['limit'] = 10 + index_params = { + 'FLAT': {'params': {'nprobe': 10}}, + 'IVF_FLAT': {'params': {'nprobe': 10}}, + 'IVF_SQ8': {'params': {'nprobe': 10}}, + 'IVF_PQ': {'params': {'nprobe': 10}}, + 'HNSW': {'params': {'ef': 10}}, + 'RHNSW_FLAT': {'params': {'ef': 10}}, + 'RHNSW_SQ': {'params': {'ef': 10}}, + 'RHNSW_PQ': {'params': {'ef': 10}}, + 'IVF_HNSW': {'params': {'nprobe': 10, 'ef': 10}}, + 'ANNOY': {'params': {'search_k': 10}}, + 'AUTOINDEX': {} + } + if 'param' not in self.kwargs: + if len(self._collection.indexes) != 0: + index_type = self._collection.indexes[0].params['index_type'] + self.kwargs['param'] = index_params[index_type] + else: + self.kwargs['param'] = index_params['IVF_FLAT'] + if 'metric_type' in self.kwargs: + self.kwargs['param']['metric_type'] = self.kwargs['metric_type'] + else: + self.kwargs['param']['metric_type'] = 'L2' + + self._collection.load() + milvus_result = self._collection.search( + data=[query], + **self.kwargs + ) + + result = [] + for hit in milvus_result[0]: + row = [] + row.extend([hit.id, hit.score]) + if 'output_fields' in self.kwargs: + for k in self.kwargs['output_fields']: + row.append(hit.entity.get(k)) + result.append(row) + return result + + @property + def shared_type(self): + return SharedType.NotShareable + + # def __del__(self): + # if connections.has_connection(self._connect_name): + # try: + # connections.disconnect(self._connect_name) + # except: + # pass diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..e0472f0 --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +pymilvus