logo
Browse Source

Support TritonServe

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 1 year ago
parent
commit
ec4aee0a13
  1. 10
      README.md
  2. 78
      nn_fingerprint.py
  3. 2
      test_onnx.py

10
README.md

@ -110,4 +110,12 @@ Accepted formats: 'pytorch', 'torchscript, 'onnx', 'tensorrt' (in progress)
*path: str*
Path to save model, defaults to 'default'.
The default path is under 'saved' in the same directory of operator cache.
The default path is under 'saved' in the same directory of operator cache.
```python
from towhee import ops
op = ops.audio_embedding.nnfp(device='cpu').get_op()
op.save_model('onnx', 'test.onnx')
```
PosixPath('/Home/.towhee/operators/audio-embedding/nnfp/main/test.onnx')

78
nn_fingerprint.py

@ -28,6 +28,7 @@ from towhee import register
from towhee.types.audio_frame import AudioFrame
from towhee.models.nnfp import NNFp
from towhee.models.utils.audio_preprocess import preprocess_wav, MelSpec
# from towhee.dc2 import accelerate
from .configs import default_params, hop25_params, distill_params
@ -35,23 +36,11 @@ warnings.filterwarnings('ignore')
log = logging.getLogger('nnfp_op')
# @accelerate
class Model:
def __init__(self, model_name, device='cpu', model_path=None):
def __init__(self, params, device='cpu', model_path=None):
self.device = device
if model_name == 'nnfp_default':
self.params = default_params
elif model_name == 'nnfp_hop25':
self.params = hop25_params
elif model_name == 'nnfp_distill':
self.params == distill_params
else:
raise ValueError('Invalid model name. Accept value from ["nnfp_default", "nnfp_hop25", "nnfp_distill"]')
log.info('Loading model...')
if model_path is None:
path = str(Path(__file__).parent)
model_path = os.path.join(path, 'saved_model', 'nnfp_fma.pt')
try:
state_dict = torch.jit.load(model_path, map_location=self.device)
except Exception:
@ -59,18 +48,18 @@ class Model:
if isinstance(state_dict, torch.nn.Module):
self.model = state_dict
else:
dim = self.params['dim']
h = self.params['h']
u = self.params['u']
f_bin = self.params['n_mels']
n_seg = int(self.params['segment_size'] * self.params['sample_rate'])
t = (n_seg + self.params['hop_length'] - 1) // self.params['hop_length']
dim = params['dim']
h = params['h']
u = params['u']
f_bin = params['n_mels']
n_seg = int(params['segment_size'] * params['sample_rate'])
t = (n_seg + params['hop_length'] - 1) // params['hop_length']
log.info('Creating model with parameters...')
self.model = NNFp(
dim=dim, h=h, u=u,
in_f=f_bin, in_t=t,
fuller=self.params['fuller'],
activation=self.params['activation']
fuller=params['fuller'],
activation=params['activation']
).to(self.device)
self.model.load_state_dict(state_dict)
self.model.eval()
@ -96,19 +85,34 @@ class NNFingerprint(NNOperator):
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = device
self.model_name = model_name
self.accelerate_model = Model(model_name=model_name, device=self.device, model_path=model_path)
self.model = self.accelerate_model.model
self.params = self.accelerate_model.params
if model_name == 'nnfp_default':
self.params = default_params
elif model_name == 'nnfp_hop25':
self.params = hop25_params
elif model_name == 'nnfp_distill':
self.params == distill_params
else:
raise ValueError('Invalid model name. Accept value from ["nnfp_default", "nnfp_hop25", "nnfp_distill"]')
if model_path is None:
path = str(Path(__file__).parent)
model_path = os.path.join(path, 'saved_model', 'nnfp_fma.pt')
self.model = Model(params=self.params, device=self.device, model_path=model_path)
def __call__(self, data: Union[str, List[AudioFrame]]) -> numpy.ndarray:
audio_tensors = self.preprocess(data)
if audio_tensors.device != self.device:
audio_tensors = audio_tensors.to(self.device)
# print(audio_tensors.shape)
features = self.accelerate_model(audio_tensors)
features = self.model(audio_tensors)
outs = features.detach().cpu().numpy()
return outs
@property
def _model(self):
return self.model.model
def preprocess(self, frames: Union[str, List[AudioFrame]]):
if isinstance(frames, str):
audio, sr = torchaudio.load(frames)
@ -176,6 +180,10 @@ class NNFingerprint(NNOperator):
log.warning('Converting float dtype from %s to %s.', wav.dtype, dtype)
return wav.astype(dtype)
@property
def supported_formats(self):
return ['onnx']
def save_model(self, format: str = 'pytorch', path: str = 'default'):
if path == 'default':
path = str(Path(__file__).parent)
@ -183,30 +191,33 @@ class NNFingerprint(NNOperator):
os.makedirs(path, exist_ok=True)
name = self.model_name.replace('/', '-')
path = os.path.join(path, name)
if format in ['torchscript', 'pytorch']:
path = path + '.pt'
elif format == 'onnx':
path = path + '.onnx'
else:
raise ValueError(f'Invalid format {format}.')
dummy_input = torch.rand(
(1,) + (self.params['n_mels'], self.params['u'])
).to(self.device)
if format == 'pytorch':
path = path + '.pt'
torch.save(self.model, path)
torch.save(self._model, path)
elif format == 'torchscript':
path = path + '.pt'
try:
try:
jit_model = torch.jit.script(self.model)
jit_model = torch.jit.script(self._model)
except Exception:
log.warning(
'Failed to directly export as torchscript.'
'Using dummy input in shape of %s now.', dummy_input.shape)
jit_model = torch.jit.trace(self.model, dummy_input, strict=False)
jit_model = torch.jit.trace(self._model, dummy_input, strict=False)
torch.jit.save(jit_model, path)
except Exception as e:
log.error('Fail to save as torchscript: %s.', e)
raise RuntimeError(f'Fail to save as torchscript: {e}.')
elif format == 'onnx':
path = path + '.onnx'
try:
torch.onnx.export(self.model,
torch.onnx.export(self._model,
dummy_input,
path,
input_names=['input'],
@ -223,3 +234,4 @@ class NNFingerprint(NNOperator):
# todo: elif format == 'tensorrt':
else:
log.error(f'Unsupported format "{format}".')
return Path(path).resolve()

2
test_onnx.py

@ -60,7 +60,7 @@ for name in models:
logger.error(f'FAIL TO LOAD OP: {e}')
continue
try:
op.save_model(format='onnx')
op.save_model(format='onnx', path=onnx_path)
logger.info('ONNX SAVED.')
status[2] = 'success'
except Exception as e:

Loading…
Cancel
Save