diff --git a/README.md b/README.md index 673e70d..2da4242 100644 --- a/README.md +++ b/README.md @@ -17,15 +17,6 @@ and connected Milvus before loading the data.** ## Code Example -### Get the Collection first - -```python -from pymilvus import Collection, connections - -connections.connect(host='localhost', port='19530') -collection = Collection('your_collection_name') -``` - ### Example *Write the pipeline in simplified style:* @@ -34,7 +25,7 @@ collection = Collection('your_collection_name') import towhee towhee.dc(your_embeddings) \ - .ann_insert.milvus(collection=collection) + .ann_insert.milvus(uri='tcp://localhost:19530/my_collection') ``` *Write a same pipeline with explicit inputs/outputs name specifications:* @@ -43,7 +34,7 @@ towhee.dc(your_embeddings) \ import towhee towhee.dc['vec'](your_embeddings) \ - .ann_insert.milvus['vec', 'results'](collection=collection) \ + .ann_insert.milvus['vec', 'results'](uri='tcp://localhost:19530/my_collection') \ .show() ``` @@ -74,9 +65,9 @@ Create the operator via the following factory method: **Parameters:** -***collection:*** *str* or *pymilvus.Collection* +***uri:*** *str* -The collection name or pymilvus.Collection in Milvus. +The uri for Milvus Collection, such as `tcp://:/`.
diff --git a/milvus.py b/milvus.py index 7c60bd2..890e530 100644 --- a/milvus.py +++ b/milvus.py @@ -1,6 +1,7 @@ import numpy as np from towhee import register -from pymilvus import Collection +from urllib.parse import urlsplit +from pymilvus import connections, Collection @register(output_schema=['mr']) @@ -9,15 +10,22 @@ class Milvus: Milvus ANN index class. """ - def __init__(self, collection): + def __init__(self, uri: str = None, host: str = 'localhost', port: int = 19530, collection: str = None): """ Get an existing collection. """ + self._uri = uri + if uri: + host, port, collection = self._parse_uri() + self._host = host + self._port = port + if isinstance(collection, str): + self.connect() collection = Collection(collection) self._collection = collection - def __call__(self, data): + def __call__(self, *data): """ Insert data to Milvus. @@ -30,9 +38,26 @@ class Milvus: """ vectors = [] - if isinstance(data, np.ndarray): - data = [data] for v in data: vectors.append(v if isinstance(v, list) else [v]) mr = self._collection.insert(vectors) return mr + + def _parse_uri(self): + try: + results = urlsplit(self._uri) + host, port = results.netloc.split(':') + collection = results.path.strip('/') + return host, port, collection + except ValueError as e: + raise ValueError('The input uri is not match: \'tcp://:/\', ' + 'such as \'tcp://localhost:19530/my_collection\'') from e + + + def connect(self): + if not connections.has_connection('default'): + connections.connect(host=self._host, port=self._port) + + + def disconnect(self): + connections.disconnect('default')