logo
Readme
Files and versions

60 lines
1.8 KiB

import uuid
import logging
from towhee.operator import PyOperator, SharedType
from pymilvus import connections, Collection
logger = logging.getLogger()
class MilvusClient(PyOperator):
"""
Milvus ANN index class.
"""
def __init__(self, host: str, port: int, collection_name: str, user: str = None, password: str = None):
self._host = host
self._port = port
self._collection_name = collection_name
self._connect_name = uuid.uuid4().hex
if None in [user, password]:
connections.connect(alias=self._connect_name, host=self._host, port=self._port)
else:
connections.connect(alias=self._connect_name, host=self._host, port=self._port,
user=user, password=password, secure=True)
self._collection = Collection(self._collection_name, using=self._connect_name)
def __call__(self, *data):
"""
Insert one row to Milvus.
Args:
data (`list`):
The data to insert into milvus.
Returns:
A MutationResult object contains `insert_count` represents how many and a `primary_keys` of primary keys.
"""
row = []
for item in data:
if isinstance(item, list):
row.extend([[i] for i in item])
else:
row.append([item])
mr = self._collection.insert(row)
if mr.insert_count != len(row[0]):
raise RuntimeError("Insert to milvus failed")
return mr
@property
def shared_type(self):
return SharedType.NotShareable
# def __del__(self):
# if connections.has_connection(self._connect_name):
# try:
# connections.disconnect(self._connect_name)
# except:
# pass