logo
Browse Source

Support device

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
e8eb11b362
  1. 6
      panns.py

6
panns.py

@ -43,9 +43,13 @@ class Panns(NNOperator):
weights_path: str = None,
framework: str = 'pytorch',
sample_rate: int = 32000,
device: str = None,
topk: int = 5):
super().__init__(framework=framework)
self.device = "cuda" if torch.cuda.is_available() else "cpu"
if device:
self.device = device
else:
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
self.sample_rate = sample_rate
self.topk = topk
# checkpoint_path=None will download model weights with default url

Loading…
Cancel
Save