logo
Browse Source

Support TritonServe

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
a8b3af6f10
  1. 6
      test_onnx.py
  2. 47
      timm_image.py

6
test_onnx.py

@ -11,8 +11,8 @@ import logging
import platform import platform
import psutil import psutil
models = TimmImage.supported_model_names()[:2]
# models = ['resnet50']
# models = TimmImage.supported_model_names()[:2]
models = ['resnet50']
atol = 1e-3 atol = 1e-3
log_path = 'timm_onnx.log' log_path = 'timm_onnx.log'
@ -57,7 +57,7 @@ for name in models:
status = [name] + ['fail'] * 5 status = [name] + ['fail'] * 5
try: try:
out1 = op.model.forward_features(data).detach().numpy()
out1 = op.accelerate_model(data).detach().numpy()
logger.info('OP LOADED.') logger.info('OP LOADED.')
status[1] = 'success' status[1] = 'success'
except Exception as e: except Exception as e:

47
timm_image.py

@ -23,6 +23,7 @@ from towhee.operator.base import NNOperator, OperatorFlag
from towhee.types.arg import arg, to_image_color from towhee.types.arg import arg, to_image_color
from towhee import register from towhee import register
from towhee.types import Image from towhee.types import Image
from towhee.dc2 import accelerate
import torch import torch
from torch import nn from torch import nn
@ -40,6 +41,17 @@ warnings.filterwarnings('ignore')
log = logging.getLogger('timm_op') log = logging.getLogger('timm_op')
@accelerate
class Model:
def __init__(self, model_name, device, num_classes):
self.model = create_model(model_name, pretrained=True, num_classes=num_classes)
self.model.eval()
self.model.to(device)
def __call__(self, x: torch.Tensor):
return self.model.forward_features(x)
@register(output_schema=['vec']) @register(output_schema=['vec'])
class TimmImage(NNOperator): class TimmImage(NNOperator):
""" """
@ -65,10 +77,12 @@ class TimmImage(NNOperator):
self.device = device self.device = device
self.model_name = model_name self.model_name = model_name
if self.model_name: if self.model_name:
self.model = create_model(self.model_name, pretrained=True, num_classes=num_classes)
self.model.eval()
self.model.to(self.device)
self.accelerate_model = Model(
model_name=model_name,
device=self.device,
num_classes=num_classes
)
self.model = self.accelerate_model.model
self.config = resolve_data_config({}, model=self.model) self.config = resolve_data_config({}, model=self.model)
self.tfms = create_transform(**self.config) self.tfms = create_transform(**self.config)
self.skip_tfms = skip_preprocess self.skip_tfms = skip_preprocess
@ -88,7 +102,7 @@ class TimmImage(NNOperator):
img_list.append(img) img_list.append(img)
inputs = torch.stack(img_list) inputs = torch.stack(img_list)
inputs = inputs.to(self.device) inputs = inputs.to(self.device)
features = self.model.forward_features(inputs)
features = self.accelerate_model(inputs)
if features.dim() == 4: if features.dim() == 4:
global_pool = nn.AdaptiveAvgPool2d(1).to(self.device) global_pool = nn.AdaptiveAvgPool2d(1).to(self.device)
features = global_pool(features) features = global_pool(features)
@ -111,12 +125,16 @@ class TimmImage(NNOperator):
os.makedirs(path, exist_ok=True) os.makedirs(path, exist_ok=True)
name = self.model_name.replace('/', '-') name = self.model_name.replace('/', '-')
path = os.path.join(path, name) path = os.path.join(path, name)
if format in ['pytorch', 'torchscript']:
path = path + '.pt'
elif format == 'onnx':
path = path + '.onnx'
else:
raise AttributeError(f'Invalid format {format}.')
dummy_input = torch.rand((1,) + self.config['input_size']) dummy_input = torch.rand((1,) + self.config['input_size'])
if format == 'pytorch': if format == 'pytorch':
path = path + '.pt'
torch.save(self.model, path) torch.save(self.model, path)
elif format == 'torchscript': elif format == 'torchscript':
path = path + '.pt'
try: try:
try: try:
jit_model = torch.jit.script(self.model) jit_model = torch.jit.script(self.model)
@ -127,7 +145,6 @@ class TimmImage(NNOperator):
log.error(f'Fail to save as torchscript: {e}.') log.error(f'Fail to save as torchscript: {e}.')
raise RuntimeError(f'Fail to save as torchscript: {e}.') raise RuntimeError(f'Fail to save as torchscript: {e}.')
elif format == 'onnx': elif format == 'onnx':
path = path + '.onnx'
self.model.forward = self.model.forward_features self.model.forward = self.model.forward_features
try: try:
torch.onnx.export(self.model, torch.onnx.export(self.model,
@ -227,11 +244,9 @@ class TimmImage(NNOperator):
log.error(f'Invalid format "{format}". Currently supported formats: "pytorch".') log.error(f'Invalid format "{format}". Currently supported formats: "pytorch".')
return model_list return model_list
def input_schema(self):
return [(Image, (-1, -1, 3))]
def output_schema(self):
image = Image(numpy.random.randn(480, 480, 3), "RGB")
ret = self(image)
data_type = type(ret.reshape(-1)[0])
return [(data_type, ret.shape)]
@property
def supported_formats(self):
if self.model_name in self.supported_model_names(format='onnx'):
return ['onnx']
else:
return []

Loading…
Cancel
Save