milvus-client
copied
3 changed files with 96 additions and 0 deletions
@ -0,0 +1,5 @@ |
|||
from .milvus_client import MilvusClientls |
|||
|
|||
|
|||
def milvus_client(*args, **kwargs): |
|||
return MilvusClient(*args, **kwargs) |
@ -0,0 +1,90 @@ |
|||
from pymilvus import connections, Collection |
|||
from towhee.operator import PyOperator |
|||
import uuid |
|||
|
|||
|
|||
class Milvus(PyOperator): |
|||
""" |
|||
Search for embedding vectors in Milvus. Note that the Milvus collection has data before searching, |
|||
|
|||
Args: |
|||
collection (`str`): |
|||
The collection name. |
|||
kwargs |
|||
The kwargs with collection.search, refer to https://milvus.io/docs/v2.0.x/search.md#Prepare-search-parameters. |
|||
And the `anns_field` defaults to the vector field name, `limit` defaults to 10, and `metric_type` in `param` defaults to 'L2' |
|||
if there has no index(FLAT), and for default index `param`: |
|||
IVF_FLAT: {"params": {"nprobe": 10}}, |
|||
IVF_SQ8: {"params": {"nprobe": 10}}, |
|||
IVF_PQ: {"params": {"nprobe": 10}}, |
|||
HNSW: {"params": {"ef": 10}}, |
|||
IVF_HNSW: {"params": {"nprobe": 10, "ef": 10}}, |
|||
RHNSW_FLAT: {"params": {"ef": 10}}, |
|||
RHNSW_SQ: {"params": {"ef": 10}}, |
|||
RHNSW_PQ: {"params": {"ef": 10}}, |
|||
ANNOY: {"params": {"search_k": 10}}. |
|||
""" |
|||
|
|||
def __init__(self, host: str = 'localhost', port: int = 19530, collection_name: str = None, **kwargs): |
|||
""" |
|||
Get an existing collection. |
|||
""" |
|||
self._host = host |
|||
self._port = port |
|||
self._collection_name = collection_name |
|||
self._connect_name = uuid.uuid4().hex |
|||
connections.connect(alias=self._connect_name, host=self._host, port=self._port) |
|||
self._collection = Collection(self._collection_name, using=self._connect_name) |
|||
|
|||
self.kwargs = kwargs |
|||
if 'anns_field' not in self.kwargs: |
|||
fields_schema = self._collection.schema.fields |
|||
for schema in fields_schema: |
|||
if schema.dtype in (101, 100): |
|||
self.kwargs['anns_field'] = schema.name |
|||
|
|||
if 'limit' not in self.kwargs: |
|||
self.kwargs['limit'] = 10 |
|||
|
|||
index_params = { |
|||
'IVF_FLAT': {'params': {'nprobe': 10}}, |
|||
'IVF_SQ8': {'params': {'nprobe': 10}}, |
|||
'IVF_PQ': {'params': {'nprobe': 10}}, |
|||
'HNSW': {'params': {'ef': 10}}, |
|||
'RHNSW_FLAT': {'params': {'ef': 10}}, |
|||
'RHNSW_SQ': {'params': {'ef': 10}}, |
|||
'RHNSW_PQ': {'params': {'ef': 10}}, |
|||
'IVF_HNSW': {'params': {'nprobe': 10, 'ef': 10}}, |
|||
'ANNOY': {'params': {'search_k': 10}} |
|||
} |
|||
|
|||
if 'param' not in self.kwargs: |
|||
if len(self._collection.indexes) != 0: |
|||
index_type = self._collection.indexes[0].params['index_type'] |
|||
self.kwargs['param'] = index_params[index_type] |
|||
else: |
|||
self.kwargs['param'] = index_params['IVF_FLAT'] |
|||
if 'metric_type' in self.kwargs: |
|||
self.kwargs['param']['metric_type'] = self.kwargs['metric_type'] |
|||
else: |
|||
self.kwargs['param']['metric_type'] = 'L2' |
|||
|
|||
def __call__(self, query: list): |
|||
milvus_result = self._collection.search( |
|||
data=[query], |
|||
**self.kwargs |
|||
) |
|||
|
|||
result = [] |
|||
for re in milvus_result: |
|||
row = [] |
|||
for hit in re: |
|||
row.extend([hit.id, hit.score]) |
|||
if 'output_fields' in self.kwargs: |
|||
for k in self.kwargs['output_fields']: |
|||
row.append(hit.entity._row_data[k]) |
|||
result.append(row) |
|||
return result |
|||
|
|||
def __del__(self): |
|||
connections.disconnect(self._connect_name) |
@ -0,0 +1 @@ |
|||
pymilvus |
Loading…
Reference in new issue