import uuid import logging from typing import Union from towhee.operator import PyOperator, SharedType from pymilvus import connections, Collection logger = logging.getLogger() class MilvusClient(PyOperator): """ Milvus ANN index class. """ def __init__(self, uri: str = 'http://localhost:19530', host: str = None, port: Union[int, str] = None, token: str = None, user: str = None, password: str = None ): self._connect_name = uuid.uuid4().hex self._connection_args = {'alias': self._connect_name} if uri is not None: self._connection_args['uri'] = uri elif all(x is not None for x in [host, port]): self._connection_args['host'] = host self._connection_args['port'] = port else: raise ConnectionError('Received invalid connection arguments.') if any(x is None for x in [token, user, password]): self._connection_args['secure'] = False else: self._connection_args['secure'] = False self._connection_args['user'] = user self._connection_args['password'] = password self._connection_args['token'] = token connections.connect(**self._connection_args) def __call__(self, collection_name: str, *data): """ Insert one row to Milvus. Args: data (`list`): The data to insert into milvus. Returns: A MutationResult object contains `insert_count` represents how many and a `primary_keys` of primary keys. """ self._collection = Collection(collection_name, using=self._connect_name) row = [] for item in data: if isinstance(item, list): row.extend([[i] for i in item]) else: row.append([item]) mr = self._collection.insert(row) if mr.insert_count != len(row[0]): raise RuntimeError("Insert to milvus failed") return mr @property def shared_type(self): return SharedType.NotShareable # def __del__(self): # if connections.has_connection(self._connect_name): # try: # connections.disconnect(self._connect_name) # except: # pass