|
@ -41,6 +41,7 @@ class EnhancedQASearchConfig(BaseModel): |
|
|
dolly_model: Optional[str] = 'databricks/dolly-v2-3b' |
|
|
dolly_model: Optional[str] = 'databricks/dolly-v2-3b' |
|
|
|
|
|
|
|
|
customize_llm: Optional[Any] = None |
|
|
customize_llm: Optional[Any] = None |
|
|
|
|
|
customize_prompt: Optional[Any] = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_hf_models = ops.sentence_embedding.transformers().get_op().supported_model_names() |
|
|
_hf_models = ops.sentence_embedding.transformers().get_op().supported_model_names() |
|
@ -89,6 +90,12 @@ def _get_llm_op(config): |
|
|
raise RuntimeError('Unknown llm source: [%s], only support \'openai\' and \'dolly\'' % (config.llm_src)) |
|
|
raise RuntimeError('Unknown llm source: [%s], only support \'openai\' and \'dolly\'' % (config.llm_src)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_prompt(config): |
|
|
|
|
|
if config.customize_prompt: |
|
|
|
|
|
return config.customize_prompt |
|
|
|
|
|
return ops.prompt.question_answer(llm_name=config.llm_src) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@AutoPipes.register |
|
|
@AutoPipes.register |
|
|
def enhanced_qa_search_pipe(config): |
|
|
def enhanced_qa_search_pipe(config): |
|
|
allow_triton, sentence_embedding_op = _get_embedding_op(config) |
|
|
allow_triton, sentence_embedding_op = _get_embedding_op(config) |
|
@ -126,7 +133,7 @@ def enhanced_qa_search_pipe(config): |
|
|
|
|
|
|
|
|
p = ( |
|
|
p = ( |
|
|
p.map('result', 'docs', lambda x: '\n'.join([i[2] for i in x])) |
|
|
p.map('result', 'docs', lambda x: '\n'.join([i[2] for i in x])) |
|
|
.map(('question', 'docs', 'history'), 'prompt', ops.prompt.question_answer(llm_name=config.llm_src)) |
|
|
|
|
|
|
|
|
.map(('question', 'docs', 'history'), 'prompt', _get_prompt(config)) |
|
|
.map('prompt', 'answer', llm_op) |
|
|
.map('prompt', 'answer', llm_op) |
|
|
) |
|
|
) |
|
|
|
|
|
|
|
|