|
@ -38,11 +38,14 @@ class AutoTransformers(NNOperator): |
|
|
Which model to use for the embeddings. |
|
|
Which model to use for the embeddings. |
|
|
""" |
|
|
""" |
|
|
|
|
|
|
|
|
def __init__(self, model_name: str = "bert-base-uncased") -> None: |
|
|
|
|
|
|
|
|
def __init__(self, model_name: str = "bert-base-uncased", device=None) -> None: |
|
|
super().__init__() |
|
|
super().__init__() |
|
|
|
|
|
if device is None: |
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
self.device = device |
|
|
self.model_name = model_name |
|
|
self.model_name = model_name |
|
|
try: |
|
|
try: |
|
|
self.model = AutoModel.from_pretrained(model_name) |
|
|
|
|
|
|
|
|
self.model = AutoModel.from_pretrained(model_name).to(self.device) |
|
|
self.model.eval() |
|
|
self.model.eval() |
|
|
except Exception as e: |
|
|
except Exception as e: |
|
|
model_list = self.supported_model_names() |
|
|
model_list = self.supported_model_names() |
|
@ -59,7 +62,7 @@ class AutoTransformers(NNOperator): |
|
|
|
|
|
|
|
|
def __call__(self, txt: str) -> numpy.ndarray: |
|
|
def __call__(self, txt: str) -> numpy.ndarray: |
|
|
try: |
|
|
try: |
|
|
inputs = self.tokenizer(txt, return_tensors="pt") |
|
|
|
|
|
|
|
|
inputs = self.tokenizer(txt, return_tensors="pt").to(self.device) |
|
|
except Exception as e: |
|
|
except Exception as e: |
|
|
log.error(f'Invalid input for the tokenizer: {self.model_name}') |
|
|
log.error(f'Invalid input for the tokenizer: {self.model_name}') |
|
|
raise e |
|
|
raise e |
|
@ -73,7 +76,7 @@ class AutoTransformers(NNOperator): |
|
|
except Exception as e: |
|
|
except Exception as e: |
|
|
log.error(f'Fail to extract features by model: {self.model_name}') |
|
|
log.error(f'Fail to extract features by model: {self.model_name}') |
|
|
raise e |
|
|
raise e |
|
|
vec = features.detach().numpy() |
|
|
|
|
|
|
|
|
vec = features.cpu().detach().numpy() |
|
|
return vec |
|
|
return vec |
|
|
|
|
|
|
|
|
def save_model(self, format: str = 'pytorch', path: str = 'default'): |
|
|
def save_model(self, format: str = 'pytorch', path: str = 'default'): |
|
|