Browse Source
Support device
Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
1 changed files with
5 additions and
1 deletions
-
panns.py
|
|
@ -43,9 +43,13 @@ class Panns(NNOperator): |
|
|
|
weights_path: str = None, |
|
|
|
framework: str = 'pytorch', |
|
|
|
sample_rate: int = 32000, |
|
|
|
device: str = None, |
|
|
|
topk: int = 5): |
|
|
|
super().__init__(framework=framework) |
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
if device: |
|
|
|
self.device = device |
|
|
|
else: |
|
|
|
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
self.sample_rate = sample_rate |
|
|
|
self.topk = topk |
|
|
|
# checkpoint_path=None will download model weights with default url |
|
|
|