logo
Browse Source

Support TritonServe

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
a8b3af6f10
  1. 6
      test_onnx.py
  2. 47
      timm_image.py

6
test_onnx.py

@ -11,8 +11,8 @@ import logging
import platform
import psutil
models = TimmImage.supported_model_names()[:2]
# models = ['resnet50']
# models = TimmImage.supported_model_names()[:2]
models = ['resnet50']
atol = 1e-3
log_path = 'timm_onnx.log'
@ -57,7 +57,7 @@ for name in models:
status = [name] + ['fail'] * 5
try:
out1 = op.model.forward_features(data).detach().numpy()
out1 = op.accelerate_model(data).detach().numpy()
logger.info('OP LOADED.')
status[1] = 'success'
except Exception as e:

47
timm_image.py

@ -23,6 +23,7 @@ from towhee.operator.base import NNOperator, OperatorFlag
from towhee.types.arg import arg, to_image_color
from towhee import register
from towhee.types import Image
from towhee.dc2 import accelerate
import torch
from torch import nn
@ -40,6 +41,17 @@ warnings.filterwarnings('ignore')
log = logging.getLogger('timm_op')
@accelerate
class Model:
def __init__(self, model_name, device, num_classes):
self.model = create_model(model_name, pretrained=True, num_classes=num_classes)
self.model.eval()
self.model.to(device)
def __call__(self, x: torch.Tensor):
return self.model.forward_features(x)
@register(output_schema=['vec'])
class TimmImage(NNOperator):
"""
@ -65,10 +77,12 @@ class TimmImage(NNOperator):
self.device = device
self.model_name = model_name
if self.model_name:
self.model = create_model(self.model_name, pretrained=True, num_classes=num_classes)
self.model.eval()
self.model.to(self.device)
self.accelerate_model = Model(
model_name=model_name,
device=self.device,
num_classes=num_classes
)
self.model = self.accelerate_model.model
self.config = resolve_data_config({}, model=self.model)
self.tfms = create_transform(**self.config)
self.skip_tfms = skip_preprocess
@ -88,7 +102,7 @@ class TimmImage(NNOperator):
img_list.append(img)
inputs = torch.stack(img_list)
inputs = inputs.to(self.device)
features = self.model.forward_features(inputs)
features = self.accelerate_model(inputs)
if features.dim() == 4:
global_pool = nn.AdaptiveAvgPool2d(1).to(self.device)
features = global_pool(features)
@ -111,12 +125,16 @@ class TimmImage(NNOperator):
os.makedirs(path, exist_ok=True)
name = self.model_name.replace('/', '-')
path = os.path.join(path, name)
if format in ['pytorch', 'torchscript']:
path = path + '.pt'
elif format == 'onnx':
path = path + '.onnx'
else:
raise AttributeError(f'Invalid format {format}.')
dummy_input = torch.rand((1,) + self.config['input_size'])
if format == 'pytorch':
path = path + '.pt'
torch.save(self.model, path)
elif format == 'torchscript':
path = path + '.pt'
try:
try:
jit_model = torch.jit.script(self.model)
@ -127,7 +145,6 @@ class TimmImage(NNOperator):
log.error(f'Fail to save as torchscript: {e}.')
raise RuntimeError(f'Fail to save as torchscript: {e}.')
elif format == 'onnx':
path = path + '.onnx'
self.model.forward = self.model.forward_features
try:
torch.onnx.export(self.model,
@ -227,11 +244,9 @@ class TimmImage(NNOperator):
log.error(f'Invalid format "{format}". Currently supported formats: "pytorch".')
return model_list
def input_schema(self):
return [(Image, (-1, -1, 3))]
def output_schema(self):
image = Image(numpy.random.randn(480, 480, 3), "RGB")
ret = self(image)
data_type = type(ret.reshape(-1)[0])
return [(data_type, ret.shape)]
@property
def supported_formats(self):
if self.model_name in self.supported_model_names(format='onnx'):
return ['onnx']
else:
return []

Loading…
Cancel
Save