|
@ -84,10 +84,8 @@ class AutoTransformers(NNOperator): |
|
|
self.model_name = model_name |
|
|
self.model_name = model_name |
|
|
|
|
|
|
|
|
if self.model_name: |
|
|
if self.model_name: |
|
|
self.accelerate_model = Model( |
|
|
|
|
|
|
|
|
self.model = Model( |
|
|
model_name=self.model_name, device=self.device, checkpoint_path=checkpoint_path) |
|
|
model_name=self.model_name, device=self.device, checkpoint_path=checkpoint_path) |
|
|
self.model = self.accelerate_model.model |
|
|
|
|
|
self.configs = self.model.config |
|
|
|
|
|
if tokenizer is None: |
|
|
if tokenizer is None: |
|
|
try: |
|
|
try: |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
@ -107,7 +105,7 @@ class AutoTransformers(NNOperator): |
|
|
log.error(f'Invalid input for the tokenizer: {self.model_name}') |
|
|
log.error(f'Invalid input for the tokenizer: {self.model_name}') |
|
|
raise e |
|
|
raise e |
|
|
try: |
|
|
try: |
|
|
outs = self.accelerate_model(**inputs) |
|
|
|
|
|
|
|
|
outs = self.model(**inputs) |
|
|
except Exception as e: |
|
|
except Exception as e: |
|
|
log.error(f'Invalid input for the model: {self.model_name}') |
|
|
log.error(f'Invalid input for the model: {self.model_name}') |
|
|
raise e |
|
|
raise e |
|
@ -119,6 +117,10 @@ class AutoTransformers(NNOperator): |
|
|
vec = features.cpu().detach().numpy() |
|
|
vec = features.cpu().detach().numpy() |
|
|
return vec |
|
|
return vec |
|
|
|
|
|
|
|
|
|
|
|
@property |
|
|
|
|
|
def _model(self): |
|
|
|
|
|
return self.model.model |
|
|
|
|
|
|
|
|
def save_model(self, model_type: str = 'pytorch', output_file: str = 'default'): |
|
|
def save_model(self, model_type: str = 'pytorch', output_file: str = 'default'): |
|
|
if output_file == 'default': |
|
|
if output_file == 'default': |
|
|
output_file = str(Path(__file__).parent) |
|
|
output_file = str(Path(__file__).parent) |
|
@ -152,14 +154,14 @@ class AutoTransformers(NNOperator): |
|
|
from transformers.onnx.features import FeaturesManager |
|
|
from transformers.onnx.features import FeaturesManager |
|
|
from transformers.onnx import export |
|
|
from transformers.onnx import export |
|
|
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise( |
|
|
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise( |
|
|
self.model, feature='default') |
|
|
|
|
|
onnx_config = model_onnx_config(self.configs) |
|
|
|
|
|
|
|
|
self._model, feature='default') |
|
|
|
|
|
onnx_config = model_onnx_config(self._model.config) |
|
|
# if os.path.isdir(output_file[:-5]): |
|
|
# if os.path.isdir(output_file[:-5]): |
|
|
# shutil.rmtree(output_file[:-5]) |
|
|
# shutil.rmtree(output_file[:-5]) |
|
|
# print('********', Path(output_file)) |
|
|
# print('********', Path(output_file)) |
|
|
onnx_inputs, onnx_outputs = export( |
|
|
onnx_inputs, onnx_outputs = export( |
|
|
self.tokenizer, |
|
|
self.tokenizer, |
|
|
self.model, |
|
|
|
|
|
|
|
|
self._model, |
|
|
config=onnx_config, |
|
|
config=onnx_config, |
|
|
opset=13, |
|
|
opset=13, |
|
|
output=Path(output_file) |
|
|
output=Path(output_file) |
|
|