|
|
@ -11,32 +11,34 @@ |
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|
|
|
# See the License for the specific language governing permissions and |
|
|
|
# limitations under the License. |
|
|
|
from typing import Optional, Union |
|
|
|
|
|
|
|
from towhee import ops, pipe, AutoPipes, AutoConfig |
|
|
|
from pydantic import BaseModel |
|
|
|
|
|
|
|
|
|
|
|
@AutoConfig.register |
|
|
|
class EnhancedQASearchConfig: |
|
|
|
class EnhancedQASearchConfig(BaseModel): |
|
|
|
""" |
|
|
|
Config of pipeline |
|
|
|
""" |
|
|
|
def __init__(self): |
|
|
|
# config for sentence_embedding |
|
|
|
self.embedding_model = 'all-MiniLM-L6-v2' |
|
|
|
self.openai_api_key = None |
|
|
|
self.embedding_device = -1 |
|
|
|
# config for search_milvus |
|
|
|
self.host = '127.0.0.1' |
|
|
|
self.port = '19530' |
|
|
|
self.collection_name = 'chatbot' |
|
|
|
self.top_k = 5 |
|
|
|
self.user = None |
|
|
|
self.password = None |
|
|
|
# config for similarity evaluation |
|
|
|
self.threshold = 0.6 |
|
|
|
# config for llm |
|
|
|
self.llm_src = 'openai' |
|
|
|
self.llm_model = 'gpt-3.5-turbo' if self.llm_src.lower() == 'openai' else 'databricks/dolly-v2-12b' |
|
|
|
# config for sentence_embedding |
|
|
|
embedding_model: Optional[str] = 'all-MiniLM-L6-v2' |
|
|
|
openai_api_key: Optional[str] = None |
|
|
|
embedding_device: Optional[int] = -1 |
|
|
|
# config for search_milvus |
|
|
|
host: Optional[str] = '127.0.0.1' |
|
|
|
port: Optional[str] = '19530' |
|
|
|
collection_name: Optional[str] = 'chatbot' |
|
|
|
top_k: Optional[int] = 5 |
|
|
|
user: Optional[str] = None |
|
|
|
password: Optional[str] = None |
|
|
|
# config for similarity evaluation |
|
|
|
threshold: Optional[Union[float, int]] = 0.6 |
|
|
|
# config for llm |
|
|
|
llm_src: Optional[str] = 'openai' |
|
|
|
openai_model: Optional[str] = 'gpt-3.5-turbo' |
|
|
|
dolly_model: Optional[str] = 'databricks/dolly-v2-12b' |
|
|
|
|
|
|
|
|
|
|
|
_hf_models = ops.sentence_embedding.transformers().get_op().supported_model_names() |
|
|
@ -76,9 +78,9 @@ def _get_similarity_evaluation_op(config): |
|
|
|
|
|
|
|
def _get_llm_op(config): |
|
|
|
if config.llm_src.lower() == 'openai': |
|
|
|
return ops.LLM.OpenAI(model_name=config.llm_model, api_key=config.openai_api_key) |
|
|
|
return ops.LLM.OpenAI(model_name=config.openai_model, api_key=config.openai_api_key) |
|
|
|
if config.llm_src.lower() == 'dolly': |
|
|
|
return ops.LLM.Dolly(model_name=config.llm_model) |
|
|
|
return ops.LLM.Dolly(model_name=config.dolly_model) |
|
|
|
|
|
|
|
raise RuntimeError('Unknown llm source: [%s], only support \'openai\' and \'dolly\'' % (config.llm_src)) |
|
|
|
|
|
|
|