logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

122 lines
4.6 KiB

from pymilvus import connections, Collection
from towhee.operator import PyOperator, SharedType
import uuid
from typing import Union
class MilvusClient(PyOperator):
"""
Search for embedding vectors in Milvus. Note that the Milvus collection has data before searching,
Args:
collection (`str`):
The collection name.
kwargs
The kwargs with collection.search, refer to https://milvus.io/docs/v2.0.x/search.md#Prepare-search-parameters.
And the `anns_field` defaults to the vector field name, `limit` defaults to 10, and `metric_type` in `param` defaults to 'L2'
if there has no index(FLAT), and for default index `param`:
IVF_FLAT: {"params": {"nprobe": 10}},
IVF_SQ8: {"params": {"nprobe": 10}},
IVF_PQ: {"params": {"nprobe": 10}},
HNSW: {"params": {"ef": 10}},
IVF_HNSW: {"params": {"nprobe": 10, "ef": 10}},
RHNSW_FLAT: {"params": {"ef": 10}},
RHNSW_SQ: {"params": {"ef": 10}},
RHNSW_PQ: {"params": {"ef": 10}},
ANNOY: {"params": {"search_k": 10}}.
"""
def __init__(self,
uri: str = 'http://localhost:19530',
host: str = None,
port: Union[int, str] = None,
token: str = None,
user: str = None,
password: str = None,
**kwargs):
"""
Get an existing collection.
"""
self._connect_name = uuid.uuid4().hex
self._connection_args = {'alias': self._connect_name}
self.kwargs = kwargs
if uri is not None:
self._connection_args['uri'] = uri
elif all(x is not None for x in [host, port]):
self._connection_args['host'] = host
self._connection_args['port'] = port
else:
raise ConnectionError('Received invalid connection arguments.')
if any(x is None for x in [token, user, password]):
self._connection_args['secure'] = False
else:
self._connection_args['secure'] = True
self._connection_args['user'] = user
self._connection_args['password'] = password
self._connection_args['token'] = token
connections.connect(**self._connection_args)
def __call__(self, collection_name: str, query: 'ndarray'):
self._collection = Collection(collection_name, using=self._connect_name)
if 'anns_field' not in self.kwargs:
fields_schema = self._collection.schema.fields
for schema in fields_schema:
if schema.dtype in (101, 100):
self.kwargs['anns_field'] = schema.name
if 'limit' not in self.kwargs:
self.kwargs['limit'] = 10
index_params = {
'FLAT': {'params': {'nprobe': 10}},
'IVF_FLAT': {'params': {'nprobe': 10}},
'IVF_SQ8': {'params': {'nprobe': 10}},
'IVF_PQ': {'params': {'nprobe': 10}},
'HNSW': {'params': {'ef': 10}},
'RHNSW_FLAT': {'params': {'ef': 10}},
'RHNSW_SQ': {'params': {'ef': 10}},
'RHNSW_PQ': {'params': {'ef': 10}},
'IVF_HNSW': {'params': {'nprobe': 10, 'ef': 10}},
'ANNOY': {'params': {'search_k': 10}},
'AUTOINDEX': {}
}
if 'param' not in self.kwargs:
if len(self._collection.indexes) != 0:
index_type = self._collection.indexes[0].params['index_type']
self.kwargs['param'] = index_params[index_type]
else:
self.kwargs['param'] = index_params['IVF_FLAT']
if 'metric_type' in self.kwargs:
self.kwargs['param']['metric_type'] = self.kwargs['metric_type']
else:
self.kwargs['param']['metric_type'] = 'L2'
self._collection.load()
milvus_result = self._collection.search(
data=[query],
**self.kwargs
)
result = []
for hit in milvus_result[0]:
row = []
row.extend([hit.id, hit.score])
if 'output_fields' in self.kwargs:
for k in self.kwargs['output_fields']:
row.append(hit.entity.get(k))
result.append(row)
return result
@property
def shared_type(self):
return SharedType.NotShareable
# def __del__(self):
# if connections.has_connection(self._connect_name):
# try:
# connections.disconnect(self._connect_name)
# except:
# pass