import os import json import faiss from towhee.operator import PyOperator, SharedType class Meta: def __init__(self, id=0, data=None): self._id = id self._data = data if data is not None else {} def add(self, data): self._id += 1 self._data[self._id] = data return self._id def get(self, i): return self._data[i] @classmethod def load(cls, f_path): with open(f_path, 'r') as f: data = json.load(f) return cls(data['id'], data['data']) def save(self, f_path): with open(f_path, 'w') as f: json.dump({'id': self._id, 'data': self._data}, f) class FaissSearch(PyOperator): """ Search for embedding vectors in Faiss. Only support the data insert by ops.ann_insert.faiss_index() """ def __init__(self, data_dir: str, top_k: int = 5): self._top_k = top_k self._index_file = os.path.join(data_dir, 'faiss.index') self._meta_file = os.path.join(data_dir, 'meta.bin') self._index = faiss.read_index(self._index_file) self._meta = Meta.load(self._meta_file) def __call__(self, query: 'ndarray'): np_data = query.reshape(1, -1) dist, ids = self._index.search(np_data, self._top_k) ids = [int(i) for i in ids[0]] meta_data = [self._meta.get(str(i)) for i in ids] return list(zip(ids, dist[0], meta_data)) @property def shared_type(self): return SharedType.Shareable