logo
Browse Source

Apply pydantic

Signed-off-by: Kaiyuan Hu <kaiyuan.hu@zilliz.com>
main
Kaiyuan Hu 2 years ago
parent
commit
b960409196
  1. 14
      README.md
  2. 42
      eqa_search.py

14
README.md

@ -41,9 +41,10 @@ config.threshold = 0.5
# The llm model source, openai or dolly
config.llm_src = 'openai'
# The llm model name
config.llm_model = 'gpt-3.5-turbo'
# The openai model name
config.openai_model = 'gpt-3.5-turbo'
# The dolly model name
# config.dolly_model = 'databricks/dolly-v2-12b'
p = AutoPipes.pipeline('eqa-search', config=config)
res = p('What is towhee?', [])
@ -111,10 +112,13 @@ The threshold to filter the milvus search result.
The llm model source, `openai` or `dolly`, defaults to `openai`.
***llm_model (str):***
***openai_model (str):***
The openai model name, defaults to `gpt-3.5-turbo`.
The llm model name, defaults to `gpt-3.5-turbo` for `openai`, `databricks/dolly-v2-12b` for `dolly`.
***dolly_model (str):***
The dolly model name, defaults to `databricks/dolly-v2-12b`.
<br />

42
eqa_search.py

@ -11,32 +11,34 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Union
from towhee import ops, pipe, AutoPipes, AutoConfig
from pydantic import BaseModel
@AutoConfig.register
class EnhancedQASearchConfig:
class EnhancedQASearchConfig(BaseModel):
"""
Config of pipeline
"""
def __init__(self):
# config for sentence_embedding
self.embedding_model = 'all-MiniLM-L6-v2'
self.openai_api_key = None
self.embedding_device = -1
# config for search_milvus
self.host = '127.0.0.1'
self.port = '19530'
self.collection_name = 'chatbot'
self.top_k = 5
self.user = None
self.password = None
# config for similarity evaluation
self.threshold = 0.6
# config for llm
self.llm_src = 'openai'
self.llm_model = 'gpt-3.5-turbo' if self.llm_src.lower() == 'openai' else 'databricks/dolly-v2-12b'
# config for sentence_embedding
embedding_model: Optional[str] = 'all-MiniLM-L6-v2'
openai_api_key: Optional[str] = None
embedding_device: Optional[int] = -1
# config for search_milvus
host: Optional[str] = '127.0.0.1'
port: Optional[str] = '19530'
collection_name: Optional[str] = 'chatbot'
top_k: Optional[int] = 5
user: Optional[str] = None
password: Optional[str] = None
# config for similarity evaluation
threshold: Optional[Union[float, int]] = 0.6
# config for llm
llm_src: Optional[str] = 'openai'
openai_model: Optional[str] = 'gpt-3.5-turbo'
dolly_model: Optional[str] = 'databricks/dolly-v2-12b'
_hf_models = ops.sentence_embedding.transformers().get_op().supported_model_names()
@ -76,9 +78,9 @@ def _get_similarity_evaluation_op(config):
def _get_llm_op(config):
if config.llm_src.lower() == 'openai':
return ops.LLM.OpenAI(model_name=config.llm_model, api_key=config.openai_api_key)
return ops.LLM.OpenAI(model_name=config.openai_model, api_key=config.openai_api_key)
if config.llm_src.lower() == 'dolly':
return ops.LLM.Dolly(model_name=config.llm_model)
return ops.LLM.Dolly(model_name=config.dolly_model)
raise RuntimeError('Unknown llm source: [%s], only support \'openai\' and \'dolly\'' % (config.llm_src))

Loading…
Cancel
Save