logo
Browse Source

Update with basemodel

Signed-off-by: shiyu22 <shiyu.chen@zilliz.com>
main
shiyu22 2 years ago
parent
commit
6c9acffb38
  1. 30
      eqa_insert.py

30
eqa_insert.py

@ -12,29 +12,31 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Dict, Optional, Any
from pydantic import BaseModel
from towhee import ops, pipe, AutoPipes, AutoConfig from towhee import ops, pipe, AutoPipes, AutoConfig
@AutoConfig.register @AutoConfig.register
class EnhancedQAInsertConfig:
class EnhancedQAInsertConfig(BaseModel):
""" """
Config of pipeline Config of pipeline
""" """
def __init__(self):
# config for text_splitter # config for text_splitter
self.type = 'RecursiveCharacter'
self.chunk_size = 300
self.splitter_kwargs = {}
type: Optional[str] = 'RecursiveCharacter'
chunk_size: int = 300
splitter_kwargs: Optional[Dict[str, Any]] = {}
# config for sentence_embedding # config for sentence_embedding
self.model = 'all-MiniLM-L6-v2'
self.openai_api_key = None
self.device = -1
model: Optional[str] = 'all-MiniLM-L6-v2'
openai_api_key: Optional[str] = None
device: Optional[int] = -1
# config for insert_milvus # config for insert_milvus
self.host = '127.0.0.1'
self.port = '19530'
self.collection_name = 'chatbot'
self.user = None
self.password = None
host: Optional[str] = '127.0.0.1'
port: Optional[str] = '19530'
collection_name: Optional[str] = 'chatbot'
user: Optional[str] = None
password: Optional[str] = None
_hf_models = ops.sentence_embedding.transformers().get_op().supported_model_names() _hf_models = ops.sentence_embedding.transformers().get_op().supported_model_names()
@ -56,7 +58,7 @@ def _get_embedding_op(config):
return True, ops.sentence_embedding.sbert(model_name=config.model, device=device) return True, ops.sentence_embedding.sbert(model_name=config.model, device=device)
if config.model in _openai_models: if config.model in _openai_models:
return False, ops.sentence_embedding.openai(model_name=config.model, api_key=config.openai_api_key) return False, ops.sentence_embedding.openai(model_name=config.model, api_key=config.openai_api_key)
raise RuntimeError('Unknown model: [%s], only support: %s' % (config.model, _hf_models + _openai_models))
raise RuntimeError('Unknown model: [%s], only support: %s' % (config.model, _hf_models + _sbert_models + _openai_models))

Loading…
Cancel
Save