logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

60 lines
2.0 KiB

import numpy as np
from pathlib import Path
import faiss
from towhee import register
from towhee.utils.faiss_utils import KVStorage
from towhee.functional.entity import Entity
@register(output_schema=['result'])
class Faiss:
"""
Search for embedding vectors in Faiss. Note that the index has data before searching,
refer to DataCollection Mixin `to_faiss`.
Args:
findex (`str` or `faiss.INDEX`):
The path to faiss index file(defaults to './index.bin') or faiss index.
kwargs
The kwargs with index.search, refer to https://github.com/facebookresearch/faiss/wiki. And the parameter `k` defaults to 10.
Examples:
>>> import towhee
>>> res = (
... towhee.glob['path']('./*.jpg')
... .image_decode['path', 'img']()
... .image_embedding.timm['img', 'vec'](model_name='resnet50')
... .faiss_search['vec', 'results'](findex='./faiss/faiss.index')
... .to_list()
... )
[<Entity dict_keys(['path', 'img', 'vec', 'results'])>,
<Entity dict_keys(['path', 'img', 'vec', 'results'])>]
"""
def __init__(self, findex, **kwargs):
self.faiss_index = findex
self.kwargs = kwargs
self.kv_storage = None
if isinstance(findex, str):
kv_file = findex.strip('./').replace('.', '_kv.')
index_file = Path(findex)
self.faiss_index = faiss.read_index(str(index_file))
if Path(kv_file).exists():
self.kv_storage = KVStorage(kv_file)
def __call__(self, query: list):
if 'k' not in self.kwargs:
self.kwargs['k'] = 10
query = np.array([query])
scores, ids = self.faiss_index.search(query, **self.kwargs)
ids = ids[0].tolist()
result = []
for i in range(len(ids)):
if self.kv_storage is not None:
k = self.kv_storage.get(ids[i])
else:
k = ids[i]
result.append(Entity(**{'key': k, 'score': scores[0][i]}))
return result