logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

40 lines
1.1 KiB

import os
import pickle
import faiss
from towhee.operator import PyOperator, SharedType
class Meta:
def __init__(self):
self._id = 0
self._data = {}
def add(self, data):
self._id += 1
self._data[self._id] = data
return self._id
class FaissSearch(PyOperator):
"""
Search for embedding vectors in Faiss.
Only support the data insert by ops.ann_insert.faiss_index()
"""
def __init__(self, data_dir: str, top_k: int = 5):
self._top_k = top_k
self._index_file = os.path.join(data_dir, 'faiss.index')
self._meta_file = os.path.join(data_dir, 'meta.bin')
self._index = faiss.read_index(self._index)
with open(self._meta_file, 'rb') as f:
self._meta = pickle.load(f)
def __call__(self, query: 'ndarray'):
np_data = query.reshape(1, -1)
dist, ids = self._index.search(np_data, self._top_k)
ids = [[int(i)] for i in ids[0]]
meta_data = [self._meta_file[int(i)] for i in ids]
return list(zip(ids, dist[0], meta_data))
@property
def shared_type(self):
return SharedType.Shareable