logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

80 lines
2.4 KiB

import uuid
import logging
from typing import Union
from towhee.operator import PyOperator, SharedType
from pymilvus import connections, Collection
logger = logging.getLogger()
class MilvusClient(PyOperator):
"""
Milvus ANN index class.
"""
def __init__(self,
uri: str = 'http://localhost:19530',
host: str = None,
port: Union[int, str] = None,
token: str = None,
user: str = None,
password: str = None
):
self._connect_name = uuid.uuid4().hex
self._connection_args = {'alias': self._connect_name}
if uri is not None:
self._connection_args['uri'] = uri
elif all(x is not None for x in [host, port]):
self._connection_args['host'] = host
self._connection_args['port'] = port
else:
raise ConnectionError('Received invalid connection arguments.')
if any(x is None for x in [token, user, password]):
self._connection_args['secure'] = False
else:
self._connection_args['secure'] = False
self._connection_args['user'] = user
self._connection_args['password'] = password
self._connection_args['token'] = token
connections.connect(**self._connection_args)
def __call__(self, collection_name: str, *data):
"""
Insert one row to Milvus.
Args:
data (`list`):
The data to insert into milvus.
Returns:
A MutationResult object contains `insert_count` represents how many and a `primary_keys` of primary keys.
"""
self._collection = Collection(collection_name, using=self._connect_name)
row = []
for item in data:
if isinstance(item, list):
row.extend([[i] for i in item])
else:
row.append([item])
mr = self._collection.insert(row)
if mr.insert_count != len(row[0]):
raise RuntimeError("Insert to milvus failed")
return mr
@property
def shared_type(self):
return SharedType.NotShareable
# def __del__(self):
# if connections.has_connection(self._connect_name):
# try:
# connections.disconnect(self._connect_name)
# except:
# pass