From 430ef4ffc28c4a02bae7d44fa362a313bf74325f Mon Sep 17 00:00:00 2001 From: "junjie.jiang" Date: Fri, 28 Apr 2023 17:58:42 +0800 Subject: [PATCH] update Signed-off-by: junjie.jiang --- faiss_search.py | 31 +++++++++++++++++++++---------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/faiss_search.py b/faiss_search.py index c5b001e..469ab82 100644 --- a/faiss_search.py +++ b/faiss_search.py @@ -1,19 +1,31 @@ import os -import pickle +import json import faiss from towhee.operator import PyOperator, SharedType - class Meta: - def __init__(self): - self._id = 0 - self._data = {} + def __init__(self, id=0, data=None): + self._id = id + self._data = data if data is not None else {} def add(self, data): self._id += 1 self._data[self._id] = data return self._id + def get(self, i): + return self._data[i] + + @classmethod + def load(cls, f_path): + with open(f_path, 'r') as f: + data = json.load(f) + return cls(data['id'], data['data']) + + def save(self, f_path): + with open(f_path, 'w') as f: + json.dump({'id': self._id, 'data': self._data}, f) + class FaissSearch(PyOperator): """ @@ -24,15 +36,14 @@ class FaissSearch(PyOperator): self._top_k = top_k self._index_file = os.path.join(data_dir, 'faiss.index') self._meta_file = os.path.join(data_dir, 'meta.bin') - self._index = faiss.read_index(self._index) - with open(self._meta_file, 'rb') as f: - self._meta = pickle.load(f) + self._index = faiss.read_index(self._index_file) + self._meta = Meta.load(self._meta_file) def __call__(self, query: 'ndarray'): np_data = query.reshape(1, -1) dist, ids = self._index.search(np_data, self._top_k) - ids = [[int(i)] for i in ids[0]] - meta_data = [self._meta_file[int(i)] for i in ids] + ids = [int(i) for i in ids[0]] + meta_data = [self._meta.get(str(i)) for i in ids] return list(zip(ids, dist[0], meta_data)) @property