logo
Browse Source

update

Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
main
junjie.jiang 2 years ago
parent
commit
f0e84bdffa
  1. 12
      README.md
  2. 6
      __init__.py
  3. 53
      faiss.py
  4. 40
      faiss_search.py
  5. 1
      requirements.txt

12
README.md

@ -26,7 +26,7 @@ Search embedding in [Faiss](https://github.com/facebookresearch/faiss), **please
from towhee.dc2 import pipe, ops from towhee.dc2 import pipe, ops
p = pipe.input('vec') \ p = pipe.input('vec') \
.flat_map('vec', 'rows', ops.ann_search.faiss(findex='index.bin')) \
.flat_map('vec', 'rows', ops.ann_search.faiss('./data_dir', 5)) \
.map('rows', ('id', 'score'), lambda x: (x[0], x[1])) \ .map('rows', ('id', 'score'), lambda x: (x[0], x[1])) \
.output('id', 'score') .output('id', 'score')
@ -41,7 +41,7 @@ p(<your-vector>)
Create the operator via the following factory method: Create the operator via the following factory method:
***ann_search.faiss(findex)***
***ops.ann_search.faiss_index('./data_dir', 5)***
<br /> <br />
@ -49,9 +49,9 @@ Create the operator via the following factory method:
**Parameters:** **Parameters:**
***findex:*** *str* or *faiss.INDEX*
***data_dir:*** *str*
The path to faiss index file or faiss index.
The path to faiss index and meta data.
<br /> <br />
@ -62,9 +62,9 @@ The path to faiss index file or faiss index.
**Parameters:** **Parameters:**
***query:*** *list*
***query:*** *ndarray*
Query embeddings in Faiss
Query embedding in Faiss
<br /> <br />

6
__init__.py

@ -1,4 +1,4 @@
from .faiss import Faiss
from .faiss_search import FaissSearch
def faiss(*args, **kwargs):
return Faiss(*args, **kwargs)
def faiss_index(*args, **kwargs):
return FaissSearch(*args, **kwargs)

53
faiss.py

@ -1,53 +0,0 @@
import numpy as np
from pathlib import Path
import faiss
from towhee import register
from towhee.operator import PyOperator, SharedType
from towhee.utils.thirdparty.faiss_utils import KVStorage
# from towhee.functional.entity import Entity
@register(output_schema=['result'])
class Faiss(PyOperator):
"""
Search for embedding vectors in Faiss. Note that the index has data before searching,
refer to DataCollection Mixin `to_faiss`.
Args:
findex (`str` or `faiss.INDEX`):
The path to faiss index file(defaults to './index.bin') or faiss index.
kwargs
The kwargs with index.search, refer to https://github.com/facebookresearch/faiss/wiki. And the parameter `k` defaults to 10.
"""
def __init__(self, findex, **kwargs):
self.faiss_index = findex
self.kwargs = kwargs
self.kv_storage = None
if isinstance(findex, str):
kv_file = findex.strip('./').replace('.', '_kv.')
index_file = Path(findex)
self.faiss_index = faiss.read_index(str(index_file))
if Path(kv_file).exists():
self.kv_storage = KVStorage(kv_file)
def __call__(self, query: list):
if 'k' not in self.kwargs:
self.kwargs['k'] = 10
query = np.array([query])
scores, ids = self.faiss_index.search(query, **self.kwargs)
ids = ids[0].tolist()
result = []
for i in range(len(ids)):
if self.kv_storage is not None:
k = self.kv_storage.get(ids[i])
else:
k = ids[i]
result.append([k, scores[0][i]])
# result.append(Entity(**{'key': k, 'score': scores[0][i]}))
return result
@property
def shared_type(self):
return SharedType.NotShareable

40
faiss_search.py

@ -0,0 +1,40 @@
import os
import pickle
import faiss
from towhee.operator import PyOperator, SharedType
class Meta:
def __init__(self):
self._id = 0
self._data = {}
def add(self, data):
self._id += 1
self._data[self._id] = data
return self._id
class FaissSearch(PyOperator):
"""
Search for embedding vectors in Faiss.
Only support the data insert by ops.ann_insert.faiss_index()
"""
def __init__(self, data_dir: str, top_k: int = 5):
self._top_k = top_k
self._index_file = os.path.join(data_dir, 'faiss.index')
self._meta_file = os.path.join(data_dir, 'meta.bin')
self._index = faiss.read_index(self._index)
with open(self._meta_file, 'rb') as f:
self._meta = pickle.load(f)
def __call__(self, query: 'ndarray'):
np_data = query.reshape(1, -1)
dist, ids = self._index.search(np_data, self._top_k)
ids = [[int(i)] for i in ids[0]]
meta_data = [self._meta_file[int(i)] for i in ids]
return list(zip(ids, dist[0], meta_data))
@property
def shared_type(self):
return SharedType.Shareable

1
requirements.txt

@ -1,3 +1,2 @@
faiss-cpu faiss-cpu
numpy numpy
towhee
Loading…
Cancel
Save