diff --git a/milvus_client.py b/milvus_client.py index fb1cce2..47034b7 100644 --- a/milvus_client.py +++ b/milvus_client.py @@ -1,6 +1,7 @@ from pymilvus import connections, Collection from towhee.operator import PyOperator, SharedType import uuid +from typing import Union class MilvusClient(PyOperator): @@ -25,20 +26,33 @@ class MilvusClient(PyOperator): ANNOY: {"params": {"search_k": 10}}. """ - def __init__(self, host: str = 'localhost', port: int = 19530, - user: str = None, password: str = None, **kwargs): + def __init__(self, + uri: str = 'http://localhost:19530', + host: str = None, + port: Union[int, str] = None, + token: str = None, + user: str = None, + password: str = None, + **kwargs): """ Get an existing collection. """ - self._host = host - self._port = port - self.kwargs = kwargs self._connect_name = uuid.uuid4().hex - if None in [user, password]: - connections.connect(alias=self._connect_name, host=self._host, port=self._port) + self._connection_args = {'alias': self._connect_name} + self.kwargs = kwargs + + if uri is not None: + self._connection_args['uri'] = uri + elif all(x is not None for x in [host, port]): + self._connection_args['host'] = host + self._connection_args['port'] = port else: - connections.connect(alias=self._connect_name, host=self._host, port=self._port, - user=user, password=password, secure=True) + raise ConnectionError('Received invalid connection arguments.') + + if any(x is None for x in [token, user, password]): + self._connection_args['secure'] = True + + connections.connect(**self._connection_args) def __call__(self, collection_name: str, query: 'ndarray'):