|  |  | @ -1,6 +1,7 @@ | 
			
		
	
		
			
				
					|  |  |  | import numpy as np | 
			
		
	
		
			
				
					|  |  |  | from towhee import register | 
			
		
	
		
			
				
					|  |  |  | from pymilvus import Collection | 
			
		
	
		
			
				
					|  |  |  | from urllib.parse import urlsplit | 
			
		
	
		
			
				
					|  |  |  | from pymilvus import connections, Collection | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  | @register(output_schema=['mr']) | 
			
		
	
	
		
			
				
					|  |  | @ -9,15 +10,22 @@ class Milvus: | 
			
		
	
		
			
				
					|  |  |  |     Milvus ANN index class. | 
			
		
	
		
			
				
					|  |  |  |     """ | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     def __init__(self, collection): | 
			
		
	
		
			
				
					|  |  |  |     def __init__(self, uri: str = None, host: str = 'localhost', port: int = 19530, collection: str = None): | 
			
		
	
		
			
				
					|  |  |  |         """ | 
			
		
	
		
			
				
					|  |  |  |         Get an existing collection. | 
			
		
	
		
			
				
					|  |  |  |         """ | 
			
		
	
		
			
				
					|  |  |  |         self._uri = uri | 
			
		
	
		
			
				
					|  |  |  |         if uri: | 
			
		
	
		
			
				
					|  |  |  |             host, port, collection = self._parse_uri() | 
			
		
	
		
			
				
					|  |  |  |         self._host = host | 
			
		
	
		
			
				
					|  |  |  |         self._port = port | 
			
		
	
		
			
				
					|  |  |  |          | 
			
		
	
		
			
				
					|  |  |  |         if isinstance(collection, str): | 
			
		
	
		
			
				
					|  |  |  |             self.connect() | 
			
		
	
		
			
				
					|  |  |  |             collection = Collection(collection) | 
			
		
	
		
			
				
					|  |  |  |         self._collection = collection | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     def __call__(self, data): | 
			
		
	
		
			
				
					|  |  |  |     def __call__(self, *data): | 
			
		
	
		
			
				
					|  |  |  |         """ | 
			
		
	
		
			
				
					|  |  |  |         Insert data to Milvus. | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
	
		
			
				
					|  |  | @ -30,9 +38,26 @@ class Milvus: | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |         """ | 
			
		
	
		
			
				
					|  |  |  |         vectors = [] | 
			
		
	
		
			
				
					|  |  |  |         if isinstance(data, np.ndarray): | 
			
		
	
		
			
				
					|  |  |  |             data = [data] | 
			
		
	
		
			
				
					|  |  |  |         for v in data: | 
			
		
	
		
			
				
					|  |  |  |             vectors.append(v if isinstance(v, list) else [v]) | 
			
		
	
		
			
				
					|  |  |  |         mr = self._collection.insert(vectors) | 
			
		
	
		
			
				
					|  |  |  |         return mr | 
			
		
	
		
			
				
					|  |  |  |      | 
			
		
	
		
			
				
					|  |  |  |     def _parse_uri(self): | 
			
		
	
		
			
				
					|  |  |  |         try: | 
			
		
	
		
			
				
					|  |  |  |             results = urlsplit(self._uri) | 
			
		
	
		
			
				
					|  |  |  |             host, port = results.netloc.split(':') | 
			
		
	
		
			
				
					|  |  |  |             collection = results.path.strip('/') | 
			
		
	
		
			
				
					|  |  |  |             return host, port, collection | 
			
		
	
		
			
				
					|  |  |  |         except ValueError as e: | 
			
		
	
		
			
				
					|  |  |  |             raise ValueError('The input uri is not match: \'tcp://<milvus-host>:<milvus-port>/<collection-name>\', ' | 
			
		
	
		
			
				
					|  |  |  |                              'such as \'tcp://localhost:19530/my_collection\'') from e | 
			
		
	
		
			
				
					|  |  |  |              | 
			
		
	
		
			
				
					|  |  |  |      | 
			
		
	
		
			
				
					|  |  |  |     def connect(self): | 
			
		
	
		
			
				
					|  |  |  |         if not connections.has_connection('default'): | 
			
		
	
		
			
				
					|  |  |  |             connections.connect(host=self._host, port=self._port) | 
			
		
	
		
			
				
					|  |  |  |              | 
			
		
	
		
			
				
					|  |  |  | 
 | 
			
		
	
		
			
				
					|  |  |  |     def disconnect(self): | 
			
		
	
		
			
				
					|  |  |  |         connections.disconnect('default') | 
			
		
	
	
		
			
				
					|  |  | 
 |