logo
Browse Source

Support tritonserve

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
efb25154aa
  1. 54
      isc.py
  2. 14
      test_onnx.py

54
isc.py

@ -23,6 +23,7 @@ from towhee.operator.base import NNOperator, OperatorFlag
from towhee.types.arg import arg, to_image_color from towhee.types.arg import arg, to_image_color
from towhee import register from towhee import register
from towhee.models import isc from towhee.models import isc
# from towhee.dc2 import accelerate
import torch import torch
from torch import nn from torch import nn
@ -36,6 +37,20 @@ warnings.filterwarnings('ignore')
log = logging.getLogger('isc_op') log = logging.getLogger('isc_op')
# @accelerate
class Model:
def __init__(self, timm_backbone, checkpoint_path, device):
self.device = device
self.backbone = timm.create_model(timm_backbone, features_only=True, pretrained=False)
self.model = isc.create_model(pretrained=True, checkpoint_path=checkpoint_path, device=self.device,
backbone=self.backbone, p=3.0, eval_p=1.0)
self.model.eval()
def __call__(self, x):
x = x.to(self.device)
return self.model(x)
@register(output_schema=['vec']) @register(output_schema=['vec'])
class Isc(NNOperator): class Isc(NNOperator):
""" """
@ -62,16 +77,13 @@ class Isc(NNOperator):
if checkpoint_path is None: if checkpoint_path is None:
checkpoint_path = os.path.join(str(Path(__file__).parent), 'checkpoints', timm_backbone + '.pth') checkpoint_path = os.path.join(str(Path(__file__).parent), 'checkpoints', timm_backbone + '.pth')
backbone = timm.create_model(timm_backbone, features_only=True, pretrained=False)
self.model = isc.create_model(pretrained=True, checkpoint_path=checkpoint_path, device=self.device,
backbone=backbone, p=3.0, eval_p=1.0)
self.model.eval()
self.model = Model(self.timm_backbone, checkpoint_path, self.device)
self.tfms = transforms.Compose([ self.tfms = transforms.Compose([
transforms.Resize((img_size, img_size)), transforms.Resize((img_size, img_size)),
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize(mean=backbone.default_cfg['mean'],
std=backbone.default_cfg['std'])
transforms.Normalize(mean=self.backbone.default_cfg['mean'],
std=self.backbone.default_cfg['std'])
]) ])
def __call__(self, data: Union[List[towhee._types.Image], towhee._types.Image]): def __call__(self, data: Union[List[towhee._types.Image], towhee._types.Image]):
@ -96,6 +108,14 @@ class Isc(NNOperator):
vecs = features.squeeze(0).detach().numpy() vecs = features.squeeze(0).detach().numpy()
return vecs return vecs
@property
def _model(self):
return self.model.model
@property
def backbone(self):
return self.model.backbone
def save_model(self, format: str = 'pytorch', path: str = 'default'): def save_model(self, format: str = 'pytorch', path: str = 'default'):
if path == 'default': if path == 'default':
path = str(Path(__file__).parent) path = str(Path(__file__).parent)
@ -103,25 +123,28 @@ class Isc(NNOperator):
os.makedirs(path, exist_ok=True) os.makedirs(path, exist_ok=True)
name = self.timm_backbone.replace('/', '-') name = self.timm_backbone.replace('/', '-')
path = os.path.join(path, name) path = os.path.join(path, name)
if format in ['pytorch', 'torchscript']:
path = path + '.pt'
elif format == 'onnx':
path = path + '.onnx'
else:
raise ValueError(f'Invalid format {format}.')
dummy_input = torch.rand(1, 3, 224, 224) dummy_input = torch.rand(1, 3, 224, 224)
if format == 'pytorch': if format == 'pytorch':
path = path + '.pt'
torch.save(self.model, path)
torch.save(self._model, path)
elif format == 'torchscript': elif format == 'torchscript':
path = path + '.pt'
try: try:
try: try:
jit_model = torch.jit.script(self.model)
jit_model = torch.jit.script(self._model)
except Exception: except Exception:
jit_model = torch.jit.trace(self.model, dummy_input, strict=False)
jit_model = torch.jit.trace(self._model, dummy_input, strict=False)
torch.jit.save(jit_model, path) torch.jit.save(jit_model, path)
except Exception as e: except Exception as e:
log.error(f'Fail to save as torchscript: {e}.') log.error(f'Fail to save as torchscript: {e}.')
raise RuntimeError(f'Fail to save as torchscript: {e}.') raise RuntimeError(f'Fail to save as torchscript: {e}.')
elif format == 'onnx': elif format == 'onnx':
path = path + '.onnx'
try: try:
torch.onnx.export(self.model,
torch.onnx.export(self._model,
dummy_input, dummy_input,
path, path,
input_names=['input_0'], input_names=['input_0'],
@ -139,12 +162,17 @@ class Isc(NNOperator):
# todo: elif format == 'tensorrt': # todo: elif format == 'tensorrt':
else: else:
log.error(f'Unsupported format "{format}".') log.error(f'Unsupported format "{format}".')
return path
@arg(1, to_image_color('RGB')) @arg(1, to_image_color('RGB'))
def convert_img(self, img: towhee._types.Image): def convert_img(self, img: towhee._types.Image):
img = PILImage.fromarray(img.astype('uint8'), 'RGB') img = PILImage.fromarray(img.astype('uint8'), 'RGB')
return img return img
@property
def supported_formats(self):
return ['onnx']
# if __name__ == '__main__': # if __name__ == '__main__':
# from towhee import ops # from towhee import ops

14
test_onnx.py

@ -1,7 +1,6 @@
import onnx
from isc import Isc
import towhee
from towhee import ops from towhee import ops
import torch import torch
import numpy import numpy
import onnx import onnx
@ -46,17 +45,17 @@ for name in models:
onnx_path = f'saved/onnx/{saved_name}.onnx' onnx_path = f'saved/onnx/{saved_name}.onnx'
try: try:
op = Isc(timm_backbone=name, device='cpu')
op = ops.image_embedding.isc(timm_backbone=name, device='cpu').get_op()
except Exception as e: except Exception as e:
logger.error(f'Fail to load model {name}. Please check weights.') logger.error(f'Fail to load model {name}. Please check weights.')
data = torch.rand(1, 3, 224, 224)
data = torch.ones(1, 3, 224, 224)
if status: if status:
f.write(','.join(status) + '\n') f.write(','.join(status) + '\n')
status = [name] + ['fail'] * 5 status = [name] + ['fail'] * 5
try: try:
out1 = op.model(data).detach().numpy()
out1 = op.model(data).cpu().detach().numpy()
logger.info('OP LOADED.') logger.info('OP LOADED.')
status[1] = 'success' status[1] = 'success'
except Exception as e: except Exception as e:
@ -84,7 +83,8 @@ for name in models:
try: try:
sess = onnxruntime.InferenceSession(onnx_path, sess = onnxruntime.InferenceSession(onnx_path,
providers=onnxruntime.get_available_providers()) providers=onnxruntime.get_available_providers())
out2 = sess.run(None, input_feed={'input_0': data.detach().numpy()})
data = data.cpu().detach().numpy()
out2 = sess.run(None, input_feed={'input_0': data})
logger.info('ONNX WORKED.') logger.info('ONNX WORKED.')
status[4] = 'success' status[4] = 'success'
if numpy.allclose(out1, out2, atol=atol): if numpy.allclose(out1, out2, atol=atol):

Loading…
Cancel
Save