from pymilvus import connections, Collection from towhee.operator import PyOperator, SharedType import uuid from typing import Union class MilvusClient(PyOperator): """ Search for embedding vectors in Milvus. Note that the Milvus collection has data before searching, Args: collection (`str`): The collection name. kwargs The kwargs with collection.search, refer to https://milvus.io/docs/v2.0.x/search.md#Prepare-search-parameters. And the `anns_field` defaults to the vector field name, `limit` defaults to 10, and `metric_type` in `param` defaults to 'L2' if there has no index(FLAT), and for default index `param`: IVF_FLAT: {"params": {"nprobe": 10}}, IVF_SQ8: {"params": {"nprobe": 10}}, IVF_PQ: {"params": {"nprobe": 10}}, HNSW: {"params": {"ef": 10}}, IVF_HNSW: {"params": {"nprobe": 10, "ef": 10}}, RHNSW_FLAT: {"params": {"ef": 10}}, RHNSW_SQ: {"params": {"ef": 10}}, RHNSW_PQ: {"params": {"ef": 10}}, ANNOY: {"params": {"search_k": 10}}. """ def __init__(self, uri: str = 'http://localhost:19530', host: str = None, port: Union[int, str] = None, token: str = None, user: str = None, password: str = None, **kwargs): """ Get an existing collection. """ self._connect_name = uuid.uuid4().hex self._connection_args = {'alias': self._connect_name} self.kwargs = kwargs if uri is not None: self._connection_args['uri'] = uri elif all(x is not None for x in [host, port]): self._connection_args['host'] = host self._connection_args['port'] = port else: raise ConnectionError('Received invalid connection arguments.') if any(x is None for x in [token, user, password]): self._connection_args['secure'] = False else: self._connection_args['secure'] = True self._connection_args['user'] = user self._connection_args['password'] = password self._connection_args['token'] = token connections.connect(**self._connection_args) def __call__(self, collection_name: str, query: 'ndarray'): self._collection = Collection(collection_name, using=self._connect_name) if 'anns_field' not in self.kwargs: fields_schema = self._collection.schema.fields for schema in fields_schema: if schema.dtype in (101, 100): self.kwargs['anns_field'] = schema.name if 'limit' not in self.kwargs: self.kwargs['limit'] = 10 index_params = { 'FLAT': {'params': {'nprobe': 10}}, 'IVF_FLAT': {'params': {'nprobe': 10}}, 'IVF_SQ8': {'params': {'nprobe': 10}}, 'IVF_PQ': {'params': {'nprobe': 10}}, 'HNSW': {'params': {'ef': 10}}, 'RHNSW_FLAT': {'params': {'ef': 10}}, 'RHNSW_SQ': {'params': {'ef': 10}}, 'RHNSW_PQ': {'params': {'ef': 10}}, 'IVF_HNSW': {'params': {'nprobe': 10, 'ef': 10}}, 'ANNOY': {'params': {'search_k': 10}}, 'AUTOINDEX': {} } if 'param' not in self.kwargs: if len(self._collection.indexes) != 0: index_type = self._collection.indexes[0].params['index_type'] self.kwargs['param'] = index_params[index_type] else: self.kwargs['param'] = index_params['IVF_FLAT'] if 'metric_type' in self.kwargs: self.kwargs['param']['metric_type'] = self.kwargs['metric_type'] else: self.kwargs['param']['metric_type'] = 'L2' self._collection.load() milvus_result = self._collection.search( data=[query], **self.kwargs ) result = [] for hit in milvus_result[0]: row = [] row.extend([hit.id, hit.score]) if 'output_fields' in self.kwargs: for k in self.kwargs['output_fields']: row.append(hit.entity.get(k)) result.append(row) return result @property def shared_type(self): return SharedType.NotShareable # def __del__(self): # if connections.has_connection(self._connect_name): # try: # connections.disconnect(self._connect_name) # except: # pass