logo
Browse Source

Update Readme

Signed-off-by: Kaiyuan Hu <kaiyuan.hu@zilliz.com>
main
Kaiyuan Hu 2 years ago
parent
commit
458a3f3a90
  1. 28
      README.md
  2. 31
      eqa_search.py

28
README.md

@ -20,10 +20,18 @@
from towhee import AutoPipes, AutoConfig from towhee import AutoPipes, AutoConfig
config = AutoConfig.load_config('eqa-search') config = AutoConfig.load_config('eqa-search')
config.openai_api_key = [your-openai-api-key]
config.collection_name = 'chatbot' config.collection_name = 'chatbot'
# The llm model source, openai or dolly
config.llm_src = 'openai'
# The llm model name
config.llm_model = 'gpt-3.5-turbo'
# The threshold to filter milvus search result
config.threshold = 0.5
p = AutoPipes.pipeline('eqa-search', config=config) p = AutoPipes.pipeline('eqa-search', config=config)
res = p('https://raw.githubusercontent.com/towhee-io/towhee/main/README.md')
res = p('https://github.com/towhee-io/towhee/blob/main/README.md', [])
``` ```
@ -76,6 +84,24 @@ The user name for [Cloud user](https://zilliz.com/cloud), defaults to `None`.
The user password for [Cloud user](https://zilliz.com/cloud), defaults to `None`. The user password for [Cloud user](https://zilliz.com/cloud), defaults to `None`.
### Configuration for Similarity Evaluation
***threshold (Union[float, int]):***
The threshold to filter the milvus search result.
### Configuration for LLM
***llm_src (str):***
The llm model source, `openai` or `dolly`, defaults to `openai`.
***llm_model (str):***
The llm model name, defaults to `gpt-3.5-turbo` for `openai`, `databricks/dolly-v2-12b` for `dolly`.
<br /> <br />

31
eqa_search.py

@ -22,7 +22,7 @@ class EnhancedQASearchConfig:
""" """
def __init__(self): def __init__(self):
# config for sentence_embedding # config for sentence_embedding
self.model = 'all-MiniLM-L6-v2'
self.embedding_model = 'all-MiniLM-L6-v2'
self.openai_api_key = None self.openai_api_key = None
self.embedding_device = -1 self.embedding_device = -1
# config for search_milvus # config for search_milvus
@ -35,7 +35,8 @@ class EnhancedQASearchConfig:
# config for similarity evaluation # config for similarity evaluation
self.threshold = 0.6 self.threshold = 0.6
# config for llm # config for llm
self.llm_model = 'openai'
self.llm_src = 'openai'
self.llm_model = 'gpt-3.5-turbo' if self.llm_src.lower() == 'openai' else 'databricks/dolly-v2-12b'
_hf_models = ops.sentence_embedding.transformers().get_op().supported_model_names() _hf_models = ops.sentence_embedding.transformers().get_op().supported_model_names()
@ -51,22 +52,22 @@ def _get_embedding_op(config):
else: else:
device = config.embedding_device device = config.embedding_device
if config.model in _hf_models:
if config.embedding_model in _hf_models:
return True, ops.sentence_embedding.transformers( return True, ops.sentence_embedding.transformers(
model_name=config.model, device=device
model_name=config.embedding_model, device=device
) )
if config.model in _sbert_models:
if config.embedding_model in _sbert_models:
return True, ops.sentence_embedding.sbert( return True, ops.sentence_embedding.sbert(
model_name=config.model, device=device
model_name=config.embedding_model, device=device
) )
if config.model in _openai_models:
if config.embedding_model in _openai_models:
return False, ops.sentence_embedding.openai( return False, ops.sentence_embedding.openai(
model_name=config.model, api_key=config.openai_api_key
model_name=config.embedding_model, api_key=config.openai_api_key
) )
raise RuntimeError('Unknown model: [%s], only support: %s' % (config.model, _hf_models + _openai_models))
raise RuntimeError('Unknown model: [%s], only support: %s' % (config.embedding_model, _hf_models + _openai_models))
def _get_similarity_evaluation_op(config): def _get_similarity_evaluation_op(config):
@ -74,12 +75,12 @@ def _get_similarity_evaluation_op(config):
def _get_llm_op(config): def _get_llm_op(config):
if config.llm_model.lower() == 'openai':
return ops.LLM.OpenAI(api_key=config.openai_api_key)
if config.llm_model.lower() == 'dolly':
return ops.LLM.Dolly()
if config.llm_src.lower() == 'openai':
return ops.LLM.OpenAI(model_name=config.llm_model, api_key=config.openai_api_key)
if config.llm_src.lower() == 'dolly':
return ops.LLM.Dolly(model_name=config.llm_model)
raise RuntimeError('Unknown llm model: [%s], only support \'openai\' and \'dolly\'' % (config.model))
raise RuntimeError('Unknown llm source: [%s], only support \'openai\' and \'dolly\'' % (config.llm_src))
@AutoPipes.register @AutoPipes.register
@ -119,7 +120,7 @@ def enhanced_qa_search_pipe(config):
p = ( p = (
p.map('result', 'docs', lambda x:[i[2] for i in x]) p.map('result', 'docs', lambda x:[i[2] for i in x])
.map(('question', 'docs', 'history'), 'prompt', ops.prompt.question_answer(llm_name=config.llm_model))
.map(('question', 'docs', 'history'), 'prompt', ops.prompt.question_answer(llm_name=config.llm_src))
.map('prompt', 'answer', llm_op) .map('prompt', 'answer', llm_op)
) )

Loading…
Cancel
Save