diff --git a/__init__.py b/__init__.py index 675ca86..cb89df9 100644 --- a/__init__.py +++ b/__init__.py @@ -1,4 +1,4 @@ -from .milvus_client import MilvusClientls +from .milvus_client import MilvusClient def milvus_client(*args, **kwargs): diff --git a/milvus_client.py b/milvus_client.py index 2ff8056..33d5792 100644 --- a/milvus_client.py +++ b/milvus_client.py @@ -3,7 +3,7 @@ from towhee.operator import PyOperator import uuid -class Milvus(PyOperator): +class MilvusClient(PyOperator): """ Search for embedding vectors in Milvus. Note that the Milvus collection has data before searching, @@ -69,22 +69,18 @@ class Milvus(PyOperator): else: self.kwargs['param']['metric_type'] = 'L2' - def __call__(self, query: list): + def __call__(self, query: 'ndarray'): milvus_result = self._collection.search( data=[query], **self.kwargs ) result = [] - for re in milvus_result: + for hit in milvus_result[0]: row = [] - for hit in re: - row.extend([hit.id, hit.score]) - if 'output_fields' in self.kwargs: - for k in self.kwargs['output_fields']: - row.append(hit.entity._row_data[k]) - result.append(row) + row.extend([hit.id, hit.score]) + if 'output_fields' in self.kwargs: + for k in self.kwargs['output_fields']: + row.append(hit.entity._row_data[k]) + result.append(row) return result - - def __del__(self): - connections.disconnect(self._connect_name)