diff --git a/omnivore.py b/omnivore.py index 4759a75..ebebc72 100644 --- a/omnivore.py +++ b/omnivore.py @@ -101,7 +101,7 @@ class Omnivore(NNOperator): ) inputs = data.to(self.device)[None, ...] - feats = self.model.forward_features(inputs, input_type = self.input_type) + feats = self.model.forward_features(inputs) features = feats.to('cpu').squeeze(0).detach().numpy() outs = self.model.head(feats, input_type = self.input_type)