From 47daa58a2c2e22ca8185b439aa8992f3940bc776 Mon Sep 17 00:00:00 2001 From: shiyu22 Date: Wed, 21 Jun 2023 16:37:19 +0800 Subject: [PATCH] Add rerank config Signed-off-by: shiyu22 --- README.md | 22 +++++++++++++++------- eqa_search.py | 28 +++++++++++++++++++--------- 2 files changed, 34 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 0dd722c..24c1f92 100644 --- a/README.md +++ b/README.md @@ -26,8 +26,8 @@ config.collection_name = 'chatbot' config.top_k = 5 # If using zilliz cloud -config.user = [zilliz-cloud-username] -config.password = [zilliz-cloud-password] +# config.user = [zilliz-cloud-username] +# config.password = [zilliz-cloud-password] # OpenAI api key config.openai_api_key = [your-openai-api-key] @@ -36,8 +36,8 @@ config.embedding_model = 'all-MiniLM-L6-v2' # Embedding model device config.embedding_device = -1 -# The threshold to filter milvus search result -config.threshold = 0.5 +# Rerank the docs searched from knowledge base +config.rerank = True # The llm model source, openai or dolly config.llm_src = 'openai' @@ -99,11 +99,19 @@ The user name for [Cloud user](https://zilliz.com/cloud), defaults to `None`. The user password for [Cloud user](https://zilliz.com/cloud), defaults to `None`. -### Configuration for Similarity Evaluation +### Configuration for Rerank -***threshold (Union[float, int]):*** +***rerank***: bool -The threshold to filter the milvus search result. +Whether to rerank the docs searched from knowledge base, defaults to False. If set it to True it will using the [rerank](https://towhee.io/towhee/rerank) operator. + +***rerank_model***: str + +The name of rerank model, you can set it according to the [rerank](https://towhee.io/towhee/rerank) operator. + +***threshold:*** Union[float, int] + +The threshold for rerank, defaults to 0.6. If the `rerank` is `False`, it will filter the milvus search result, otherwise it will be filtered with the [rerank](https://towhee.io/towhee/rerank) operator. ### Configuration for LLM diff --git a/eqa_search.py b/eqa_search.py index 7cd3843..1343243 100644 --- a/eqa_search.py +++ b/eqa_search.py @@ -33,8 +33,6 @@ class EnhancedQASearchConfig(BaseModel): top_k: Optional[int] = 5 user: Optional[str] = None password: Optional[str] = None - # config for similarity evaluation - threshold: Optional[Union[float, int]] = 0.6 # config for llm llm_src: Optional[str] = 'openai' openai_model: Optional[str] = 'gpt-3.5-turbo' @@ -44,7 +42,11 @@ class EnhancedQASearchConfig(BaseModel): ernie_secret_key: Optional[str] = None customize_llm: Optional[Any] = None - customize_prompt: Optional[Any] = None + customize_prompt: Optional[Any] = None + # config for rerank + rerank: Optional[bool] = False + rerank_model: Optional[str] = 'cross-encoder/ms-marco-MiniLM-L-6-v2' + threshold: Optional[Union[float, int]] = 0.6 _hf_models = ops.sentence_embedding.transformers().get_op().supported_model_names() @@ -131,15 +133,23 @@ def enhanced_qa_search_pipe(config): .map('embedding', 'result', search_milvus_op) ) - # if config.similarity_evaluation: - if config.threshold: + if config.rerank: + p = ( + p.map('result', 'docs', lambda x: [i[2] for i in x]) + .map(('question', 'docs'), ('docs', 'score'), ops.rerank(config.rerank_model, config.threshold)) + .map('docs', 'docs', lambda x: '\n'.join([i for i in x])) + ) + elif config.threshold: sim_eval_op = _get_similarity_evaluation_op(config) - p = p.map('result', 'result', sim_eval_op) + p = (p.map('result', 'result', sim_eval_op) + .map('result', 'docs', lambda x: '\n'.join([i[2] for i in x])) + ) + else: + p = p.map('result', 'docs', lambda x: '\n'.join([i[2] for i in x])) p = ( - p.map('result', 'docs', lambda x: '\n'.join([i[2] for i in x])) - .map(('question', 'docs', 'history'), 'prompt', _get_prompt(config)) - .map('prompt', 'answer', llm_op) + p.map(('question', 'docs', 'history'), 'prompt', _get_prompt(config)) + .map('prompt', 'answer', llm_op) ) return p.output('answer')