diff --git a/README.md b/README.md index 1751bbb..0dd722c 100644 --- a/README.md +++ b/README.md @@ -128,6 +128,15 @@ Users customize LLM. Users customize prompt. +***ernie_api_key (str):*** + +ernie_api_key for ernie bot + +***ernie_secret_key (str):*** + +ernie_secret_key for ernie bot + +
diff --git a/eqa_search.py b/eqa_search.py index 1b37b74..1ccd9a2 100644 --- a/eqa_search.py +++ b/eqa_search.py @@ -40,6 +40,9 @@ class EnhancedQASearchConfig(BaseModel): openai_model: Optional[str] = 'gpt-3.5-turbo' dolly_model: Optional[str] = 'databricks/dolly-v2-3b' + ernie_api_key: Optional[str] = None + ernie_secret_key: Optional[str] = None + customize_llm: Optional[Any] = None customize_prompt: Optional[Any] = None @@ -86,6 +89,8 @@ def _get_llm_op(config): return ops.LLM.OpenAI(model_name=config.openai_model, api_key=config.openai_api_key) if config.llm_src.lower() == 'dolly': return ops.LLM.Dolly(model_name=config.dolly_model) + if config.llm_src.lower() == ' ernie': + return ops.LLM.Ernie(api_key=config.ernie_api_key, secret_key=config.ernie_secret_key) raise RuntimeError('Unknown llm source: [%s], only support \'openai\' and \'dolly\'' % (config.llm_src))