import numpy as np from towhee import register from urllib.parse import urlsplit from towhee.operator import PyOperator from pymilvus import connections, Collection @register(output_schema=['mr']) class Milvus(PyOperator): """ Milvus ANN index class. """ def __init__(self, uri: str = None, host: str = 'localhost', port: int = 19530, collection: str = None): """ Get an existing collection. """ self._uri = uri if uri: host, port, collection = self._parse_uri() self._host = host self._port = port if isinstance(collection, str): self.connect() collection = Collection(collection) self._collection = collection def __call__(self, *data): """ Insert data to Milvus. Args: data (`list`): The data to insert into milvus. Returns: A MutationResult object contains `insert_count` represents how many and a `primary_keys` of primary keys. """ vectors = [] for v in data: vectors.append(v if isinstance(v, list) else [v]) mr = self._collection.insert(vectors) return mr def _parse_uri(self): try: results = urlsplit(self._uri) host, port = results.netloc.split(':') collection = results.path.strip('/') return host, port, collection except ValueError as e: raise ValueError('The input uri is not match: \'tcp://:/\', ' 'such as \'tcp://localhost:19530/my_collection\'') from e def connect(self): if not connections.has_connection('default'): connections.connect(host=self._host, port=self._port) def disconnect(self): connections.disconnect('default')