logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

64 lines
1.9 KiB

import numpy as np
from towhee import register
from urllib.parse import urlsplit
from towhee.operator import PyOperator
from pymilvus import connections, Collection
@register(output_schema=['mr'])
class Milvus(PyOperator):
"""
Milvus ANN index class.
"""
def __init__(self, uri: str = None, host: str = 'localhost', port: int = 19530, collection: str = None):
"""
Get an existing collection.
"""
self._uri = uri
if uri:
host, port, collection = self._parse_uri()
self._host = host
self._port = port
if isinstance(collection, str):
self.connect()
collection = Collection(collection)
self._collection = collection
def __call__(self, *data):
"""
Insert data to Milvus.
Args:
data (`list`):
The data to insert into milvus.
Returns:
A MutationResult object contains `insert_count` represents how many and a `primary_keys` of primary keys.
"""
vectors = []
for v in data:
vectors.append(v if isinstance(v, list) else [v])
mr = self._collection.insert(vectors)
return mr
def _parse_uri(self):
try:
results = urlsplit(self._uri)
host, port = results.netloc.split(':')
collection = results.path.strip('/')
return host, port, collection
except ValueError as e:
raise ValueError('The input uri is not match: \'tcp://<milvus-host>:<milvus-port>/<collection-name>\', '
'such as \'tcp://localhost:19530/my_collection\'') from e
def connect(self):
if not connections.has_connection('default'):
connections.connect(host=self._host, port=self._port)
def disconnect(self):
connections.disconnect('default')