osschat-milvus
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
76 lines
2.2 KiB
76 lines
2.2 KiB
import uuid
|
|
import logging
|
|
from typing import Union
|
|
|
|
from towhee.operator import PyOperator, SharedType
|
|
from pymilvus import connections, Collection
|
|
|
|
|
|
logger = logging.getLogger()
|
|
|
|
|
|
class MilvusClient(PyOperator):
|
|
"""
|
|
Milvus ANN index class.
|
|
"""
|
|
|
|
def __init__(self,
|
|
uri: str = 'http://localhost:19530',
|
|
host: str = None,
|
|
port: Union[int, str] = None,
|
|
token: str = None,
|
|
user: str = None,
|
|
password: str = None
|
|
):
|
|
self._connect_name = uuid.uuid4().hex
|
|
self._connection_args = {'alias': self._connect_name}
|
|
|
|
if uri is not None:
|
|
self._connection_args['uri'] = uri
|
|
elif all(x is not None for x in [host, port]):
|
|
self._connection_args['host'] = host
|
|
self._connection_args['port'] = port
|
|
else:
|
|
raise ConnectionError('Received invalid connection arguments.')
|
|
|
|
if any(x is None for x in [token, user, password]):
|
|
self._connection_args['secure'] = False
|
|
else:
|
|
self._connection_args['secure'] = False
|
|
|
|
connections.connect(**self._connection_args)
|
|
|
|
def __call__(self, collection_name: str, *data):
|
|
"""
|
|
Insert one row to Milvus.
|
|
|
|
Args:
|
|
data (`list`):
|
|
The data to insert into milvus.
|
|
|
|
Returns:
|
|
A MutationResult object contains `insert_count` represents how many and a `primary_keys` of primary keys.
|
|
|
|
"""
|
|
self._collection = Collection(collection_name, using=self._connect_name)
|
|
row = []
|
|
for item in data:
|
|
if isinstance(item, list):
|
|
row.extend([[i] for i in item])
|
|
else:
|
|
row.append([item])
|
|
mr = self._collection.insert(row)
|
|
if mr.insert_count != len(row[0]):
|
|
raise RuntimeError("Insert to milvus failed")
|
|
return mr
|
|
|
|
@property
|
|
def shared_type(self):
|
|
return SharedType.NotShareable
|
|
|
|
# def __del__(self):
|
|
# if connections.has_connection(self._connect_name):
|
|
# try:
|
|
# connections.disconnect(self._connect_name)
|
|
# except:
|
|
# pass
|