From f0e84bdffa75dcecec5a1bce5b6ebd8fca44be55 Mon Sep 17 00:00:00 2001 From: "junjie.jiang" Date: Fri, 28 Apr 2023 16:05:56 +0800 Subject: [PATCH] update Signed-off-by: junjie.jiang --- README.md | 12 +++++------ __init__.py | 6 +++--- faiss.py | 53 ------------------------------------------------ faiss_search.py | 40 ++++++++++++++++++++++++++++++++++++ requirements.txt | 1 - 5 files changed, 49 insertions(+), 63 deletions(-) delete mode 100644 faiss.py create mode 100644 faiss_search.py diff --git a/README.md b/README.md index fe6f087..d3f8bc1 100644 --- a/README.md +++ b/README.md @@ -26,7 +26,7 @@ Search embedding in [Faiss](https://github.com/facebookresearch/faiss), **please from towhee.dc2 import pipe, ops p = pipe.input('vec') \ - .flat_map('vec', 'rows', ops.ann_search.faiss(findex='index.bin')) \ + .flat_map('vec', 'rows', ops.ann_search.faiss('./data_dir', 5)) \ .map('rows', ('id', 'score'), lambda x: (x[0], x[1])) \ .output('id', 'score') @@ -41,7 +41,7 @@ p() Create the operator via the following factory method: -***ann_search.faiss(findex)*** +***ops.ann_search.faiss_index('./data_dir', 5)***
@@ -49,9 +49,9 @@ Create the operator via the following factory method: **Parameters:** -***findex:*** *str* or *faiss.INDEX* +***data_dir:*** *str* -The path to faiss index file or faiss index. +The path to faiss index and meta data.
@@ -62,9 +62,9 @@ The path to faiss index file or faiss index. **Parameters:** -***query:*** *list* +***query:*** *ndarray* -Query embeddings in Faiss +Query embedding in Faiss
diff --git a/__init__.py b/__init__.py index 684db96..ac2ff80 100644 --- a/__init__.py +++ b/__init__.py @@ -1,4 +1,4 @@ -from .faiss import Faiss +from .faiss_search import FaissSearch -def faiss(*args, **kwargs): - return Faiss(*args, **kwargs) \ No newline at end of file +def faiss_index(*args, **kwargs): + return FaissSearch(*args, **kwargs) diff --git a/faiss.py b/faiss.py deleted file mode 100644 index eeac3fe..0000000 --- a/faiss.py +++ /dev/null @@ -1,53 +0,0 @@ -import numpy as np -from pathlib import Path -import faiss -from towhee import register -from towhee.operator import PyOperator, SharedType -from towhee.utils.thirdparty.faiss_utils import KVStorage -# from towhee.functional.entity import Entity - - -@register(output_schema=['result']) -class Faiss(PyOperator): - """ - Search for embedding vectors in Faiss. Note that the index has data before searching, - refer to DataCollection Mixin `to_faiss`. - - Args: - findex (`str` or `faiss.INDEX`): - The path to faiss index file(defaults to './index.bin') or faiss index. - kwargs - The kwargs with index.search, refer to https://github.com/facebookresearch/faiss/wiki. And the parameter `k` defaults to 10. - """ - def __init__(self, findex, **kwargs): - self.faiss_index = findex - self.kwargs = kwargs - self.kv_storage = None - if isinstance(findex, str): - kv_file = findex.strip('./').replace('.', '_kv.') - index_file = Path(findex) - self.faiss_index = faiss.read_index(str(index_file)) - if Path(kv_file).exists(): - self.kv_storage = KVStorage(kv_file) - - def __call__(self, query: list): - if 'k' not in self.kwargs: - self.kwargs['k'] = 10 - - query = np.array([query]) - scores, ids = self.faiss_index.search(query, **self.kwargs) - - ids = ids[0].tolist() - result = [] - for i in range(len(ids)): - if self.kv_storage is not None: - k = self.kv_storage.get(ids[i]) - else: - k = ids[i] - result.append([k, scores[0][i]]) - # result.append(Entity(**{'key': k, 'score': scores[0][i]})) - return result - - @property - def shared_type(self): - return SharedType.NotShareable diff --git a/faiss_search.py b/faiss_search.py new file mode 100644 index 0000000..c5b001e --- /dev/null +++ b/faiss_search.py @@ -0,0 +1,40 @@ +import os +import pickle +import faiss +from towhee.operator import PyOperator, SharedType + + +class Meta: + def __init__(self): + self._id = 0 + self._data = {} + + def add(self, data): + self._id += 1 + self._data[self._id] = data + return self._id + + +class FaissSearch(PyOperator): + """ + Search for embedding vectors in Faiss. + Only support the data insert by ops.ann_insert.faiss_index() + """ + def __init__(self, data_dir: str, top_k: int = 5): + self._top_k = top_k + self._index_file = os.path.join(data_dir, 'faiss.index') + self._meta_file = os.path.join(data_dir, 'meta.bin') + self._index = faiss.read_index(self._index) + with open(self._meta_file, 'rb') as f: + self._meta = pickle.load(f) + + def __call__(self, query: 'ndarray'): + np_data = query.reshape(1, -1) + dist, ids = self._index.search(np_data, self._top_k) + ids = [[int(i)] for i in ids[0]] + meta_data = [self._meta_file[int(i)] for i in ids] + return list(zip(ids, dist[0], meta_data)) + + @property + def shared_type(self): + return SharedType.Shareable diff --git a/requirements.txt b/requirements.txt index 2b48261..c3ff7f0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,2 @@ faiss-cpu numpy -towhee \ No newline at end of file