diff --git a/README.md b/README.md index 67a9ef8..a5f26c9 100644 --- a/README.md +++ b/README.md @@ -52,6 +52,13 @@ Create the operator via the following factory method: **Parameters:** +All connection arguments are set as None by default. You must provide at least one valid value for uri or host/port. +The order of use: uri > host/port, token > user/password. + +***uri:*** *str* + +The uri for Milvus. + ***host:*** *str* The host for Milvus. @@ -60,6 +67,10 @@ The host for Milvus. The port for Milvus. +***token:*** *str* + +The token for Milvus. + ***user:*** *str* The user for Zilliz Cloud, defaults to None. diff --git a/milvus_client.py b/milvus_client.py index 75a75c8..c270c13 100644 --- a/milvus_client.py +++ b/milvus_client.py @@ -1,5 +1,7 @@ import uuid import logging +from typing import Union + from towhee.operator import PyOperator, SharedType from pymilvus import connections, Collection @@ -12,15 +14,29 @@ class MilvusClient(PyOperator): Milvus ANN index class. """ - def __init__(self, host: str, port: int, user: str = None, password: str = None): - self._host = host - self._port = port + def __init__(self, + uri: str = None, + host: str = None, + port: Union[int, str] = None, + token: str = None, + user: str = None, + password: str = None + ): self._connect_name = uuid.uuid4().hex - if None in [user, password]: - connections.connect(alias=self._connect_name, host=self._host, port=self._port) + self._connection_args = {'alias': self._connect_name} + + if uri is not None: + self._connection_args['uri'] = uri + elif all(x is not None for x in [host, port]): + self._connection_args['host'] = host + self._connection_args['port'] = port else: - connections.connect(alias=self._connect_name, host=self._host, port=self._port, - user=user, password=password, secure=True) + raise ConnectionError('Received invalid connection arguments.') + + if any(x is None for x in [token, user, password]): + self._connection_args['secure'] = True + + connections.connect(**self._connection_args) def __call__(self, collection_name: str, *data): """