osschat-milvus
copied
4 changed files with 166 additions and 1 deletions
@ -1,2 +1,59 @@ |
|||||
# osschat-milvus |
|
||||
|
# ANN Search Operator: MilvusClient |
||||
|
|
||||
|
*author: junjie.jiangjjj* |
||||
|
|
||||
|
<br /> |
||||
|
|
||||
|
## Desription |
||||
|
Search embedding in [Milvus](https://milvus.io/), **please make sure you have inserted data to Milvus Collection**. |
||||
|
|
||||
|
<br /> |
||||
|
|
||||
|
|
||||
|
|
||||
|
## Code Example |
||||
|
|
||||
|
> Please make sure you have inserted data into Milvus and [load the collection](https://milvus.io/docs/v2.1.x/load_collection.md) to memory. |
||||
|
|
||||
|
|
||||
|
```python |
||||
|
from towhee import pipe, ops, DataCollection |
||||
|
|
||||
|
p = pipe.input('collection_name', 'text') \ |
||||
|
.map('text', 'vec', ops.sentence_embedding.transformers(model_name='all-MiniLM-L12-v2')) \ |
||||
|
.flat_map(('collection_name', 'vec'), 'rows', ops.ann_search.milvus_client(host='127.0.0.1', port='19530', **{'output_fields': ['text']})) \ |
||||
|
.map('rows', ('id', 'score', 'text'), lambda x: (x[0], x[1], x[2])) \ |
||||
|
.output('id', 'score', 'text') |
||||
|
|
||||
|
DataCollection(p('test_collection', 'cat')).show() |
||||
|
|
||||
|
# result: |
||||
|
|
||||
|
``` |
||||
|
|
||||
|
```python |
||||
|
from towhee import pipe, ops |
||||
|
|
||||
|
# search additional info url: |
||||
|
from towhee import pipe, ops, DataCollection |
||||
|
|
||||
|
p = pipe.input('collection_name', 'text') \ |
||||
|
.map('text', 'vec', ops.sentence_embedding.transformers(model_name='all-MiniLM-L12-v2')) \ |
||||
|
.map(('collection_name', 'vec'), 'rows', ops.ann_search.milvus_client(host='127.0.0.1', port='19530', **{'output_fields': ['text']})) \ |
||||
|
.output('rows') |
||||
|
|
||||
|
DataCollection(p('test_collection', 'cat')).show() |
||||
|
``` |
||||
|
|
||||
|
<br /> |
||||
|
|
||||
|
|
||||
|
|
||||
|
## Factory Constructor |
||||
|
|
||||
|
Create the operator via the following factory method: |
||||
|
|
||||
|
***ann_search.milvus_client(host='127.0.0.1', port='19530')*** |
||||
|
|
||||
|
|
||||
|
<br /> |
||||
|
@ -0,0 +1,5 @@ |
|||||
|
from .milvus_client import MilvusClient |
||||
|
|
||||
|
|
||||
|
def osschat_milvus(*args, **kwargs): |
||||
|
return MilvusClient(*args, **kwargs) |
@ -0,0 +1,102 @@ |
|||||
|
from pymilvus import connections, Collection |
||||
|
from towhee.operator import PyOperator, SharedType |
||||
|
import uuid |
||||
|
|
||||
|
|
||||
|
class MilvusClient(PyOperator): |
||||
|
""" |
||||
|
Search for embedding vectors in Milvus. Note that the Milvus collection has data before searching, |
||||
|
|
||||
|
Args: |
||||
|
collection (`str`): |
||||
|
The collection name. |
||||
|
kwargs |
||||
|
The kwargs with collection.search, refer to https://milvus.io/docs/v2.0.x/search.md#Prepare-search-parameters. |
||||
|
And the `anns_field` defaults to the vector field name, `limit` defaults to 10, and `metric_type` in `param` defaults to 'L2' |
||||
|
if there has no index(FLAT), and for default index `param`: |
||||
|
IVF_FLAT: {"params": {"nprobe": 10}}, |
||||
|
IVF_SQ8: {"params": {"nprobe": 10}}, |
||||
|
IVF_PQ: {"params": {"nprobe": 10}}, |
||||
|
HNSW: {"params": {"ef": 10}}, |
||||
|
IVF_HNSW: {"params": {"nprobe": 10, "ef": 10}}, |
||||
|
RHNSW_FLAT: {"params": {"ef": 10}}, |
||||
|
RHNSW_SQ: {"params": {"ef": 10}}, |
||||
|
RHNSW_PQ: {"params": {"ef": 10}}, |
||||
|
ANNOY: {"params": {"search_k": 10}}. |
||||
|
""" |
||||
|
|
||||
|
def __init__(self, host: str = 'localhost', port: int = 19530, |
||||
|
user: str = None, password: str = None, **kwargs): |
||||
|
""" |
||||
|
Get an existing collection. |
||||
|
""" |
||||
|
self._host = host |
||||
|
self._port = port |
||||
|
self.kwargs = kwargs |
||||
|
self._connect_name = uuid.uuid4().hex |
||||
|
if None in [user, password]: |
||||
|
connections.connect(alias=self._connect_name, host=self._host, port=self._port) |
||||
|
else: |
||||
|
connections.connect(alias=self._connect_name, host=self._host, port=self._port, |
||||
|
user=user, password=password, secure=True) |
||||
|
|
||||
|
|
||||
|
def __call__(self, collection_name: str, query: 'ndarray'): |
||||
|
self._collection = Collection(collection_name, using=self._connect_name) |
||||
|
if 'anns_field' not in self.kwargs: |
||||
|
fields_schema = self._collection.schema.fields |
||||
|
for schema in fields_schema: |
||||
|
if schema.dtype in (101, 100): |
||||
|
self.kwargs['anns_field'] = schema.name |
||||
|
if 'limit' not in self.kwargs: |
||||
|
self.kwargs['limit'] = 10 |
||||
|
index_params = { |
||||
|
'FLAT': {'params': {'nprobe': 10}}, |
||||
|
'IVF_FLAT': {'params': {'nprobe': 10}}, |
||||
|
'IVF_SQ8': {'params': {'nprobe': 10}}, |
||||
|
'IVF_PQ': {'params': {'nprobe': 10}}, |
||||
|
'HNSW': {'params': {'ef': 10}}, |
||||
|
'RHNSW_FLAT': {'params': {'ef': 10}}, |
||||
|
'RHNSW_SQ': {'params': {'ef': 10}}, |
||||
|
'RHNSW_PQ': {'params': {'ef': 10}}, |
||||
|
'IVF_HNSW': {'params': {'nprobe': 10, 'ef': 10}}, |
||||
|
'ANNOY': {'params': {'search_k': 10}}, |
||||
|
'AUTOINDEX': {} |
||||
|
} |
||||
|
if 'param' not in self.kwargs: |
||||
|
if len(self._collection.indexes) != 0: |
||||
|
index_type = self._collection.indexes[0].params['index_type'] |
||||
|
self.kwargs['param'] = index_params[index_type] |
||||
|
else: |
||||
|
self.kwargs['param'] = index_params['IVF_FLAT'] |
||||
|
if 'metric_type' in self.kwargs: |
||||
|
self.kwargs['param']['metric_type'] = self.kwargs['metric_type'] |
||||
|
else: |
||||
|
self.kwargs['param']['metric_type'] = 'L2' |
||||
|
|
||||
|
self._collection.load() |
||||
|
milvus_result = self._collection.search( |
||||
|
data=[query], |
||||
|
**self.kwargs |
||||
|
) |
||||
|
|
||||
|
result = [] |
||||
|
for hit in milvus_result[0]: |
||||
|
row = [] |
||||
|
row.extend([hit.id, hit.score]) |
||||
|
if 'output_fields' in self.kwargs: |
||||
|
for k in self.kwargs['output_fields']: |
||||
|
row.append(hit.entity.get(k)) |
||||
|
result.append(row) |
||||
|
return result |
||||
|
|
||||
|
@property |
||||
|
def shared_type(self): |
||||
|
return SharedType.NotShareable |
||||
|
|
||||
|
# def __del__(self): |
||||
|
# if connections.has_connection(self._connect_name): |
||||
|
# try: |
||||
|
# connections.disconnect(self._connect_name) |
||||
|
# except: |
||||
|
# pass |
@ -0,0 +1 @@ |
|||||
|
pymilvus |
Loading…
Reference in new issue