faiss-index
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
61 lines
2.0 KiB
61 lines
2.0 KiB
3 years ago
|
import numpy as np
|
||
|
from pathlib import Path
|
||
|
import faiss
|
||
|
from towhee import register
|
||
|
from towhee.utils.faiss_utils import KVStorage
|
||
|
from towhee.functional.entity import Entity
|
||
|
|
||
|
|
||
|
@register(output_schema=['result'])
|
||
|
class Faiss:
|
||
|
"""
|
||
|
Search for embedding vectors in Faiss. Note that the index has data before searching,
|
||
|
refer to DataCollection Mixin `to_faiss`.
|
||
|
|
||
|
Args:
|
||
|
findex (`str` or `faiss.INDEX`):
|
||
|
The path to faiss index file(defaults to './index.bin') or faiss index.
|
||
|
kwargs
|
||
|
The kwargs with index.search, refer to https://github.com/facebookresearch/faiss/wiki. And the parameter `k` defaults to 10.
|
||
|
|
||
|
Examples:
|
||
|
|
||
|
>>> import towhee
|
||
|
>>> res = (
|
||
|
... towhee.glob['path']('./*.jpg')
|
||
|
... .image_decode['path', 'img']()
|
||
|
... .image_embedding.timm['img', 'vec'](model_name='resnet50')
|
||
|
... .faiss_search['vec', 'results'](findex='./faiss/faiss.index')
|
||
|
... .to_list()
|
||
|
... )
|
||
|
[<Entity dict_keys(['path', 'img', 'vec', 'results'])>,
|
||
|
<Entity dict_keys(['path', 'img', 'vec', 'results'])>]
|
||
|
"""
|
||
|
def __init__(self, findex, **kwargs):
|
||
|
self.faiss_index = findex
|
||
|
self.kwargs = kwargs
|
||
|
self.kv_storage = None
|
||
|
if isinstance(findex, str):
|
||
|
kv_file = findex.strip('./').replace('.', '_kv.')
|
||
|
index_file = Path(findex)
|
||
|
self.faiss_index = faiss.read_index(str(index_file))
|
||
|
if Path(kv_file).exists():
|
||
|
self.kv_storage = KVStorage(kv_file)
|
||
|
|
||
|
def __call__(self, query: list):
|
||
|
if 'k' not in self.kwargs:
|
||
|
self.kwargs['k'] = 10
|
||
|
|
||
|
query = np.array([query])
|
||
|
scores, ids = self.faiss_index.search(query, **self.kwargs)
|
||
|
|
||
|
ids = ids[0].tolist()
|
||
|
result = []
|
||
|
for i in range(len(ids)):
|
||
|
if self.kv_storage is not None:
|
||
|
k = self.kv_storage.get(ids[i])
|
||
|
else:
|
||
|
k = ids[i]
|
||
|
result.append(Entity(**{'key': k, 'score': scores[0][i]}))
|
||
|
return result
|