logo
Browse Source

Support TritonServe

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
ec4aee0a13
  1. 10
      README.md
  2. 78
      nn_fingerprint.py
  3. 2
      test_onnx.py

10
README.md

@ -110,4 +110,12 @@ Accepted formats: 'pytorch', 'torchscript, 'onnx', 'tensorrt' (in progress)
*path: str* *path: str*
Path to save model, defaults to 'default'. Path to save model, defaults to 'default'.
The default path is under 'saved' in the same directory of operator cache.
The default path is under 'saved' in the same directory of operator cache.
```python
from towhee import ops
op = ops.audio_embedding.nnfp(device='cpu').get_op()
op.save_model('onnx', 'test.onnx')
```
PosixPath('/Home/.towhee/operators/audio-embedding/nnfp/main/test.onnx')

78
nn_fingerprint.py

@ -28,6 +28,7 @@ from towhee import register
from towhee.types.audio_frame import AudioFrame from towhee.types.audio_frame import AudioFrame
from towhee.models.nnfp import NNFp from towhee.models.nnfp import NNFp
from towhee.models.utils.audio_preprocess import preprocess_wav, MelSpec from towhee.models.utils.audio_preprocess import preprocess_wav, MelSpec
# from towhee.dc2 import accelerate
from .configs import default_params, hop25_params, distill_params from .configs import default_params, hop25_params, distill_params
@ -35,23 +36,11 @@ warnings.filterwarnings('ignore')
log = logging.getLogger('nnfp_op') log = logging.getLogger('nnfp_op')
# @accelerate
class Model: class Model:
def __init__(self, model_name, device='cpu', model_path=None):
def __init__(self, params, device='cpu', model_path=None):
self.device = device self.device = device
if model_name == 'nnfp_default':
self.params = default_params
elif model_name == 'nnfp_hop25':
self.params = hop25_params
elif model_name == 'nnfp_distill':
self.params == distill_params
else:
raise ValueError('Invalid model name. Accept value from ["nnfp_default", "nnfp_hop25", "nnfp_distill"]')
log.info('Loading model...') log.info('Loading model...')
if model_path is None:
path = str(Path(__file__).parent)
model_path = os.path.join(path, 'saved_model', 'nnfp_fma.pt')
try: try:
state_dict = torch.jit.load(model_path, map_location=self.device) state_dict = torch.jit.load(model_path, map_location=self.device)
except Exception: except Exception:
@ -59,18 +48,18 @@ class Model:
if isinstance(state_dict, torch.nn.Module): if isinstance(state_dict, torch.nn.Module):
self.model = state_dict self.model = state_dict
else: else:
dim = self.params['dim']
h = self.params['h']
u = self.params['u']
f_bin = self.params['n_mels']
n_seg = int(self.params['segment_size'] * self.params['sample_rate'])
t = (n_seg + self.params['hop_length'] - 1) // self.params['hop_length']
dim = params['dim']
h = params['h']
u = params['u']
f_bin = params['n_mels']
n_seg = int(params['segment_size'] * params['sample_rate'])
t = (n_seg + params['hop_length'] - 1) // params['hop_length']
log.info('Creating model with parameters...') log.info('Creating model with parameters...')
self.model = NNFp( self.model = NNFp(
dim=dim, h=h, u=u, dim=dim, h=h, u=u,
in_f=f_bin, in_t=t, in_f=f_bin, in_t=t,
fuller=self.params['fuller'],
activation=self.params['activation']
fuller=params['fuller'],
activation=params['activation']
).to(self.device) ).to(self.device)
self.model.load_state_dict(state_dict) self.model.load_state_dict(state_dict)
self.model.eval() self.model.eval()
@ -96,19 +85,34 @@ class NNFingerprint(NNOperator):
device = 'cuda' if torch.cuda.is_available() else 'cpu' device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = device self.device = device
self.model_name = model_name self.model_name = model_name
self.accelerate_model = Model(model_name=model_name, device=self.device, model_path=model_path)
self.model = self.accelerate_model.model
self.params = self.accelerate_model.params
if model_name == 'nnfp_default':
self.params = default_params
elif model_name == 'nnfp_hop25':
self.params = hop25_params
elif model_name == 'nnfp_distill':
self.params == distill_params
else:
raise ValueError('Invalid model name. Accept value from ["nnfp_default", "nnfp_hop25", "nnfp_distill"]')
if model_path is None:
path = str(Path(__file__).parent)
model_path = os.path.join(path, 'saved_model', 'nnfp_fma.pt')
self.model = Model(params=self.params, device=self.device, model_path=model_path)
def __call__(self, data: Union[str, List[AudioFrame]]) -> numpy.ndarray: def __call__(self, data: Union[str, List[AudioFrame]]) -> numpy.ndarray:
audio_tensors = self.preprocess(data) audio_tensors = self.preprocess(data)
if audio_tensors.device != self.device: if audio_tensors.device != self.device:
audio_tensors = audio_tensors.to(self.device) audio_tensors = audio_tensors.to(self.device)
# print(audio_tensors.shape) # print(audio_tensors.shape)
features = self.accelerate_model(audio_tensors)
features = self.model(audio_tensors)
outs = features.detach().cpu().numpy() outs = features.detach().cpu().numpy()
return outs return outs
@property
def _model(self):
return self.model.model
def preprocess(self, frames: Union[str, List[AudioFrame]]): def preprocess(self, frames: Union[str, List[AudioFrame]]):
if isinstance(frames, str): if isinstance(frames, str):
audio, sr = torchaudio.load(frames) audio, sr = torchaudio.load(frames)
@ -176,6 +180,10 @@ class NNFingerprint(NNOperator):
log.warning('Converting float dtype from %s to %s.', wav.dtype, dtype) log.warning('Converting float dtype from %s to %s.', wav.dtype, dtype)
return wav.astype(dtype) return wav.astype(dtype)
@property
def supported_formats(self):
return ['onnx']
def save_model(self, format: str = 'pytorch', path: str = 'default'): def save_model(self, format: str = 'pytorch', path: str = 'default'):
if path == 'default': if path == 'default':
path = str(Path(__file__).parent) path = str(Path(__file__).parent)
@ -183,30 +191,33 @@ class NNFingerprint(NNOperator):
os.makedirs(path, exist_ok=True) os.makedirs(path, exist_ok=True)
name = self.model_name.replace('/', '-') name = self.model_name.replace('/', '-')
path = os.path.join(path, name) path = os.path.join(path, name)
if format in ['torchscript', 'pytorch']:
path = path + '.pt'
elif format == 'onnx':
path = path + '.onnx'
else:
raise ValueError(f'Invalid format {format}.')
dummy_input = torch.rand( dummy_input = torch.rand(
(1,) + (self.params['n_mels'], self.params['u']) (1,) + (self.params['n_mels'], self.params['u'])
).to(self.device) ).to(self.device)
if format == 'pytorch': if format == 'pytorch':
path = path + '.pt'
torch.save(self.model, path)
torch.save(self._model, path)
elif format == 'torchscript': elif format == 'torchscript':
path = path + '.pt'
try: try:
try: try:
jit_model = torch.jit.script(self.model)
jit_model = torch.jit.script(self._model)
except Exception: except Exception:
log.warning( log.warning(
'Failed to directly export as torchscript.' 'Failed to directly export as torchscript.'
'Using dummy input in shape of %s now.', dummy_input.shape) 'Using dummy input in shape of %s now.', dummy_input.shape)
jit_model = torch.jit.trace(self.model, dummy_input, strict=False)
jit_model = torch.jit.trace(self._model, dummy_input, strict=False)
torch.jit.save(jit_model, path) torch.jit.save(jit_model, path)
except Exception as e: except Exception as e:
log.error('Fail to save as torchscript: %s.', e) log.error('Fail to save as torchscript: %s.', e)
raise RuntimeError(f'Fail to save as torchscript: {e}.') raise RuntimeError(f'Fail to save as torchscript: {e}.')
elif format == 'onnx': elif format == 'onnx':
path = path + '.onnx'
try: try:
torch.onnx.export(self.model,
torch.onnx.export(self._model,
dummy_input, dummy_input,
path, path,
input_names=['input'], input_names=['input'],
@ -223,3 +234,4 @@ class NNFingerprint(NNOperator):
# todo: elif format == 'tensorrt': # todo: elif format == 'tensorrt':
else: else:
log.error(f'Unsupported format "{format}".') log.error(f'Unsupported format "{format}".')
return Path(path).resolve()

2
test_onnx.py

@ -60,7 +60,7 @@ for name in models:
logger.error(f'FAIL TO LOAD OP: {e}') logger.error(f'FAIL TO LOAD OP: {e}')
continue continue
try: try:
op.save_model(format='onnx')
op.save_model(format='onnx', path=onnx_path)
logger.info('ONNX SAVED.') logger.info('ONNX SAVED.')
status[2] = 'success' status[2] = 'success'
except Exception as e: except Exception as e:

Loading…
Cancel
Save