logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

53 lines
1.8 KiB

import numpy as np
from pathlib import Path
import faiss
from towhee import register
from towhee.operator import PyOperator, SharedType
from towhee.utils.thirdparty.faiss_utils import KVStorage
# from towhee.functional.entity import Entity
@register(output_schema=['result'])
class Faiss(PyOperator):
"""
Search for embedding vectors in Faiss. Note that the index has data before searching,
refer to DataCollection Mixin `to_faiss`.
Args:
findex (`str` or `faiss.INDEX`):
The path to faiss index file(defaults to './index.bin') or faiss index.
kwargs
The kwargs with index.search, refer to https://github.com/facebookresearch/faiss/wiki. And the parameter `k` defaults to 10.
"""
def __init__(self, findex, **kwargs):
self.faiss_index = findex
self.kwargs = kwargs
self.kv_storage = None
if isinstance(findex, str):
kv_file = findex.strip('./').replace('.', '_kv.')
index_file = Path(findex)
self.faiss_index = faiss.read_index(str(index_file))
if Path(kv_file).exists():
self.kv_storage = KVStorage(kv_file)
def __call__(self, query: list):
if 'k' not in self.kwargs:
self.kwargs['k'] = 10
query = np.array([query])
scores, ids = self.faiss_index.search(query, **self.kwargs)
ids = ids[0].tolist()
result = []
for i in range(len(ids)):
if self.kv_storage is not None:
k = self.kv_storage.get(ids[i])
else:
k = ids[i]
result.append([k, scores[0][i]])
# result.append(Entity(**{'key': k, 'score': scores[0][i]}))
return result
@property
def shared_type(self):
return SharedType.NotShareable