diff --git a/osschat_insert.py b/osschat_insert.py index fc454a9..cb38028 100644 --- a/osschat_insert.py +++ b/osschat_insert.py @@ -61,11 +61,10 @@ def _get_embedding_op(config): if config.embedding_model in _hf_models: return True, ops.sentence_embedding.transformers(model_name=config.embedding_model, device=device) - if config.embedding_model in _sbert_models: - return True, ops.sentence_embedding.sbert(model_name=config.embedding_model, device=device) - if config.embedding_model in _openai_models: + elif config.embedding_model in _openai_models: return False, ops.sentence_embedding.openai(model_name=config.embedding_model, api_key=config.openai_api_key) - raise RuntimeError('Unknown model: [%s], only support: %s' % (config.embedding_model, _hf_models + _sbert_models + _openai_models)) + else: + return True, ops.sentence_embedding.sbert(model_name=config.embedding_model, device=device) def data_loader(path): if path.endswith('pdf'):