|
|
@ -73,19 +73,19 @@ class AutoTransformers(NNOperator): |
|
|
|
else: |
|
|
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
self.model_name = self.map_model_names(model_name) |
|
|
|
if tokenizer: |
|
|
|
self.tokenizer = tokenizer |
|
|
|
else: |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) |
|
|
|
if not self.tokenizer.pad_token: |
|
|
|
self.tokenizer.pad_token = '[PAD]' |
|
|
|
self.norm = norm |
|
|
|
self.checkpoint_path = checkpoint_path |
|
|
|
|
|
|
|
if self.model_name: |
|
|
|
model_list = self.supported_model_names() |
|
|
|
# model_list = self.supported_model_names() |
|
|
|
# assert model_name in model_list, f"Invalid model name: {model_name}. Supported model names: {model_list}" |
|
|
|
self.model = Model(self._model) |
|
|
|
if tokenizer: |
|
|
|
self.tokenizer = tokenizer |
|
|
|
else: |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) |
|
|
|
if not self.tokenizer.pad_token: |
|
|
|
self.tokenizer.pad_token = '[PAD]' |
|
|
|
else: |
|
|
|
log.warning('The operator is initialized without specified model.') |
|
|
|
pass |
|
|
@ -205,14 +205,6 @@ class AutoTransformers(NNOperator): |
|
|
|
log.error(f'Unsupported format "{format}".') |
|
|
|
return Path(output_file).resolve() |
|
|
|
|
|
|
|
@property |
|
|
|
def supported_formats(self): |
|
|
|
onnxes = self.supported_model_names(format='onnx') |
|
|
|
if self.model_name in onnxes: |
|
|
|
return ['onnx'] |
|
|
|
else: |
|
|
|
return ['pytorch'] |
|
|
|
|
|
|
|
@property |
|
|
|
def supported_formats(self): |
|
|
|
return ['onnx'] |
|
|
@ -220,7 +212,60 @@ class AutoTransformers(NNOperator): |
|
|
|
@staticmethod |
|
|
|
def supported_model_names(format: str = None): |
|
|
|
full_list = [ |
|
|
|
|
|
|
|
'sentence-t5-xxl', |
|
|
|
'sentence-t5-xl', |
|
|
|
'sentence-t5-large', |
|
|
|
'text-embedding-ada-002', |
|
|
|
'text-similarity-davinci-001', |
|
|
|
'text-similarity-babbage-001', |
|
|
|
'text-similarity-curie-001', |
|
|
|
'paraphrase-mpnet-base-v2', |
|
|
|
'text-similarity-ada-001', |
|
|
|
'gtr-t5-xxl', |
|
|
|
'gtr-t5-large', |
|
|
|
'gtr-t5-xl', |
|
|
|
'paraphrase-multilingual-mpnet-base-v2', |
|
|
|
'paraphrase-distilroberta-base-v2', |
|
|
|
'all-mpnet-base-v1', |
|
|
|
'all-roberta-large-v1', |
|
|
|
'all-mpnet-base-v2', |
|
|
|
'all-MiniLM-L12-v2', |
|
|
|
'all-distilroberta-v1', |
|
|
|
'all-MiniLM-L12-v1', |
|
|
|
'gtr-t5-base', |
|
|
|
'paraphrase-multilingual-MiniLM-L12-v2', |
|
|
|
'paraphrase-MiniLM-L12-v2', |
|
|
|
'all-MiniLM-L6-v1', |
|
|
|
'paraphrase-TinyBERT-L6-v2', |
|
|
|
'all-MiniLM-L6-v2', |
|
|
|
'paraphrase-albert-small-v2', |
|
|
|
'multi-qa-mpnet-base-cos-v1', |
|
|
|
'paraphrase-MiniLM-L3-v2', |
|
|
|
'multi-qa-distilbert-cos-v1', |
|
|
|
'msmarco-distilbert-base-v4', |
|
|
|
'multi-qa-mpnet-base-dot-v1', |
|
|
|
'msmarco-distilbert-base-tas-b', |
|
|
|
'distiluse-base-multilingual-cased-v2', |
|
|
|
'multi-qa-distilbert-dot-v1', |
|
|
|
'multi-qa-MiniLM-L6-cos-v1', |
|
|
|
'distiluse-base-multilingual-cased-v1', |
|
|
|
'msmarco-bert-base-dot-v5', |
|
|
|
'paraphrase-MiniLM-L6-v2', |
|
|
|
'multi-qa-MiniLM-L6-dot-v1', |
|
|
|
'msmarco-distilbert-dot-v5', |
|
|
|
'bert-base-nli-mean-tokens', |
|
|
|
'bert-large-uncased-whole-word-masking', |
|
|
|
'average_word_embeddings_komninos', |
|
|
|
'distilbert-base-uncased', |
|
|
|
'average_word_embeddings_glove.6B.300d', |
|
|
|
'dpr-ctx_encoder-multiset-base', |
|
|
|
'dpr-ctx_encoder-single-nq-base', |
|
|
|
'microsoft/deberta-xlarge', |
|
|
|
'facebook/bart-large', |
|
|
|
'bert-base-uncased', |
|
|
|
'microsoft/deberta-xlarge-mnli', |
|
|
|
'gpt2-xl', |
|
|
|
'bert-large-uncased' |
|
|
|
] |
|
|
|
full_list.sort() |
|
|
|
if format is None: |
|
|
@ -230,13 +275,11 @@ class AutoTransformers(NNOperator): |
|
|
|
assert set(to_remove).issubset(set(full_list)) |
|
|
|
model_list = list(set(full_list) - set(to_remove)) |
|
|
|
elif format == 'torchscript': |
|
|
|
to_remove = [ |
|
|
|
] |
|
|
|
to_remove = [] |
|
|
|
assert set(to_remove).issubset(set(full_list)) |
|
|
|
model_list = list(set(full_list) - set(to_remove)) |
|
|
|
elif format == 'onnx': |
|
|
|
to_remove = [ |
|
|
|
] |
|
|
|
to_remove = [] |
|
|
|
assert set(to_remove).issubset(set(full_list)) |
|
|
|
model_list = list(set(full_list) - set(to_remove)) |
|
|
|
# todo: elif format == 'tensorrt': |
|
|
|