diff --git a/README.md b/README.md index 800c7d2..90fcb3a 100644 --- a/README.md +++ b/README.md @@ -118,7 +118,11 @@ The openai model name, defaults to `gpt-3.5-turbo`. ***dolly_model (str):*** -The dolly model name, defaults to `databricks/dolly-v2-12b`. +The dolly model name, defaults to `databricks/dolly-v2-3b`. + +**customize_llm (Any):*** + +Users customize LLM.
diff --git a/eqa_search.py b/eqa_search.py index d66c251..39f341d 100644 --- a/eqa_search.py +++ b/eqa_search.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Union +from typing import Optional, Union, Any from towhee import ops, pipe, AutoPipes, AutoConfig from pydantic import BaseModel @@ -38,7 +38,9 @@ class EnhancedQASearchConfig(BaseModel): # config for llm llm_src: Optional[str] = 'openai' openai_model: Optional[str] = 'gpt-3.5-turbo' - dolly_model: Optional[str] = 'databricks/dolly-v2-12b' + dolly_model: Optional[str] = 'databricks/dolly-v2-3b' + + customize_llm: Optional[Any] = None _hf_models = ops.sentence_embedding.transformers().get_op().supported_model_names() @@ -77,6 +79,8 @@ def _get_similarity_evaluation_op(config): def _get_llm_op(config): + if config.customize_llm: + return config.customize_llm if config.llm_src.lower() == 'openai': return ops.LLM.OpenAI(model_name=config.openai_model, api_key=config.openai_api_key) if config.llm_src.lower() == 'dolly':