diff --git a/README.md b/README.md index 24671ac..800c7d2 100644 --- a/README.md +++ b/README.md @@ -41,9 +41,10 @@ config.threshold = 0.5 # The llm model source, openai or dolly config.llm_src = 'openai' -# The llm model name -config.llm_model = 'gpt-3.5-turbo' - +# The openai model name +config.openai_model = 'gpt-3.5-turbo' +# The dolly model name +# config.dolly_model = 'databricks/dolly-v2-12b' p = AutoPipes.pipeline('eqa-search', config=config) res = p('What is towhee?', []) @@ -111,10 +112,13 @@ The threshold to filter the milvus search result. The llm model source, `openai` or `dolly`, defaults to `openai`. -***llm_model (str):*** +***openai_model (str):*** + +The openai model name, defaults to `gpt-3.5-turbo`. -The llm model name, defaults to `gpt-3.5-turbo` for `openai`, `databricks/dolly-v2-12b` for `dolly`. +***dolly_model (str):*** +The dolly model name, defaults to `databricks/dolly-v2-12b`.
diff --git a/eqa_search.py b/eqa_search.py index 0786104..f2de589 100644 --- a/eqa_search.py +++ b/eqa_search.py @@ -11,32 +11,34 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Optional, Union from towhee import ops, pipe, AutoPipes, AutoConfig +from pydantic import BaseModel @AutoConfig.register -class EnhancedQASearchConfig: +class EnhancedQASearchConfig(BaseModel): """ Config of pipeline """ - def __init__(self): - # config for sentence_embedding - self.embedding_model = 'all-MiniLM-L6-v2' - self.openai_api_key = None - self.embedding_device = -1 - # config for search_milvus - self.host = '127.0.0.1' - self.port = '19530' - self.collection_name = 'chatbot' - self.top_k = 5 - self.user = None - self.password = None - # config for similarity evaluation - self.threshold = 0.6 - # config for llm - self.llm_src = 'openai' - self.llm_model = 'gpt-3.5-turbo' if self.llm_src.lower() == 'openai' else 'databricks/dolly-v2-12b' + # config for sentence_embedding + embedding_model: Optional[str] = 'all-MiniLM-L6-v2' + openai_api_key: Optional[str] = None + embedding_device: Optional[int] = -1 + # config for search_milvus + host: Optional[str] = '127.0.0.1' + port: Optional[str] = '19530' + collection_name: Optional[str] = 'chatbot' + top_k: Optional[int] = 5 + user: Optional[str] = None + password: Optional[str] = None + # config for similarity evaluation + threshold: Optional[Union[float, int]] = 0.6 + # config for llm + llm_src: Optional[str] = 'openai' + openai_model: Optional[str] = 'gpt-3.5-turbo' + dolly_model: Optional[str] = 'databricks/dolly-v2-12b' _hf_models = ops.sentence_embedding.transformers().get_op().supported_model_names() @@ -76,9 +78,9 @@ def _get_similarity_evaluation_op(config): def _get_llm_op(config): if config.llm_src.lower() == 'openai': - return ops.LLM.OpenAI(model_name=config.llm_model, api_key=config.openai_api_key) + return ops.LLM.OpenAI(model_name=config.openai_model, api_key=config.openai_api_key) if config.llm_src.lower() == 'dolly': - return ops.LLM.Dolly(model_name=config.llm_model) + return ops.LLM.Dolly(model_name=config.dolly_model) raise RuntimeError('Unknown llm source: [%s], only support \'openai\' and \'dolly\'' % (config.llm_src))