diff --git a/README.md b/README.md index d744793..d2742c1 100644 --- a/README.md +++ b/README.md @@ -20,10 +20,18 @@ from towhee import AutoPipes, AutoConfig config = AutoConfig.load_config('eqa-search') +config.openai_api_key = [your-openai-api-key] config.collection_name = 'chatbot' +# The llm model source, openai or dolly +config.llm_src = 'openai' +# The llm model name +config.llm_model = 'gpt-3.5-turbo' +# The threshold to filter milvus search result +config.threshold = 0.5 + p = AutoPipes.pipeline('eqa-search', config=config) -res = p('https://raw.githubusercontent.com/towhee-io/towhee/main/README.md') +res = p('https://github.com/towhee-io/towhee/blob/main/README.md', []) ``` @@ -76,6 +84,24 @@ The user name for [Cloud user](https://zilliz.com/cloud), defaults to `None`. The user password for [Cloud user](https://zilliz.com/cloud), defaults to `None`. +### Configuration for Similarity Evaluation + +***threshold (Union[float, int]):*** + +The threshold to filter the milvus search result. + + +### Configuration for LLM + +***llm_src (str):*** + +The llm model source, `openai` or `dolly`, defaults to `openai`. + +***llm_model (str):*** + +The llm model name, defaults to `gpt-3.5-turbo` for `openai`, `databricks/dolly-v2-12b` for `dolly`. + +
diff --git a/eqa_search.py b/eqa_search.py index f4d897e..e5e27f9 100644 --- a/eqa_search.py +++ b/eqa_search.py @@ -22,7 +22,7 @@ class EnhancedQASearchConfig: """ def __init__(self): # config for sentence_embedding - self.model = 'all-MiniLM-L6-v2' + self.embedding_model = 'all-MiniLM-L6-v2' self.openai_api_key = None self.embedding_device = -1 # config for search_milvus @@ -35,7 +35,8 @@ class EnhancedQASearchConfig: # config for similarity evaluation self.threshold = 0.6 # config for llm - self.llm_model = 'openai' + self.llm_src = 'openai' + self.llm_model = 'gpt-3.5-turbo' if self.llm_src.lower() == 'openai' else 'databricks/dolly-v2-12b' _hf_models = ops.sentence_embedding.transformers().get_op().supported_model_names() @@ -51,22 +52,22 @@ def _get_embedding_op(config): else: device = config.embedding_device - if config.model in _hf_models: + if config.embedding_model in _hf_models: return True, ops.sentence_embedding.transformers( - model_name=config.model, device=device + model_name=config.embedding_model, device=device ) - if config.model in _sbert_models: + if config.embedding_model in _sbert_models: return True, ops.sentence_embedding.sbert( - model_name=config.model, device=device + model_name=config.embedding_model, device=device ) - if config.model in _openai_models: + if config.embedding_model in _openai_models: return False, ops.sentence_embedding.openai( - model_name=config.model, api_key=config.openai_api_key + model_name=config.embedding_model, api_key=config.openai_api_key ) - raise RuntimeError('Unknown model: [%s], only support: %s' % (config.model, _hf_models + _openai_models)) + raise RuntimeError('Unknown model: [%s], only support: %s' % (config.embedding_model, _hf_models + _openai_models)) def _get_similarity_evaluation_op(config): @@ -74,12 +75,12 @@ def _get_similarity_evaluation_op(config): def _get_llm_op(config): - if config.llm_model.lower() == 'openai': - return ops.LLM.OpenAI(api_key=config.openai_api_key) - if config.llm_model.lower() == 'dolly': - return ops.LLM.Dolly() + if config.llm_src.lower() == 'openai': + return ops.LLM.OpenAI(model_name=config.llm_model, api_key=config.openai_api_key) + if config.llm_src.lower() == 'dolly': + return ops.LLM.Dolly(model_name=config.llm_model) - raise RuntimeError('Unknown llm model: [%s], only support \'openai\' and \'dolly\'' % (config.model)) + raise RuntimeError('Unknown llm source: [%s], only support \'openai\' and \'dolly\'' % (config.llm_src)) @AutoPipes.register @@ -116,10 +117,10 @@ def enhanced_qa_search_pipe(config): if config.threshold: sim_eval_op = _get_similarity_evaluation_op(config) p = p.map('result', 'result', sim_eval_op) - + p = ( p.map('result', 'docs', lambda x:[i[2] for i in x]) - .map(('question', 'docs', 'history'), 'prompt', ops.prompt.question_answer(llm_name=config.llm_model)) + .map(('question', 'docs', 'history'), 'prompt', ops.prompt.question_answer(llm_name=config.llm_src)) .map('prompt', 'answer', llm_op) )