|
@ -1,6 +1,7 @@ |
|
|
import numpy as np |
|
|
import numpy as np |
|
|
from towhee import register |
|
|
from towhee import register |
|
|
from pymilvus import Collection |
|
|
|
|
|
|
|
|
from urllib.parse import urlsplit |
|
|
|
|
|
from pymilvus import connections, Collection |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register(output_schema=['mr']) |
|
|
@register(output_schema=['mr']) |
|
@ -9,15 +10,22 @@ class Milvus: |
|
|
Milvus ANN index class. |
|
|
Milvus ANN index class. |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
def __init__(self, collection): |
|
|
|
|
|
|
|
|
def __init__(self, uri: str = None, host: str = 'localhost', port: int = 19530, collection: str = None): |
|
|
""" |
|
|
""" |
|
|
Get an existing collection. |
|
|
Get an existing collection. |
|
|
""" |
|
|
""" |
|
|
|
|
|
self._uri = uri |
|
|
|
|
|
if uri: |
|
|
|
|
|
host, port, collection = self._parse_uri() |
|
|
|
|
|
self._host = host |
|
|
|
|
|
self._port = port |
|
|
|
|
|
|
|
|
if isinstance(collection, str): |
|
|
if isinstance(collection, str): |
|
|
|
|
|
self.connect() |
|
|
collection = Collection(collection) |
|
|
collection = Collection(collection) |
|
|
self._collection = collection |
|
|
self._collection = collection |
|
|
|
|
|
|
|
|
def __call__(self, data): |
|
|
|
|
|
|
|
|
def __call__(self, *data): |
|
|
""" |
|
|
""" |
|
|
Insert data to Milvus. |
|
|
Insert data to Milvus. |
|
|
|
|
|
|
|
@ -30,9 +38,26 @@ class Milvus: |
|
|
|
|
|
|
|
|
""" |
|
|
""" |
|
|
vectors = [] |
|
|
vectors = [] |
|
|
if isinstance(data, np.ndarray): |
|
|
|
|
|
data = [data] |
|
|
|
|
|
for v in data: |
|
|
for v in data: |
|
|
vectors.append(v if isinstance(v, list) else [v]) |
|
|
vectors.append(v if isinstance(v, list) else [v]) |
|
|
mr = self._collection.insert(vectors) |
|
|
mr = self._collection.insert(vectors) |
|
|
return mr |
|
|
return mr |
|
|
|
|
|
|
|
|
|
|
|
def _parse_uri(self): |
|
|
|
|
|
try: |
|
|
|
|
|
results = urlsplit(self._uri) |
|
|
|
|
|
host, port = results.netloc.split(':') |
|
|
|
|
|
collection = results.path.strip('/') |
|
|
|
|
|
return host, port, collection |
|
|
|
|
|
except ValueError as e: |
|
|
|
|
|
raise ValueError('The input uri is not match: \'tcp://<milvus-host>:<milvus-port>/<collection-name>\', ' |
|
|
|
|
|
'such as \'tcp://localhost:19530/my_collection\'') from e |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def connect(self): |
|
|
|
|
|
if not connections.has_connection('default'): |
|
|
|
|
|
connections.connect(host=self._host, port=self._port) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def disconnect(self): |
|
|
|
|
|
connections.disconnect('default') |
|
|