logo
Browse Source

Support TritonServe

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
ec4aee0a13
  1. 8
      README.md
  2. 78
      nn_fingerprint.py
  3. 2
      test_onnx.py

8
README.md

@ -111,3 +111,11 @@ Accepted formats: 'pytorch', 'torchscript, 'onnx', 'tensorrt' (in progress)
Path to save model, defaults to 'default'.
The default path is under 'saved' in the same directory of operator cache.
```python
from towhee import ops
op = ops.audio_embedding.nnfp(device='cpu').get_op()
op.save_model('onnx', 'test.onnx')
```
PosixPath('/Home/.towhee/operators/audio-embedding/nnfp/main/test.onnx')

78
nn_fingerprint.py

@ -28,6 +28,7 @@ from towhee import register
from towhee.types.audio_frame import AudioFrame
from towhee.models.nnfp import NNFp
from towhee.models.utils.audio_preprocess import preprocess_wav, MelSpec
# from towhee.dc2 import accelerate
from .configs import default_params, hop25_params, distill_params
@ -35,23 +36,11 @@ warnings.filterwarnings('ignore')
log = logging.getLogger('nnfp_op')
# @accelerate
class Model:
def __init__(self, model_name, device='cpu', model_path=None):
def __init__(self, params, device='cpu', model_path=None):
self.device = device
if model_name == 'nnfp_default':
self.params = default_params
elif model_name == 'nnfp_hop25':
self.params = hop25_params
elif model_name == 'nnfp_distill':
self.params == distill_params
else:
raise ValueError('Invalid model name. Accept value from ["nnfp_default", "nnfp_hop25", "nnfp_distill"]')
log.info('Loading model...')
if model_path is None:
path = str(Path(__file__).parent)
model_path = os.path.join(path, 'saved_model', 'nnfp_fma.pt')
try:
state_dict = torch.jit.load(model_path, map_location=self.device)
except Exception:
@ -59,18 +48,18 @@ class Model:
if isinstance(state_dict, torch.nn.Module):
self.model = state_dict
else:
dim = self.params['dim']
h = self.params['h']
u = self.params['u']
f_bin = self.params['n_mels']
n_seg = int(self.params['segment_size'] * self.params['sample_rate'])
t = (n_seg + self.params['hop_length'] - 1) // self.params['hop_length']
dim = params['dim']
h = params['h']
u = params['u']
f_bin = params['n_mels']
n_seg = int(params['segment_size'] * params['sample_rate'])
t = (n_seg + params['hop_length'] - 1) // params['hop_length']
log.info('Creating model with parameters...')
self.model = NNFp(
dim=dim, h=h, u=u,
in_f=f_bin, in_t=t,
fuller=self.params['fuller'],
activation=self.params['activation']
fuller=params['fuller'],
activation=params['activation']
).to(self.device)
self.model.load_state_dict(state_dict)
self.model.eval()
@ -96,19 +85,34 @@ class NNFingerprint(NNOperator):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = device
self.model_name = model_name
self.accelerate_model = Model(model_name=model_name, device=self.device, model_path=model_path)
self.model = self.accelerate_model.model
self.params = self.accelerate_model.params
if model_name == 'nnfp_default':
self.params = default_params
elif model_name == 'nnfp_hop25':
self.params = hop25_params
elif model_name == 'nnfp_distill':
self.params == distill_params
else:
raise ValueError('Invalid model name. Accept value from ["nnfp_default", "nnfp_hop25", "nnfp_distill"]')
if model_path is None:
path = str(Path(__file__).parent)
model_path = os.path.join(path, 'saved_model', 'nnfp_fma.pt')
self.model = Model(params=self.params, device=self.device, model_path=model_path)
def __call__(self, data: Union[str, List[AudioFrame]]) -> numpy.ndarray:
audio_tensors = self.preprocess(data)
if audio_tensors.device != self.device:
audio_tensors = audio_tensors.to(self.device)
# print(audio_tensors.shape)
features = self.accelerate_model(audio_tensors)
features = self.model(audio_tensors)
outs = features.detach().cpu().numpy()
return outs
@property
def _model(self):
return self.model.model
def preprocess(self, frames: Union[str, List[AudioFrame]]):
if isinstance(frames, str):
audio, sr = torchaudio.load(frames)
@ -176,6 +180,10 @@ class NNFingerprint(NNOperator):
log.warning('Converting float dtype from %s to %s.', wav.dtype, dtype)
return wav.astype(dtype)
@property
def supported_formats(self):
return ['onnx']
def save_model(self, format: str = 'pytorch', path: str = 'default'):
if path == 'default':
path = str(Path(__file__).parent)
@ -183,30 +191,33 @@ class NNFingerprint(NNOperator):
os.makedirs(path, exist_ok=True)
name = self.model_name.replace('/', '-')
path = os.path.join(path, name)
if format in ['torchscript', 'pytorch']:
path = path + '.pt'
elif format == 'onnx':
path = path + '.onnx'
else:
raise ValueError(f'Invalid format {format}.')
dummy_input = torch.rand(
(1,) + (self.params['n_mels'], self.params['u'])
).to(self.device)
if format == 'pytorch':
path = path + '.pt'
torch.save(self.model, path)
torch.save(self._model, path)
elif format == 'torchscript':
path = path + '.pt'
try:
try:
jit_model = torch.jit.script(self.model)
jit_model = torch.jit.script(self._model)
except Exception:
log.warning(
'Failed to directly export as torchscript.'
'Using dummy input in shape of %s now.', dummy_input.shape)
jit_model = torch.jit.trace(self.model, dummy_input, strict=False)
jit_model = torch.jit.trace(self._model, dummy_input, strict=False)
torch.jit.save(jit_model, path)
except Exception as e:
log.error('Fail to save as torchscript: %s.', e)
raise RuntimeError(f'Fail to save as torchscript: {e}.')
elif format == 'onnx':
path = path + '.onnx'
try:
torch.onnx.export(self.model,
torch.onnx.export(self._model,
dummy_input,
path,
input_names=['input'],
@ -223,3 +234,4 @@ class NNFingerprint(NNOperator):
# todo: elif format == 'tensorrt':
else:
log.error(f'Unsupported format "{format}".')
return Path(path).resolve()

2
test_onnx.py

@ -60,7 +60,7 @@ for name in models:
logger.error(f'FAIL TO LOAD OP: {e}')
continue
try:
op.save_model(format='onnx')
op.save_model(format='onnx', path=onnx_path)
logger.info('ONNX SAVED.')
status[2] = 'success'
except Exception as e:

Loading…
Cancel
Save