logo
Browse Source

Debug to support TritonServer

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 1 year ago
parent
commit
c947c6ea84
  1. 2
      test_onnx.py
  2. 22
      timm_image.py

2
test_onnx.py

@ -57,7 +57,7 @@ for name in models:
status = [name] + ['fail'] * 5 status = [name] + ['fail'] * 5
try: try:
out1 = op.accelerate_model(data).detach().numpy()
out1 = op.model(data).detach().numpy()
logger.info('OP LOADED.') logger.info('OP LOADED.')
status[1] = 'success' status[1] = 'success'
except Exception as e: except Exception as e:

22
timm_image.py

@ -45,6 +45,7 @@ log = logging.getLogger('timm_op')
class Model: class Model:
def __init__(self, model_name, device, num_classes): def __init__(self, model_name, device, num_classes):
self.model = create_model(model_name, pretrained=True, num_classes=num_classes) self.model = create_model(model_name, pretrained=True, num_classes=num_classes)
self.config = resolve_data_config({}, model=self.model)
self.model.eval() self.model.eval()
self.model.to(device) self.model.to(device)
@ -77,13 +78,12 @@ class TimmImage(NNOperator):
self.device = device self.device = device
self.model_name = model_name self.model_name = model_name
if self.model_name: if self.model_name:
self.accelerate_model = Model(
self.model = Model(
model_name=model_name, model_name=model_name,
device=self.device, device=self.device,
num_classes=num_classes num_classes=num_classes
) )
self.model = self.accelerate_model.model
self.config = resolve_data_config({}, model=self.model)
self.config = self.model.config
self.tfms = create_transform(**self.config) self.tfms = create_transform(**self.config)
self.skip_tfms = skip_preprocess self.skip_tfms = skip_preprocess
else: else:
@ -102,7 +102,7 @@ class TimmImage(NNOperator):
img_list.append(img) img_list.append(img)
inputs = torch.stack(img_list) inputs = torch.stack(img_list)
inputs = inputs.to(self.device) inputs = inputs.to(self.device)
features = self.accelerate_model(inputs)
features = self.model(inputs)
if features.dim() == 4: if features.dim() == 4:
global_pool = nn.AdaptiveAvgPool2d(1).to(self.device) global_pool = nn.AdaptiveAvgPool2d(1).to(self.device)
features = global_pool(features) features = global_pool(features)
@ -113,6 +113,10 @@ class TimmImage(NNOperator):
vecs = features.squeeze(0).detach().numpy() vecs = features.squeeze(0).detach().numpy()
return vecs return vecs
@property
def _model(self):
return self.model.model
@arg(1, to_image_color('RGB')) @arg(1, to_image_color('RGB'))
def convert_img(self, img: towhee._types.Image): def convert_img(self, img: towhee._types.Image):
img = PILImage.fromarray(img.astype('uint8'), 'RGB') img = PILImage.fromarray(img.astype('uint8'), 'RGB')
@ -133,21 +137,21 @@ class TimmImage(NNOperator):
raise AttributeError(f'Invalid format {format}.') raise AttributeError(f'Invalid format {format}.')
dummy_input = torch.rand((1,) + self.config['input_size']) dummy_input = torch.rand((1,) + self.config['input_size'])
if format == 'pytorch': if format == 'pytorch':
torch.save(self.model, path)
torch.save(self._model, path)
elif format == 'torchscript': elif format == 'torchscript':
try: try:
try: try:
jit_model = torch.jit.script(self.model)
jit_model = torch.jit.script(self._model)
except Exception: except Exception:
jit_model = torch.jit.trace(self.model, dummy_input, strict=False)
jit_model = torch.jit.trace(self._model, dummy_input, strict=False)
torch.jit.save(jit_model, path) torch.jit.save(jit_model, path)
except Exception as e: except Exception as e:
log.error(f'Fail to save as torchscript: {e}.') log.error(f'Fail to save as torchscript: {e}.')
raise RuntimeError(f'Fail to save as torchscript: {e}.') raise RuntimeError(f'Fail to save as torchscript: {e}.')
elif format == 'onnx': elif format == 'onnx':
self.model.forward = self.model.forward_features
self._model.forward = self._model.forward_features
try: try:
torch.onnx.export(self.model,
torch.onnx.export(self._model,
dummy_input, dummy_input,
path, path,
input_names=['input_0'], input_names=['input_0'],

Loading…
Cancel
Save