import numpy as np from pathlib import Path import faiss from towhee import register from towhee.operator import PyOperator, SharedType from towhee.utils.thirdparty.faiss_utils import KVStorage # from towhee.functional.entity import Entity @register(output_schema=['result']) class Faiss(PyOperator): """ Search for embedding vectors in Faiss. Note that the index has data before searching, refer to DataCollection Mixin `to_faiss`. Args: findex (`str` or `faiss.INDEX`): The path to faiss index file(defaults to './index.bin') or faiss index. kwargs The kwargs with index.search, refer to https://github.com/facebookresearch/faiss/wiki. And the parameter `k` defaults to 10. """ def __init__(self, findex, **kwargs): self.faiss_index = findex self.kwargs = kwargs self.kv_storage = None if isinstance(findex, str): kv_file = findex.strip('./').replace('.', '_kv.') index_file = Path(findex) self.faiss_index = faiss.read_index(str(index_file)) if Path(kv_file).exists(): self.kv_storage = KVStorage(kv_file) def __call__(self, query: list): if 'k' not in self.kwargs: self.kwargs['k'] = 10 query = np.array([query]) scores, ids = self.faiss_index.search(query, **self.kwargs) ids = ids[0].tolist() result = [] for i in range(len(ids)): if self.kv_storage is not None: k = self.kv_storage.get(ids[i]) else: k = ids[i] result.append([k, scores[0][i]]) # result.append(Entity(**{'key': k, 'score': scores[0][i]})) return result @property def shared_type(self): return SharedType.NotShareable