diff --git a/omnivore.py b/omnivore.py index fac3b53..1d2ab94 100644 --- a/omnivore.py +++ b/omnivore.py @@ -101,7 +101,7 @@ class Omnivore(NNOperator): outs = self.model(inputs) post_act = torch.nn.Softmax(dim=1) - preds = post_act(outs) + preds = post_act(outs,input_type="video") pred_scores, pred_classes = preds.topk(k=self.topk) labels = [self.classmap[int(i)] for i in pred_classes[0]] scores = [round(float(x), 5) for x in pred_scores[0]]