faiss-index
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
53 lines
1.8 KiB
53 lines
1.8 KiB
import numpy as np
|
|
from pathlib import Path
|
|
import faiss
|
|
from towhee import register
|
|
from towhee.operator import PyOperator, SharedType
|
|
from towhee.utils.thirdparty.faiss_utils import KVStorage
|
|
# from towhee.functional.entity import Entity
|
|
|
|
|
|
@register(output_schema=['result'])
|
|
class Faiss(PyOperator):
|
|
"""
|
|
Search for embedding vectors in Faiss. Note that the index has data before searching,
|
|
refer to DataCollection Mixin `to_faiss`.
|
|
|
|
Args:
|
|
findex (`str` or `faiss.INDEX`):
|
|
The path to faiss index file(defaults to './index.bin') or faiss index.
|
|
kwargs
|
|
The kwargs with index.search, refer to https://github.com/facebookresearch/faiss/wiki. And the parameter `k` defaults to 10.
|
|
"""
|
|
def __init__(self, findex, **kwargs):
|
|
self.faiss_index = findex
|
|
self.kwargs = kwargs
|
|
self.kv_storage = None
|
|
if isinstance(findex, str):
|
|
kv_file = findex.strip('./').replace('.', '_kv.')
|
|
index_file = Path(findex)
|
|
self.faiss_index = faiss.read_index(str(index_file))
|
|
if Path(kv_file).exists():
|
|
self.kv_storage = KVStorage(kv_file)
|
|
|
|
def __call__(self, query: list):
|
|
if 'k' not in self.kwargs:
|
|
self.kwargs['k'] = 10
|
|
|
|
query = np.array([query])
|
|
scores, ids = self.faiss_index.search(query, **self.kwargs)
|
|
|
|
ids = ids[0].tolist()
|
|
result = []
|
|
for i in range(len(ids)):
|
|
if self.kv_storage is not None:
|
|
k = self.kv_storage.get(ids[i])
|
|
else:
|
|
k = ids[i]
|
|
result.append([k, scores[0][i]])
|
|
# result.append(Entity(**{'key': k, 'score': scores[0][i]}))
|
|
return result
|
|
|
|
@property
|
|
def shared_type(self):
|
|
return SharedType.NotShareable
|