logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

67 lines
2.0 KiB

import uuid
import logging
from towhee.operator import PyOperator, SharedType
from pymilvus import connections, Collection
logger = logging.getLogger()
class MilvusClient(PyOperator):
"""
Milvus ANN index class.
"""
def __init__(
self, host: str = 'localhost', port: int = 19530, collection_name: str = None, uri: str = None, user: str = None, password: str = None, token: str = None
):
self._host = host
self._port = port
self._uri = uri
self._collection_name = collection_name
self._connect_name = uuid.uuid4().hex
if uri and token:
connections.connect(alias=self._connect_name, uri=self._uri, token=token, secure=True)
elif user and password:
connections.connect(alias=self._connect_name, host=self._host, port=self._port,
user=user, password=password, secure=True)
else:
connections.connect(alias=self._connect_name, host=self._host, port=self._port)
self._collection = Collection(self._collection_name, using=self._connect_name)
def __call__(self, *data):
"""
Insert one row to Milvus.
Args:
data (`list`):
The data to insert into milvus.
Returns:
A MutationResult object contains `insert_count` represents how many and a `primary_keys` of primary keys.
"""
row = []
for item in data:
if isinstance(item, list):
row.extend([[i] for i in item])
else:
row.append([item])
mr = self._collection.insert(row)
if mr.insert_count != len(row[0]):
raise RuntimeError("Insert to milvus failed")
return mr
@property
def shared_type(self):
return SharedType.NotShareable
# def __del__(self):
# if connections.has_connection(self._connect_name):
# try:
# connections.disconnect(self._connect_name)
# except:
# pass