logo
Browse Source

Add rerank config

Signed-off-by: shiyu22 <shiyu.chen@zilliz.com>
main
shiyu22 1 year ago
parent
commit
47daa58a2c
  1. 22
      README.md
  2. 28
      eqa_search.py

22
README.md

@ -26,8 +26,8 @@ config.collection_name = 'chatbot'
config.top_k = 5 config.top_k = 5
# If using zilliz cloud # If using zilliz cloud
config.user = [zilliz-cloud-username]
config.password = [zilliz-cloud-password]
# config.user = [zilliz-cloud-username]
# config.password = [zilliz-cloud-password]
# OpenAI api key # OpenAI api key
config.openai_api_key = [your-openai-api-key] config.openai_api_key = [your-openai-api-key]
@ -36,8 +36,8 @@ config.embedding_model = 'all-MiniLM-L6-v2'
# Embedding model device # Embedding model device
config.embedding_device = -1 config.embedding_device = -1
# The threshold to filter milvus search result
config.threshold = 0.5
# Rerank the docs searched from knowledge base
config.rerank = True
# The llm model source, openai or dolly # The llm model source, openai or dolly
config.llm_src = 'openai' config.llm_src = 'openai'
@ -99,11 +99,19 @@ The user name for [Cloud user](https://zilliz.com/cloud), defaults to `None`.
The user password for [Cloud user](https://zilliz.com/cloud), defaults to `None`. The user password for [Cloud user](https://zilliz.com/cloud), defaults to `None`.
### Configuration for Similarity Evaluation
### Configuration for Rerank
***threshold (Union[float, int]):***
***rerank***: bool
The threshold to filter the milvus search result.
Whether to rerank the docs searched from knowledge base, defaults to False. If set it to True it will using the [rerank](https://towhee.io/towhee/rerank) operator.
***rerank_model***: str
The name of rerank model, you can set it according to the [rerank](https://towhee.io/towhee/rerank) operator.
***threshold:*** Union[float, int]
The threshold for rerank, defaults to 0.6. If the `rerank` is `False`, it will filter the milvus search result, otherwise it will be filtered with the [rerank](https://towhee.io/towhee/rerank) operator.
### Configuration for LLM ### Configuration for LLM

28
eqa_search.py

@ -33,8 +33,6 @@ class EnhancedQASearchConfig(BaseModel):
top_k: Optional[int] = 5 top_k: Optional[int] = 5
user: Optional[str] = None user: Optional[str] = None
password: Optional[str] = None password: Optional[str] = None
# config for similarity evaluation
threshold: Optional[Union[float, int]] = 0.6
# config for llm # config for llm
llm_src: Optional[str] = 'openai' llm_src: Optional[str] = 'openai'
openai_model: Optional[str] = 'gpt-3.5-turbo' openai_model: Optional[str] = 'gpt-3.5-turbo'
@ -44,7 +42,11 @@ class EnhancedQASearchConfig(BaseModel):
ernie_secret_key: Optional[str] = None ernie_secret_key: Optional[str] = None
customize_llm: Optional[Any] = None customize_llm: Optional[Any] = None
customize_prompt: Optional[Any] = None
customize_prompt: Optional[Any] = None
# config for rerank
rerank: Optional[bool] = False
rerank_model: Optional[str] = 'cross-encoder/ms-marco-MiniLM-L-6-v2'
threshold: Optional[Union[float, int]] = 0.6
_hf_models = ops.sentence_embedding.transformers().get_op().supported_model_names() _hf_models = ops.sentence_embedding.transformers().get_op().supported_model_names()
@ -131,15 +133,23 @@ def enhanced_qa_search_pipe(config):
.map('embedding', 'result', search_milvus_op) .map('embedding', 'result', search_milvus_op)
) )
# if config.similarity_evaluation:
if config.threshold:
if config.rerank:
p = (
p.map('result', 'docs', lambda x: [i[2] for i in x])
.map(('question', 'docs'), ('docs', 'score'), ops.rerank(config.rerank_model, config.threshold))
.map('docs', 'docs', lambda x: '\n'.join([i for i in x]))
)
elif config.threshold:
sim_eval_op = _get_similarity_evaluation_op(config) sim_eval_op = _get_similarity_evaluation_op(config)
p = p.map('result', 'result', sim_eval_op)
p = (p.map('result', 'result', sim_eval_op)
.map('result', 'docs', lambda x: '\n'.join([i[2] for i in x]))
)
else:
p = p.map('result', 'docs', lambda x: '\n'.join([i[2] for i in x]))
p = ( p = (
p.map('result', 'docs', lambda x: '\n'.join([i[2] for i in x]))
.map(('question', 'docs', 'history'), 'prompt', _get_prompt(config))
.map('prompt', 'answer', llm_op)
p.map(('question', 'docs', 'history'), 'prompt', _get_prompt(config))
.map('prompt', 'answer', llm_op)
) )
return p.output('answer') return p.output('answer')

Loading…
Cancel
Save