logo
Browse Source

Allow init parameters issue & test save_model

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
061e46b82e
  1. 4
      __init__.py
  2. 52
      nn_fingerprint.py
  3. 35
      test.py

4
__init__.py

@ -15,5 +15,5 @@
from .nn_fingerprint import NNFingerprint from .nn_fingerprint import NNFingerprint
def nnfp():
return NNFingerprint()
def nnfp(**kwargs):
return NNFingerprint(**kwargs)

52
nn_fingerprint.py

@ -46,7 +46,7 @@ class NNFingerprint(NNOperator):
checkpoint_path: str = None, checkpoint_path: str = None,
framework: str = 'pytorch'): framework: str = 'pytorch'):
super().__init__(framework=framework) super().__init__(framework=framework)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
if params is None: if params is None:
self.params = default_params self.params = default_params
else: else:
@ -72,12 +72,16 @@ class NNFingerprint(NNOperator):
path = str(Path(__file__).parent) path = str(Path(__file__).parent)
checkpoint_path = os.path.join(path, 'saved_model', 'pfann_fma_m.pt') checkpoint_path = os.path.join(path, 'saved_model', 'pfann_fma_m.pt')
state_dict = torch.load(checkpoint_path, map_location=self.device) state_dict = torch.load(checkpoint_path, map_location=self.device)
if isinstance(state_dict, torch.nn.Module):
self.model = state_dict
else:
self.model.load_state_dict(state_dict) self.model.load_state_dict(state_dict)
self.model.eval() self.model.eval()
log.info('Model is loaded.') log.info('Model is loaded.')
def __call__(self, data: List[AudioFrame]) -> numpy.ndarray: def __call__(self, data: List[AudioFrame]) -> numpy.ndarray:
audio_tensors = self.preprocess(data).to(self.device) audio_tensors = self.preprocess(data).to(self.device)
# print(audio_tensors.shape)
features = self.model(audio_tensors) features = self.model(audio_tensors)
return features.detach().cpu().numpy() return features.detach().cpu().numpy()
@ -90,7 +94,6 @@ class NNFingerprint(NNOperator):
else: else:
audio = numpy.hstack(frames) audio = numpy.hstack(frames)
audio = audio[None, :] audio = audio[None, :]
audio = self.int2float(audio) audio = self.int2float(audio)
if sr != self.params['sample_rate']: if sr != self.params['sample_rate']:
@ -133,3 +136,48 @@ class NNFingerprint(NNOperator):
else: else:
log.warning('Converting float dtype from %s to %s.', wav.dtype, dtype) log.warning('Converting float dtype from %s to %s.', wav.dtype, dtype)
return wav.astype(dtype) return wav.astype(dtype)
def save_model(self, format: str='pytorch', path: str = 'default'):
if path == 'default':
path = str(Path(__file__).parent)
path = os.path.join(path, 'saved', format)
os.makedirs(path, exist_ok=True)
name = 'nnfp'
path = os.path.join(path, name)
dummy_input = torch.rand(
(1,) + (self.params['n_mels'], self.params['u'])
)
if format == 'pytorch':
path = path + '.pt'
torch.save(self.model, path)
elif format == 'torchscript':
path = path + '.pt'
try:
try:
jit_model = torch.jit.script(self.model)
except Exception:
jit_model = torch.jit.trace(self.model, dummy_input, strict=False)
torch.jit.save(jit_model, path)
except Exception as e:
log.error(f'Fail to save as torchscript: {e}.')
raise RuntimeError(f'Fail to save as torchscript: {e}.')
elif format == 'onnx':
path = path + '.onnx'
try:
torch.onnx.export(self.model,
dummy_input,
path,
input_names=['input'],
output_names=['output'],
opset_version=12,
do_constant_folding=True,
dynamic_axes={'input': {0: 'batch_size'},
'output': {0: 'batch_size'}
}
)
except Exception as e:
log.error(f'Fail to save as onnx: {e}.')
raise RuntimeError(f'Fail to save as onnx: {e}.')
# todo: elif format == 'tensorrt':
else:
log.error(f'Unsupported format "{format}".')

35
test.py

@ -0,0 +1,35 @@
from towhee import ops
import torch
import numpy
import onnx
import onnxruntime
def to_numpy(tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
# decode = ops.audio_decode.ffmpeg()
# audio = [x[0] for x in decode('path/to/audio.wav')]
audio = torch.rand(10, 256, 32)
op = ops.audio_embedding.nnfp()
out0 = op.get_op().model(audio)
# print(out0)
op.get_op().save_model(format='pytorch')
op = ops.audio_embedding.nnfp(checkpoint_path='./saved/pytorch/nnfp.pt')
out1 = op.get_op().model(audio)
assert((out0 == out1).all())
op.get_op().save_model(format='onnx')
op = ops.audio_embedding.nnfp()
onnx_model = onnx.load('./saved/onnx/nnfp.onnx')
onnx.checker.check_model(onnx_model)
ort_session = onnxruntime.InferenceSession('./saved/onnx/nnfp.onnx')
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(audio)}
ort_outs = ort_session.run(None, ort_inputs)
out2 = ort_outs[0]
# print(out2)
assert(numpy.allclose(to_numpy(out0), out2, rtol=1e-03, atol=1e-05))
Loading…
Cancel
Save