|
|
@ -40,11 +40,19 @@ class AutoTransformers(NNOperator): |
|
|
|
NLP embedding operator that uses the pretrained transformers model gathered by huggingface. |
|
|
|
Args: |
|
|
|
model_name (`str`): |
|
|
|
Which model to use for the embeddings. |
|
|
|
The model name to load a pretrained model from transformers. |
|
|
|
checkpoint_path (`str`): |
|
|
|
The local checkpoint path. |
|
|
|
tokenizer (`object`): |
|
|
|
The tokenizer to tokenize input text as model inputs. |
|
|
|
""" |
|
|
|
|
|
|
|
def __init__(self, model_name: str = None, device: str = None, pretrain_weights_path=None, |
|
|
|
load_pretrain_f=None, tokenizer=None) -> None: |
|
|
|
def __init__(self, |
|
|
|
model_name: str = None, |
|
|
|
checkpoint_path: str = None, |
|
|
|
tokenizer: object = None, |
|
|
|
device: str = None, |
|
|
|
): |
|
|
|
super().__init__() |
|
|
|
if device is None: |
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
@ -52,29 +60,31 @@ class AutoTransformers(NNOperator): |
|
|
|
self.model_name = model_name |
|
|
|
|
|
|
|
if self.model_name: |
|
|
|
model_list = self.supported_model_names() |
|
|
|
assert model_name in model_list, f"Invalid model name: {model_name}. Supported model names: {model_list}" |
|
|
|
|
|
|
|
try: |
|
|
|
self.model = AutoModel.from_pretrained(model_name).to(self.device) |
|
|
|
self.configs = self.model.config |
|
|
|
except Exception as e: |
|
|
|
model_list = self.supported_model_names() |
|
|
|
if model_name not in model_list: |
|
|
|
log.error(f"Invalid model name: {model_name}. Supported model names: {model_list}") |
|
|
|
else: |
|
|
|
log.error(f"Fail to load model by name: {self.model_name}") |
|
|
|
log.error(f"Fail to load model by name: {self.model_name}") |
|
|
|
raise e |
|
|
|
if pretrain_weights_path is not None: |
|
|
|
if load_pretrain_f is None: |
|
|
|
state_dict = torch.load(pretrain_weights_path, map_location='cpu') |
|
|
|
if checkpoint_path: |
|
|
|
try: |
|
|
|
state_dict = torch.load(checkpoint_path, map_location=self.device) |
|
|
|
self.model.load_state_dict(state_dict) |
|
|
|
else: |
|
|
|
self.model = load_pretrain_f(self.model, pretrain_weights_path) |
|
|
|
except Exception as e: |
|
|
|
log.error(f"Fail to load state dict from {checkpoint_path}: {e}") |
|
|
|
self.model.eval() |
|
|
|
|
|
|
|
if tokenizer is None: |
|
|
|
try: |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
except Exception as e: |
|
|
|
log.error(f'Fail to load tokenizer by name: {self.model_name}') |
|
|
|
log.error(f'Fail to load default tokenizer by name: {self.model_name}') |
|
|
|
raise e |
|
|
|
else: |
|
|
|
self.tokenizer = tokenizer |
|
|
|
else: |
|
|
|
log.warning('The operator is initialized without specified model.') |
|
|
|
pass |
|
|
|