logo
Browse Source

Add osschat-milvus

main
shiyu22 11 months ago
parent
commit
2a57822207
  1. 59
      README.md
  2. 5
      __init__.py
  3. 102
      milvus_client.py
  4. 1
      requirements.txt

59
README.md

@ -1,2 +1,59 @@
# osschat-milvus
# ANN Search Operator: MilvusClient
*author: junjie.jiangjjj*
<br />
## Desription
Search embedding in [Milvus](https://milvus.io/), **please make sure you have inserted data to Milvus Collection**.
<br />
## Code Example
> Please make sure you have inserted data into Milvus and [load the collection](https://milvus.io/docs/v2.1.x/load_collection.md) to memory.
```python
from towhee import pipe, ops, DataCollection
p = pipe.input('collection_name', 'text') \
.map('text', 'vec', ops.sentence_embedding.transformers(model_name='all-MiniLM-L12-v2')) \
.flat_map(('collection_name', 'vec'), 'rows', ops.ann_search.milvus_client(host='127.0.0.1', port='19530', **{'output_fields': ['text']})) \
.map('rows', ('id', 'score', 'text'), lambda x: (x[0], x[1], x[2])) \
.output('id', 'score', 'text')
DataCollection(p('test_collection', 'cat')).show()
# result:
```
```python
from towhee import pipe, ops
# search additional info url:
from towhee import pipe, ops, DataCollection
p = pipe.input('collection_name', 'text') \
.map('text', 'vec', ops.sentence_embedding.transformers(model_name='all-MiniLM-L12-v2')) \
.map(('collection_name', 'vec'), 'rows', ops.ann_search.milvus_client(host='127.0.0.1', port='19530', **{'output_fields': ['text']})) \
.output('rows')
DataCollection(p('test_collection', 'cat')).show()
```
<br />
## Factory Constructor
Create the operator via the following factory method:
***ann_search.milvus_client(host='127.0.0.1', port='19530')***
<br />

5
__init__.py

@ -0,0 +1,5 @@
from .milvus_client import MilvusClient
def osschat_milvus(*args, **kwargs):
return MilvusClient(*args, **kwargs)

102
milvus_client.py

@ -0,0 +1,102 @@
from pymilvus import connections, Collection
from towhee.operator import PyOperator, SharedType
import uuid
class MilvusClient(PyOperator):
"""
Search for embedding vectors in Milvus. Note that the Milvus collection has data before searching,
Args:
collection (`str`):
The collection name.
kwargs
The kwargs with collection.search, refer to https://milvus.io/docs/v2.0.x/search.md#Prepare-search-parameters.
And the `anns_field` defaults to the vector field name, `limit` defaults to 10, and `metric_type` in `param` defaults to 'L2'
if there has no index(FLAT), and for default index `param`:
IVF_FLAT: {"params": {"nprobe": 10}},
IVF_SQ8: {"params": {"nprobe": 10}},
IVF_PQ: {"params": {"nprobe": 10}},
HNSW: {"params": {"ef": 10}},
IVF_HNSW: {"params": {"nprobe": 10, "ef": 10}},
RHNSW_FLAT: {"params": {"ef": 10}},
RHNSW_SQ: {"params": {"ef": 10}},
RHNSW_PQ: {"params": {"ef": 10}},
ANNOY: {"params": {"search_k": 10}}.
"""
def __init__(self, host: str = 'localhost', port: int = 19530,
user: str = None, password: str = None, **kwargs):
"""
Get an existing collection.
"""
self._host = host
self._port = port
self.kwargs = kwargs
self._connect_name = uuid.uuid4().hex
if None in [user, password]:
connections.connect(alias=self._connect_name, host=self._host, port=self._port)
else:
connections.connect(alias=self._connect_name, host=self._host, port=self._port,
user=user, password=password, secure=True)
def __call__(self, collection_name: str, query: 'ndarray'):
self._collection = Collection(collection_name, using=self._connect_name)
if 'anns_field' not in self.kwargs:
fields_schema = self._collection.schema.fields
for schema in fields_schema:
if schema.dtype in (101, 100):
self.kwargs['anns_field'] = schema.name
if 'limit' not in self.kwargs:
self.kwargs['limit'] = 10
index_params = {
'FLAT': {'params': {'nprobe': 10}},
'IVF_FLAT': {'params': {'nprobe': 10}},
'IVF_SQ8': {'params': {'nprobe': 10}},
'IVF_PQ': {'params': {'nprobe': 10}},
'HNSW': {'params': {'ef': 10}},
'RHNSW_FLAT': {'params': {'ef': 10}},
'RHNSW_SQ': {'params': {'ef': 10}},
'RHNSW_PQ': {'params': {'ef': 10}},
'IVF_HNSW': {'params': {'nprobe': 10, 'ef': 10}},
'ANNOY': {'params': {'search_k': 10}},
'AUTOINDEX': {}
}
if 'param' not in self.kwargs:
if len(self._collection.indexes) != 0:
index_type = self._collection.indexes[0].params['index_type']
self.kwargs['param'] = index_params[index_type]
else:
self.kwargs['param'] = index_params['IVF_FLAT']
if 'metric_type' in self.kwargs:
self.kwargs['param']['metric_type'] = self.kwargs['metric_type']
else:
self.kwargs['param']['metric_type'] = 'L2'
self._collection.load()
milvus_result = self._collection.search(
data=[query],
**self.kwargs
)
result = []
for hit in milvus_result[0]:
row = []
row.extend([hit.id, hit.score])
if 'output_fields' in self.kwargs:
for k in self.kwargs['output_fields']:
row.append(hit.entity.get(k))
result.append(row)
return result
@property
def shared_type(self):
return SharedType.NotShareable
# def __del__(self):
# if connections.has_connection(self._connect_name):
# try:
# connections.disconnect(self._connect_name)
# except:
# pass

1
requirements.txt

@ -0,0 +1 @@
pymilvus
Loading…
Cancel
Save