logo
Browse Source

Update

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
2604d10a69
  1. 9
      nn_fingerprint.py

9
nn_fingerprint.py

@ -46,9 +46,12 @@ class NNFingerprint(NNOperator):
params: dict = None, params: dict = None,
model_path: str = None, model_path: str = None,
framework: str = 'pytorch', framework: str = 'pytorch',
device: str = None
): ):
super().__init__(framework=framework) super().__init__(framework=framework)
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
if device is None:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.device = device
if params is None: if params is None:
self.params = default_params self.params = default_params
else: else:
@ -87,7 +90,9 @@ class NNFingerprint(NNOperator):
log.info('Model is loaded.') log.info('Model is loaded.')
def __call__(self, data: List[AudioFrame]) -> numpy.ndarray: def __call__(self, data: List[AudioFrame]) -> numpy.ndarray:
audio_tensors = self.preprocess(data).to(self.device)
audio_tensors = self.preprocess(data)
if audio_tensors.device != self.device:
audio_tensors = audio_tensors.to(self.device)
# print(audio_tensors.shape) # print(audio_tensors.shape)
if isinstance(self.model, onnxruntime.InferenceSession): if isinstance(self.model, onnxruntime.InferenceSession):
audio_numpy = audio_tensors.detach().cpu().numpy() if audio_tensors.requires_grad \ audio_numpy = audio_tensors.detach().cpu().numpy() if audio_tensors.requires_grad \

Loading…
Cancel
Save