|
@ -12,30 +12,32 @@ |
|
|
# See the License for the specific language governing permissions and |
|
|
# See the License for the specific language governing permissions and |
|
|
# limitations under the License. |
|
|
# limitations under the License. |
|
|
|
|
|
|
|
|
|
|
|
from typing import Dict, Optional, Any |
|
|
|
|
|
from pydantic import BaseModel |
|
|
|
|
|
|
|
|
from towhee import ops, pipe, AutoPipes, AutoConfig |
|
|
from towhee import ops, pipe, AutoPipes, AutoConfig |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@AutoConfig.register |
|
|
@AutoConfig.register |
|
|
class EnhancedQAInsertConfig: |
|
|
|
|
|
|
|
|
class EnhancedQAInsertConfig(BaseModel): |
|
|
""" |
|
|
""" |
|
|
Config of pipeline |
|
|
Config of pipeline |
|
|
""" |
|
|
""" |
|
|
def __init__(self): |
|
|
|
|
|
# config for text_splitter |
|
|
|
|
|
self.type = 'RecursiveCharacter' |
|
|
|
|
|
self.chunk_size = 300 |
|
|
|
|
|
self.splitter_kwargs = {} |
|
|
|
|
|
# config for sentence_embedding |
|
|
|
|
|
self.model = 'all-MiniLM-L6-v2' |
|
|
|
|
|
self.openai_api_key = None |
|
|
|
|
|
self.device = -1 |
|
|
|
|
|
# config for insert_milvus |
|
|
|
|
|
self.host = '127.0.0.1' |
|
|
|
|
|
self.port = '19530' |
|
|
|
|
|
self.collection_name = 'chatbot' |
|
|
|
|
|
self.user = None |
|
|
|
|
|
self.password = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# config for text_splitter |
|
|
|
|
|
type: Optional[str] = 'RecursiveCharacter' |
|
|
|
|
|
chunk_size: int = 300 |
|
|
|
|
|
splitter_kwargs: Optional[Dict[str, Any]] = {} |
|
|
|
|
|
# config for sentence_embedding |
|
|
|
|
|
model: Optional[str] = 'all-MiniLM-L6-v2' |
|
|
|
|
|
openai_api_key: Optional[str] = None |
|
|
|
|
|
device: Optional[int] = -1 |
|
|
|
|
|
# config for insert_milvus |
|
|
|
|
|
host: Optional[str] = '127.0.0.1' |
|
|
|
|
|
port: Optional[str] = '19530' |
|
|
|
|
|
collection_name: Optional[str] = 'chatbot' |
|
|
|
|
|
user: Optional[str] = None |
|
|
|
|
|
password: Optional[str] = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_hf_models = ops.sentence_embedding.transformers().get_op().supported_model_names() |
|
|
_hf_models = ops.sentence_embedding.transformers().get_op().supported_model_names() |
|
|
_sbert_models = ops.sentence_embedding.sbert().get_op().supported_model_names() |
|
|
_sbert_models = ops.sentence_embedding.sbert().get_op().supported_model_names() |
|
@ -56,7 +58,7 @@ def _get_embedding_op(config): |
|
|
return True, ops.sentence_embedding.sbert(model_name=config.model, device=device) |
|
|
return True, ops.sentence_embedding.sbert(model_name=config.model, device=device) |
|
|
if config.model in _openai_models: |
|
|
if config.model in _openai_models: |
|
|
return False, ops.sentence_embedding.openai(model_name=config.model, api_key=config.openai_api_key) |
|
|
return False, ops.sentence_embedding.openai(model_name=config.model, api_key=config.openai_api_key) |
|
|
raise RuntimeError('Unknown model: [%s], only support: %s' % (config.model, _hf_models + _openai_models)) |
|
|
|
|
|
|
|
|
raise RuntimeError('Unknown model: [%s], only support: %s' % (config.model, _hf_models + _sbert_models + _openai_models)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|