|
|
@ -13,6 +13,9 @@ |
|
|
|
# limitations under the License. |
|
|
|
|
|
|
|
import numpy |
|
|
|
import os |
|
|
|
import torch |
|
|
|
from pathlib import Path |
|
|
|
from transformers import AutoTokenizer, AutoModel |
|
|
|
|
|
|
|
from towhee.operator import NNOperator |
|
|
@ -40,6 +43,7 @@ class AutoTransformers(NNOperator): |
|
|
|
self.model_name = model_name |
|
|
|
try: |
|
|
|
self.model = AutoModel.from_pretrained(model_name) |
|
|
|
self.model.eval() |
|
|
|
except Exception as e: |
|
|
|
model_list = get_model_list() |
|
|
|
if model_name not in model_list: |
|
|
@ -65,13 +69,29 @@ class AutoTransformers(NNOperator): |
|
|
|
log.error(f'Invalid input for the model: {self.model_name}') |
|
|
|
raise e |
|
|
|
try: |
|
|
|
features = outs.last_hidden_state.squeeze(0) |
|
|
|
features = outs['last_hidden_state'].squeeze(0) |
|
|
|
except Exception as e: |
|
|
|
log.error(f'Fail to extract features by model: {self.model_name}') |
|
|
|
raise e |
|
|
|
vec = features.detach().numpy() |
|
|
|
return vec |
|
|
|
|
|
|
|
def save_model(self, jit: bool = True, destination: str = 'default'): |
|
|
|
if destination == 'default': |
|
|
|
path = str(Path(__file__).parent) |
|
|
|
destination = os.path.join(path, self.model_name + '.pt') |
|
|
|
inputs = self.tokenizer('[CLS]', return_tensors='pt') |
|
|
|
inputs = list(inputs.values()) |
|
|
|
if jit: |
|
|
|
try: |
|
|
|
traced_model = torch.jit.trace(self.model, inputs, strict=False) |
|
|
|
torch.jit.save(traced_model, destination) |
|
|
|
except Exception as e: |
|
|
|
log.error(f'Fail to save as torchscript: {e}.') |
|
|
|
raise RuntimeError(f'Fail to save as torchscript: {e}.') |
|
|
|
else: |
|
|
|
torch.save(self.model, destination) |
|
|
|
|
|
|
|
|
|
|
|
def get_model_list(): |
|
|
|
full_list = [ |
|
|
|