logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

155 lines
6.0 KiB

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
"""
FAISS-based index components for dense retriver
"""
import logging
import pickle
from typing import List, Tuple
import faiss
import numpy as np
logger = logging.getLogger()
class DenseIndexer(object):
def __init__(self, buffer_size: int = 50000):
self.buffer_size = buffer_size
self.index_id_to_db_id = []
self.index = None
def index_data(self, data: List[Tuple[object, np.array]]):
raise NotImplementedError
def search_knn(self, query_vectors: np.array, top_docs: int) -> List[Tuple[List[object], List[float]]]:
raise NotImplementedError
def serialize(self, file: str):
logger.info('Serializing index to %s', file)
index_file = file + '.index.dpr'
meta_file = file + '.index_meta.dpr'
faiss.write_index(self.index, index_file)
with open(meta_file, mode='wb') as f:
pickle.dump(self.index_id_to_db_id, f)
def deserialize_from(self, file: str):
logger.info('Loading index from %s', file)
index_file = file + '.index.dpr'
meta_file = file + '.index_meta.dpr'
self.index = faiss.read_index(index_file)
logger.info('Loaded index of type %s and size %d', type(self.index), self.index.ntotal)
with open(meta_file, "rb") as reader:
self.index_id_to_db_id = pickle.load(reader)
assert len(
self.index_id_to_db_id) == self.index.ntotal, 'Deserialized index_id_to_db_id should match faiss index size'
def _update_id_mapping(self, db_ids: List):
self.index_id_to_db_id.extend(db_ids)
class DenseFlatIndexer(DenseIndexer):
def __init__(self, vector_sz: int, buffer_size: int = 50000):
super(DenseFlatIndexer, self).__init__(buffer_size=buffer_size)
self.index = faiss.IndexFlatIP(vector_sz)
def index_data(self, data: List[Tuple[object, np.array]]):
n = len(data)
# indexing in batches is beneficial for many faiss index types
for i in range(0, n, self.buffer_size):
db_ids = [t[0] for t in data[i:i + self.buffer_size]]
vectors = [np.reshape(t[1], (1, -1)) for t in data[i:i + self.buffer_size]]
vectors = np.concatenate(vectors, axis=0)
self._update_id_mapping(db_ids)
self.index.add(vectors)
indexed_cnt = len(self.index_id_to_db_id)
logger.info('Total data indexed %d', indexed_cnt)
def search_knn(self, query_vectors: np.array, top_docs: int) -> List[Tuple[List[object], List[float]]]:
scores, indexes = self.index.search(query_vectors, top_docs)
# convert to external ids
db_ids = [[self.index_id_to_db_id[i] for i in query_top_idxs] for query_top_idxs in indexes]
result = [(db_ids[i], scores[i]) for i in range(len(db_ids))]
return result
class DenseHNSWFlatIndexer(DenseIndexer):
"""
Efficient index for retrieval. Note: default settings are for hugh accuracy but also high RAM usage
"""
def __init__(self, vector_sz: int, buffer_size: int = 50000, store_n: int = 512
, ef_search: int = 128, ef_construction: int = 200):
super(DenseHNSWFlatIndexer, self).__init__(buffer_size=buffer_size)
# IndexHNSWFlat supports L2 similarity only
# so we have to apply DOT -> L2 similairy space conversion with the help of an extra dimension
index = faiss.IndexHNSWFlat(vector_sz + 1, store_n)
index.hnsw.efSearch = ef_search
index.hnsw.efConstruction = ef_construction
self.index = index
self.phi = 0
def index_data(self, data: List[Tuple[object, np.array]]):
n = len(data)
# max norm is required before putting all vectors in the index to convert inner product similarity to L2
if self.phi > 0:
raise RuntimeError('DPR HNSWF index needs to index all data at once,'
'results will be unpredictable otherwise.')
phi = 0
for i, item in enumerate(data):
id, doc_vector = item
norms = (doc_vector ** 2).sum()
phi = max(phi, norms)
logger.info('HNSWF DotProduct -> L2 space phi={}'.format(phi))
self.phi = 0
# indexing in batches is beneficial for many faiss index types
for i in range(0, n, self.buffer_size):
db_ids = [t[0] for t in data[i:i + self.buffer_size]]
vectors = [np.reshape(t[1], (1, -1)) for t in data[i:i + self.buffer_size]]
norms = [(doc_vector ** 2).sum() for doc_vector in vectors]
aux_dims = [np.sqrt(phi - norm) for norm in norms]
hnsw_vectors = [np.hstack((doc_vector, aux_dims[i].reshape(-1, 1))) for i, doc_vector in
enumerate(vectors)]
hnsw_vectors = np.concatenate(hnsw_vectors, axis=0)
self._update_id_mapping(db_ids)
self.index.add(hnsw_vectors)
logger.info('data indexed %d', len(self.index_id_to_db_id))
indexed_cnt = len(self.index_id_to_db_id)
logger.info('Total data indexed %d', indexed_cnt)
def search_knn(self, query_vectors: np.array, top_docs: int) -> List[Tuple[List[object], List[float]]]:
aux_dim = np.zeros(len(query_vectors), dtype='float32')
query_nhsw_vectors = np.hstack((query_vectors, aux_dim.reshape(-1, 1)))
# logger.info('query_hnsw_vectors %s', query_nhsw_vectors.shape)
scores, indexes = self.index.search(query_nhsw_vectors, top_docs)
# convert to external ids
db_ids = [[self.index_id_to_db_id[i] for i in query_top_idxs] for query_top_idxs in indexes]
result = [(db_ids[i], scores[i]) for i in range(len(db_ids))]
return result
def deserialize_from(self, file: str):
super(DenseHNSWFlatIndexer, self).deserialize_from(file)
# to trigger warning on subsequent indexing
self.phi = 1