logo
Browse Source

Debug to support TritonServer

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
c947c6ea84
  1. 2
      test_onnx.py
  2. 22
      timm_image.py

2
test_onnx.py

@ -57,7 +57,7 @@ for name in models:
status = [name] + ['fail'] * 5
try:
out1 = op.accelerate_model(data).detach().numpy()
out1 = op.model(data).detach().numpy()
logger.info('OP LOADED.')
status[1] = 'success'
except Exception as e:

22
timm_image.py

@ -45,6 +45,7 @@ log = logging.getLogger('timm_op')
class Model:
def __init__(self, model_name, device, num_classes):
self.model = create_model(model_name, pretrained=True, num_classes=num_classes)
self.config = resolve_data_config({}, model=self.model)
self.model.eval()
self.model.to(device)
@ -77,13 +78,12 @@ class TimmImage(NNOperator):
self.device = device
self.model_name = model_name
if self.model_name:
self.accelerate_model = Model(
self.model = Model(
model_name=model_name,
device=self.device,
num_classes=num_classes
)
self.model = self.accelerate_model.model
self.config = resolve_data_config({}, model=self.model)
self.config = self.model.config
self.tfms = create_transform(**self.config)
self.skip_tfms = skip_preprocess
else:
@ -102,7 +102,7 @@ class TimmImage(NNOperator):
img_list.append(img)
inputs = torch.stack(img_list)
inputs = inputs.to(self.device)
features = self.accelerate_model(inputs)
features = self.model(inputs)
if features.dim() == 4:
global_pool = nn.AdaptiveAvgPool2d(1).to(self.device)
features = global_pool(features)
@ -113,6 +113,10 @@ class TimmImage(NNOperator):
vecs = features.squeeze(0).detach().numpy()
return vecs
@property
def _model(self):
return self.model.model
@arg(1, to_image_color('RGB'))
def convert_img(self, img: towhee._types.Image):
img = PILImage.fromarray(img.astype('uint8'), 'RGB')
@ -133,21 +137,21 @@ class TimmImage(NNOperator):
raise AttributeError(f'Invalid format {format}.')
dummy_input = torch.rand((1,) + self.config['input_size'])
if format == 'pytorch':
torch.save(self.model, path)
torch.save(self._model, path)
elif format == 'torchscript':
try:
try:
jit_model = torch.jit.script(self.model)
jit_model = torch.jit.script(self._model)
except Exception:
jit_model = torch.jit.trace(self.model, dummy_input, strict=False)
jit_model = torch.jit.trace(self._model, dummy_input, strict=False)
torch.jit.save(jit_model, path)
except Exception as e:
log.error(f'Fail to save as torchscript: {e}.')
raise RuntimeError(f'Fail to save as torchscript: {e}.')
elif format == 'onnx':
self.model.forward = self.model.forward_features
self._model.forward = self._model.forward_features
try:
torch.onnx.export(self.model,
torch.onnx.export(self._model,
dummy_input,
path,
input_names=['input_0'],

Loading…
Cancel
Save