diff --git a/panns.py b/panns.py index 07e8877..b6af5c0 100644 --- a/panns.py +++ b/panns.py @@ -43,9 +43,13 @@ class Panns(NNOperator): weights_path: str = None, framework: str = 'pytorch', sample_rate: int = 32000, + device: str = None, topk: int = 5): super().__init__(framework=framework) - self.device = "cuda" if torch.cuda.is_available() else "cpu" + if device: + self.device = device + else: + self.device = 'cuda' if torch.cuda.is_available() else 'cpu' self.sample_rate = sample_rate self.topk = topk # checkpoint_path=None will download model weights with default url