Browse Source
Update
Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
1 changed files with
7 additions and
2 deletions
-
nn_fingerprint.py
|
|
@ -46,9 +46,12 @@ class NNFingerprint(NNOperator): |
|
|
|
params: dict = None, |
|
|
|
model_path: str = None, |
|
|
|
framework: str = 'pytorch', |
|
|
|
device: str = None |
|
|
|
): |
|
|
|
super().__init__(framework=framework) |
|
|
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
if device is None: |
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
self.device = device |
|
|
|
if params is None: |
|
|
|
self.params = default_params |
|
|
|
else: |
|
|
@ -87,7 +90,9 @@ class NNFingerprint(NNOperator): |
|
|
|
log.info('Model is loaded.') |
|
|
|
|
|
|
|
def __call__(self, data: List[AudioFrame]) -> numpy.ndarray: |
|
|
|
audio_tensors = self.preprocess(data).to(self.device) |
|
|
|
audio_tensors = self.preprocess(data) |
|
|
|
if audio_tensors.device != self.device: |
|
|
|
audio_tensors = audio_tensors.to(self.device) |
|
|
|
# print(audio_tensors.shape) |
|
|
|
if isinstance(self.model, onnxruntime.InferenceSession): |
|
|
|
audio_numpy = audio_tensors.detach().cpu().numpy() if audio_tensors.requires_grad \ |
|
|
|