logo
Browse Source

update

Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
main
junjie.jiang 2 years ago
parent
commit
430ef4ffc2
  1. 31
      faiss_search.py

31
faiss_search.py

@ -1,19 +1,31 @@
import os
import pickle
import json
import faiss
from towhee.operator import PyOperator, SharedType
class Meta:
def __init__(self):
self._id = 0
self._data = {}
def __init__(self, id=0, data=None):
self._id = id
self._data = data if data is not None else {}
def add(self, data):
self._id += 1
self._data[self._id] = data
return self._id
def get(self, i):
return self._data[i]
@classmethod
def load(cls, f_path):
with open(f_path, 'r') as f:
data = json.load(f)
return cls(data['id'], data['data'])
def save(self, f_path):
with open(f_path, 'w') as f:
json.dump({'id': self._id, 'data': self._data}, f)
class FaissSearch(PyOperator):
"""
@ -24,15 +36,14 @@ class FaissSearch(PyOperator):
self._top_k = top_k
self._index_file = os.path.join(data_dir, 'faiss.index')
self._meta_file = os.path.join(data_dir, 'meta.bin')
self._index = faiss.read_index(self._index)
with open(self._meta_file, 'rb') as f:
self._meta = pickle.load(f)
self._index = faiss.read_index(self._index_file)
self._meta = Meta.load(self._meta_file)
def __call__(self, query: 'ndarray'):
np_data = query.reshape(1, -1)
dist, ids = self._index.search(np_data, self._top_k)
ids = [[int(i)] for i in ids[0]]
meta_data = [self._meta_file[int(i)] for i in ids]
ids = [int(i) for i in ids[0]]
meta_data = [self._meta.get(str(i)) for i in ids]
return list(zip(ids, dist[0], meta_data))
@property

Loading…
Cancel
Save