|
|
@ -1,19 +1,31 @@ |
|
|
|
import os |
|
|
|
import pickle |
|
|
|
import json |
|
|
|
import faiss |
|
|
|
from towhee.operator import PyOperator, SharedType |
|
|
|
|
|
|
|
|
|
|
|
class Meta: |
|
|
|
def __init__(self): |
|
|
|
self._id = 0 |
|
|
|
self._data = {} |
|
|
|
def __init__(self, id=0, data=None): |
|
|
|
self._id = id |
|
|
|
self._data = data if data is not None else {} |
|
|
|
|
|
|
|
def add(self, data): |
|
|
|
self._id += 1 |
|
|
|
self._data[self._id] = data |
|
|
|
return self._id |
|
|
|
|
|
|
|
def get(self, i): |
|
|
|
return self._data[i] |
|
|
|
|
|
|
|
@classmethod |
|
|
|
def load(cls, f_path): |
|
|
|
with open(f_path, 'r') as f: |
|
|
|
data = json.load(f) |
|
|
|
return cls(data['id'], data['data']) |
|
|
|
|
|
|
|
def save(self, f_path): |
|
|
|
with open(f_path, 'w') as f: |
|
|
|
json.dump({'id': self._id, 'data': self._data}, f) |
|
|
|
|
|
|
|
|
|
|
|
class FaissSearch(PyOperator): |
|
|
|
""" |
|
|
@ -24,15 +36,14 @@ class FaissSearch(PyOperator): |
|
|
|
self._top_k = top_k |
|
|
|
self._index_file = os.path.join(data_dir, 'faiss.index') |
|
|
|
self._meta_file = os.path.join(data_dir, 'meta.bin') |
|
|
|
self._index = faiss.read_index(self._index) |
|
|
|
with open(self._meta_file, 'rb') as f: |
|
|
|
self._meta = pickle.load(f) |
|
|
|
self._index = faiss.read_index(self._index_file) |
|
|
|
self._meta = Meta.load(self._meta_file) |
|
|
|
|
|
|
|
def __call__(self, query: 'ndarray'): |
|
|
|
np_data = query.reshape(1, -1) |
|
|
|
dist, ids = self._index.search(np_data, self._top_k) |
|
|
|
ids = [[int(i)] for i in ids[0]] |
|
|
|
meta_data = [self._meta_file[int(i)] for i in ids] |
|
|
|
ids = [int(i) for i in ids[0]] |
|
|
|
meta_data = [self._meta.get(str(i)) for i in ids] |
|
|
|
return list(zip(ids, dist[0], meta_data)) |
|
|
|
|
|
|
|
@property |
|
|
|