#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ FAISS-based index components for dense retriver """ import logging import pickle from typing import List, Tuple import faiss import numpy as np logger = logging.getLogger() class DenseIndexer(object): def __init__(self, buffer_size: int = 50000): self.buffer_size = buffer_size self.index_id_to_db_id = [] self.index = None def index_data(self, data: List[Tuple[object, np.array]]): raise NotImplementedError def search_knn(self, query_vectors: np.array, top_docs: int) -> List[Tuple[List[object], List[float]]]: raise NotImplementedError def serialize(self, file: str): logger.info('Serializing index to %s', file) index_file = file + '.index.dpr' meta_file = file + '.index_meta.dpr' faiss.write_index(self.index, index_file) with open(meta_file, mode='wb') as f: pickle.dump(self.index_id_to_db_id, f) def deserialize_from(self, file: str): logger.info('Loading index from %s', file) index_file = file + '.index.dpr' meta_file = file + '.index_meta.dpr' self.index = faiss.read_index(index_file) logger.info('Loaded index of type %s and size %d', type(self.index), self.index.ntotal) with open(meta_file, "rb") as reader: self.index_id_to_db_id = pickle.load(reader) assert len( self.index_id_to_db_id) == self.index.ntotal, 'Deserialized index_id_to_db_id should match faiss index size' def _update_id_mapping(self, db_ids: List): self.index_id_to_db_id.extend(db_ids) class DenseFlatIndexer(DenseIndexer): def __init__(self, vector_sz: int, buffer_size: int = 50000): super(DenseFlatIndexer, self).__init__(buffer_size=buffer_size) self.index = faiss.IndexFlatIP(vector_sz) def index_data(self, data: List[Tuple[object, np.array]]): n = len(data) # indexing in batches is beneficial for many faiss index types for i in range(0, n, self.buffer_size): db_ids = [t[0] for t in data[i:i + self.buffer_size]] vectors = [np.reshape(t[1], (1, -1)) for t in data[i:i + self.buffer_size]] vectors = np.concatenate(vectors, axis=0) self._update_id_mapping(db_ids) self.index.add(vectors) indexed_cnt = len(self.index_id_to_db_id) logger.info('Total data indexed %d', indexed_cnt) def search_knn(self, query_vectors: np.array, top_docs: int) -> List[Tuple[List[object], List[float]]]: scores, indexes = self.index.search(query_vectors, top_docs) # convert to external ids db_ids = [[self.index_id_to_db_id[i] for i in query_top_idxs] for query_top_idxs in indexes] result = [(db_ids[i], scores[i]) for i in range(len(db_ids))] return result class DenseHNSWFlatIndexer(DenseIndexer): """ Efficient index for retrieval. Note: default settings are for hugh accuracy but also high RAM usage """ def __init__(self, vector_sz: int, buffer_size: int = 50000, store_n: int = 512 , ef_search: int = 128, ef_construction: int = 200): super(DenseHNSWFlatIndexer, self).__init__(buffer_size=buffer_size) # IndexHNSWFlat supports L2 similarity only # so we have to apply DOT -> L2 similairy space conversion with the help of an extra dimension index = faiss.IndexHNSWFlat(vector_sz + 1, store_n) index.hnsw.efSearch = ef_search index.hnsw.efConstruction = ef_construction self.index = index self.phi = 0 def index_data(self, data: List[Tuple[object, np.array]]): n = len(data) # max norm is required before putting all vectors in the index to convert inner product similarity to L2 if self.phi > 0: raise RuntimeError('DPR HNSWF index needs to index all data at once,' 'results will be unpredictable otherwise.') phi = 0 for i, item in enumerate(data): id, doc_vector = item norms = (doc_vector ** 2).sum() phi = max(phi, norms) logger.info('HNSWF DotProduct -> L2 space phi={}'.format(phi)) self.phi = 0 # indexing in batches is beneficial for many faiss index types for i in range(0, n, self.buffer_size): db_ids = [t[0] for t in data[i:i + self.buffer_size]] vectors = [np.reshape(t[1], (1, -1)) for t in data[i:i + self.buffer_size]] norms = [(doc_vector ** 2).sum() for doc_vector in vectors] aux_dims = [np.sqrt(phi - norm) for norm in norms] hnsw_vectors = [np.hstack((doc_vector, aux_dims[i].reshape(-1, 1))) for i, doc_vector in enumerate(vectors)] hnsw_vectors = np.concatenate(hnsw_vectors, axis=0) self._update_id_mapping(db_ids) self.index.add(hnsw_vectors) logger.info('data indexed %d', len(self.index_id_to_db_id)) indexed_cnt = len(self.index_id_to_db_id) logger.info('Total data indexed %d', indexed_cnt) def search_knn(self, query_vectors: np.array, top_docs: int) -> List[Tuple[List[object], List[float]]]: aux_dim = np.zeros(len(query_vectors), dtype='float32') query_nhsw_vectors = np.hstack((query_vectors, aux_dim.reshape(-1, 1))) # logger.info('query_hnsw_vectors %s', query_nhsw_vectors.shape) scores, indexes = self.index.search(query_nhsw_vectors, top_docs) # convert to external ids db_ids = [[self.index_id_to_db_id[i] for i in query_top_idxs] for query_top_idxs in indexes] result = [(db_ids[i], scores[i]) for i in range(len(db_ids))] return result def deserialize_from(self, file: str): super(DenseHNSWFlatIndexer, self).deserialize_from(file) # to trigger warning on subsequent indexing self.phi = 1