faiss-index
copied
5 changed files with 49 additions and 63 deletions
@ -1,4 +1,4 @@ |
|||
from .faiss import Faiss |
|||
from .faiss_search import FaissSearch |
|||
|
|||
def faiss(*args, **kwargs): |
|||
return Faiss(*args, **kwargs) |
|||
def faiss_index(*args, **kwargs): |
|||
return FaissSearch(*args, **kwargs) |
|||
|
@ -1,53 +0,0 @@ |
|||
import numpy as np |
|||
from pathlib import Path |
|||
import faiss |
|||
from towhee import register |
|||
from towhee.operator import PyOperator, SharedType |
|||
from towhee.utils.thirdparty.faiss_utils import KVStorage |
|||
# from towhee.functional.entity import Entity |
|||
|
|||
|
|||
@register(output_schema=['result']) |
|||
class Faiss(PyOperator): |
|||
""" |
|||
Search for embedding vectors in Faiss. Note that the index has data before searching, |
|||
refer to DataCollection Mixin `to_faiss`. |
|||
|
|||
Args: |
|||
findex (`str` or `faiss.INDEX`): |
|||
The path to faiss index file(defaults to './index.bin') or faiss index. |
|||
kwargs |
|||
The kwargs with index.search, refer to https://github.com/facebookresearch/faiss/wiki. And the parameter `k` defaults to 10. |
|||
""" |
|||
def __init__(self, findex, **kwargs): |
|||
self.faiss_index = findex |
|||
self.kwargs = kwargs |
|||
self.kv_storage = None |
|||
if isinstance(findex, str): |
|||
kv_file = findex.strip('./').replace('.', '_kv.') |
|||
index_file = Path(findex) |
|||
self.faiss_index = faiss.read_index(str(index_file)) |
|||
if Path(kv_file).exists(): |
|||
self.kv_storage = KVStorage(kv_file) |
|||
|
|||
def __call__(self, query: list): |
|||
if 'k' not in self.kwargs: |
|||
self.kwargs['k'] = 10 |
|||
|
|||
query = np.array([query]) |
|||
scores, ids = self.faiss_index.search(query, **self.kwargs) |
|||
|
|||
ids = ids[0].tolist() |
|||
result = [] |
|||
for i in range(len(ids)): |
|||
if self.kv_storage is not None: |
|||
k = self.kv_storage.get(ids[i]) |
|||
else: |
|||
k = ids[i] |
|||
result.append([k, scores[0][i]]) |
|||
# result.append(Entity(**{'key': k, 'score': scores[0][i]})) |
|||
return result |
|||
|
|||
@property |
|||
def shared_type(self): |
|||
return SharedType.NotShareable |
@ -0,0 +1,40 @@ |
|||
import os |
|||
import pickle |
|||
import faiss |
|||
from towhee.operator import PyOperator, SharedType |
|||
|
|||
|
|||
class Meta: |
|||
def __init__(self): |
|||
self._id = 0 |
|||
self._data = {} |
|||
|
|||
def add(self, data): |
|||
self._id += 1 |
|||
self._data[self._id] = data |
|||
return self._id |
|||
|
|||
|
|||
class FaissSearch(PyOperator): |
|||
""" |
|||
Search for embedding vectors in Faiss. |
|||
Only support the data insert by ops.ann_insert.faiss_index() |
|||
""" |
|||
def __init__(self, data_dir: str, top_k: int = 5): |
|||
self._top_k = top_k |
|||
self._index_file = os.path.join(data_dir, 'faiss.index') |
|||
self._meta_file = os.path.join(data_dir, 'meta.bin') |
|||
self._index = faiss.read_index(self._index) |
|||
with open(self._meta_file, 'rb') as f: |
|||
self._meta = pickle.load(f) |
|||
|
|||
def __call__(self, query: 'ndarray'): |
|||
np_data = query.reshape(1, -1) |
|||
dist, ids = self._index.search(np_data, self._top_k) |
|||
ids = [[int(i)] for i in ids[0]] |
|||
meta_data = [self._meta_file[int(i)] for i in ids] |
|||
return list(zip(ids, dist[0], meta_data)) |
|||
|
|||
@property |
|||
def shared_type(self): |
|||
return SharedType.Shareable |
@ -1,3 +1,2 @@ |
|||
faiss-cpu |
|||
numpy |
|||
towhee |
Loading…
Reference in new issue