|
|
@ -1,5 +1,7 @@ |
|
|
|
import uuid |
|
|
|
import logging |
|
|
|
from typing import Union |
|
|
|
|
|
|
|
from towhee.operator import PyOperator, SharedType |
|
|
|
from pymilvus import connections, Collection |
|
|
|
|
|
|
@ -12,15 +14,29 @@ class MilvusClient(PyOperator): |
|
|
|
Milvus ANN index class. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, host: str, port: int, user: str = None, password: str = None): |
|
|
|
self._host = host |
|
|
|
self._port = port |
|
|
|
def __init__(self, |
|
|
|
uri: str = None, |
|
|
|
host: str = None, |
|
|
|
port: Union[int, str] = None, |
|
|
|
token: str = None, |
|
|
|
user: str = None, |
|
|
|
password: str = None |
|
|
|
): |
|
|
|
self._connect_name = uuid.uuid4().hex |
|
|
|
if None in [user, password]: |
|
|
|
connections.connect(alias=self._connect_name, host=self._host, port=self._port) |
|
|
|
self._connection_args = {'alias': self._connect_name} |
|
|
|
|
|
|
|
if uri is not None: |
|
|
|
self._connection_args['uri'] = uri |
|
|
|
elif all(x is not None for x in [host, port]): |
|
|
|
self._connection_args['host'] = host |
|
|
|
self._connection_args['port'] = port |
|
|
|
else: |
|
|
|
connections.connect(alias=self._connect_name, host=self._host, port=self._port, |
|
|
|
user=user, password=password, secure=True) |
|
|
|
raise ConnectionError('Received invalid connection arguments.') |
|
|
|
|
|
|
|
if any(x is None for x in [token, user, password]): |
|
|
|
self._connection_args['secure'] = True |
|
|
|
|
|
|
|
connections.connect(**self._connection_args) |
|
|
|
|
|
|
|
def __call__(self, collection_name: str, *data): |
|
|
|
""" |
|
|
|