faiss-index
copied
5 changed files with 49 additions and 63 deletions
@ -1,4 +1,4 @@ |
|||||
from .faiss import Faiss |
|
||||
|
from .faiss_search import FaissSearch |
||||
|
|
||||
def faiss(*args, **kwargs): |
|
||||
return Faiss(*args, **kwargs) |
|
||||
|
def faiss_index(*args, **kwargs): |
||||
|
return FaissSearch(*args, **kwargs) |
||||
|
@ -1,53 +0,0 @@ |
|||||
import numpy as np |
|
||||
from pathlib import Path |
|
||||
import faiss |
|
||||
from towhee import register |
|
||||
from towhee.operator import PyOperator, SharedType |
|
||||
from towhee.utils.thirdparty.faiss_utils import KVStorage |
|
||||
# from towhee.functional.entity import Entity |
|
||||
|
|
||||
|
|
||||
@register(output_schema=['result']) |
|
||||
class Faiss(PyOperator): |
|
||||
""" |
|
||||
Search for embedding vectors in Faiss. Note that the index has data before searching, |
|
||||
refer to DataCollection Mixin `to_faiss`. |
|
||||
|
|
||||
Args: |
|
||||
findex (`str` or `faiss.INDEX`): |
|
||||
The path to faiss index file(defaults to './index.bin') or faiss index. |
|
||||
kwargs |
|
||||
The kwargs with index.search, refer to https://github.com/facebookresearch/faiss/wiki. And the parameter `k` defaults to 10. |
|
||||
""" |
|
||||
def __init__(self, findex, **kwargs): |
|
||||
self.faiss_index = findex |
|
||||
self.kwargs = kwargs |
|
||||
self.kv_storage = None |
|
||||
if isinstance(findex, str): |
|
||||
kv_file = findex.strip('./').replace('.', '_kv.') |
|
||||
index_file = Path(findex) |
|
||||
self.faiss_index = faiss.read_index(str(index_file)) |
|
||||
if Path(kv_file).exists(): |
|
||||
self.kv_storage = KVStorage(kv_file) |
|
||||
|
|
||||
def __call__(self, query: list): |
|
||||
if 'k' not in self.kwargs: |
|
||||
self.kwargs['k'] = 10 |
|
||||
|
|
||||
query = np.array([query]) |
|
||||
scores, ids = self.faiss_index.search(query, **self.kwargs) |
|
||||
|
|
||||
ids = ids[0].tolist() |
|
||||
result = [] |
|
||||
for i in range(len(ids)): |
|
||||
if self.kv_storage is not None: |
|
||||
k = self.kv_storage.get(ids[i]) |
|
||||
else: |
|
||||
k = ids[i] |
|
||||
result.append([k, scores[0][i]]) |
|
||||
# result.append(Entity(**{'key': k, 'score': scores[0][i]})) |
|
||||
return result |
|
||||
|
|
||||
@property |
|
||||
def shared_type(self): |
|
||||
return SharedType.NotShareable |
|
@ -0,0 +1,40 @@ |
|||||
|
import os |
||||
|
import pickle |
||||
|
import faiss |
||||
|
from towhee.operator import PyOperator, SharedType |
||||
|
|
||||
|
|
||||
|
class Meta: |
||||
|
def __init__(self): |
||||
|
self._id = 0 |
||||
|
self._data = {} |
||||
|
|
||||
|
def add(self, data): |
||||
|
self._id += 1 |
||||
|
self._data[self._id] = data |
||||
|
return self._id |
||||
|
|
||||
|
|
||||
|
class FaissSearch(PyOperator): |
||||
|
""" |
||||
|
Search for embedding vectors in Faiss. |
||||
|
Only support the data insert by ops.ann_insert.faiss_index() |
||||
|
""" |
||||
|
def __init__(self, data_dir: str, top_k: int = 5): |
||||
|
self._top_k = top_k |
||||
|
self._index_file = os.path.join(data_dir, 'faiss.index') |
||||
|
self._meta_file = os.path.join(data_dir, 'meta.bin') |
||||
|
self._index = faiss.read_index(self._index) |
||||
|
with open(self._meta_file, 'rb') as f: |
||||
|
self._meta = pickle.load(f) |
||||
|
|
||||
|
def __call__(self, query: 'ndarray'): |
||||
|
np_data = query.reshape(1, -1) |
||||
|
dist, ids = self._index.search(np_data, self._top_k) |
||||
|
ids = [[int(i)] for i in ids[0]] |
||||
|
meta_data = [self._meta_file[int(i)] for i in ids] |
||||
|
return list(zip(ids, dist[0], meta_data)) |
||||
|
|
||||
|
@property |
||||
|
def shared_type(self): |
||||
|
return SharedType.Shareable |
@ -1,3 +1,2 @@ |
|||||
faiss-cpu |
faiss-cpu |
||||
numpy |
numpy |
||||
towhee |
|
Loading…
Reference in new issue