faiss-index
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
51 lines
1.5 KiB
51 lines
1.5 KiB
import os
|
|
import json
|
|
import faiss
|
|
from towhee.operator import PyOperator, SharedType
|
|
|
|
class Meta:
|
|
def __init__(self, id=0, data=None):
|
|
self._id = id
|
|
self._data = data if data is not None else {}
|
|
|
|
def add(self, data):
|
|
self._id += 1
|
|
self._data[self._id] = data
|
|
return self._id
|
|
|
|
def get(self, i):
|
|
return self._data[i]
|
|
|
|
@classmethod
|
|
def load(cls, f_path):
|
|
with open(f_path, 'r') as f:
|
|
data = json.load(f)
|
|
return cls(data['id'], data['data'])
|
|
|
|
def save(self, f_path):
|
|
with open(f_path, 'w') as f:
|
|
json.dump({'id': self._id, 'data': self._data}, f)
|
|
|
|
|
|
class FaissSearch(PyOperator):
|
|
"""
|
|
Search for embedding vectors in Faiss.
|
|
Only support the data insert by ops.ann_insert.faiss_index()
|
|
"""
|
|
def __init__(self, data_dir: str, top_k: int = 5):
|
|
self._top_k = top_k
|
|
self._index_file = os.path.join(data_dir, 'faiss.index')
|
|
self._meta_file = os.path.join(data_dir, 'meta.bin')
|
|
self._index = faiss.read_index(self._index_file)
|
|
self._meta = Meta.load(self._meta_file)
|
|
|
|
def __call__(self, query: 'ndarray'):
|
|
np_data = query.reshape(1, -1)
|
|
dist, ids = self._index.search(np_data, self._top_k)
|
|
ids = [int(i) for i in ids[0]]
|
|
meta_data = [self._meta.get(str(i)) for i in ids]
|
|
return list(zip(ids, dist[0], meta_data))
|
|
|
|
@property
|
|
def shared_type(self):
|
|
return SharedType.Shareable
|