diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..0b76a19 --- /dev/null +++ b/__init__.py @@ -0,0 +1,4 @@ +from .milvus_client import MilvusClient + +def milvus_client(*args, **kwargs): + return MilvusClient(*args, **kwargs) diff --git a/milvus_client.py b/milvus_client.py new file mode 100644 index 0000000..e2c81dc --- /dev/null +++ b/milvus_client.py @@ -0,0 +1,44 @@ +import uuid +import logging +from towhee.operator import PyOperator +from pymilvus import connections, Collection + + +logger = logging.getLogger() + + +class MilvusClient(PyOperator): + """ + Milvus ANN index class. + """ + + def __init__(self, host: str, port: int, collection_name: str): + """ + Get an existing collection. + """ + self._host = host + self._port = port + self._collection_name = collection_name + self._connect_name = uuid.uuid4().hex + connections.connect(alias=self._connect_name, host=self._host, port=self._port) + self._collection = Collection(self._collection_name, using=self._connect_name) + + def __call__(self, *data): + """ + Insert data to Milvus. + + Args: + data (`list`): + The data to insert into milvus. + + Returns: + A MutationResult object contains `insert_count` represents how many and a `primary_keys` of primary keys. + + """ + vectors = [] + for v in data: + vectors.append(v if isinstance(v, list) else [v]) + mr = self._collection.insert(vectors) + if mr.err_count > 0: + raise RuntimeError("Insert to milvus failed") + return None diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..e0472f0 --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +pymilvus