logo
Browse Source

Debug for TritonServer

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 1 year ago
parent
commit
7fa55ba6eb
  1. 16
      auto_transformers.py

16
auto_transformers.py

@ -84,10 +84,8 @@ class AutoTransformers(NNOperator):
self.model_name = model_name
if self.model_name:
self.accelerate_model = Model(
self.model = Model(
model_name=self.model_name, device=self.device, checkpoint_path=checkpoint_path)
self.model = self.accelerate_model.model
self.configs = self.model.config
if tokenizer is None:
try:
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
@ -107,7 +105,7 @@ class AutoTransformers(NNOperator):
log.error(f'Invalid input for the tokenizer: {self.model_name}')
raise e
try:
outs = self.accelerate_model(**inputs)
outs = self.model(**inputs)
except Exception as e:
log.error(f'Invalid input for the model: {self.model_name}')
raise e
@ -119,6 +117,10 @@ class AutoTransformers(NNOperator):
vec = features.cpu().detach().numpy()
return vec
@property
def _model(self):
return self.model.model
def save_model(self, model_type: str = 'pytorch', output_file: str = 'default'):
if output_file == 'default':
output_file = str(Path(__file__).parent)
@ -152,14 +154,14 @@ class AutoTransformers(NNOperator):
from transformers.onnx.features import FeaturesManager
from transformers.onnx import export
model_kind, model_onnx_config = FeaturesManager.check_supported_model_or_raise(
self.model, feature='default')
onnx_config = model_onnx_config(self.configs)
self._model, feature='default')
onnx_config = model_onnx_config(self._model.config)
# if os.path.isdir(output_file[:-5]):
# shutil.rmtree(output_file[:-5])
# print('********', Path(output_file))
onnx_inputs, onnx_outputs = export(
self.tokenizer,
self.model,
self._model,
config=onnx_config,
opset=13,
output=Path(output_file)

Loading…
Cancel
Save