From 219e9a8c45825da2e8b6961e473c1e365a7f6362 Mon Sep 17 00:00:00 2001 From: gexy5 Date: Tue, 14 Jun 2022 15:25:56 +0800 Subject: [PATCH] add omnivore Signed-off-by: gexy5 --- omnivore.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/omnivore.py b/omnivore.py index 87b79ed..4759a75 100644 --- a/omnivore.py +++ b/omnivore.py @@ -101,10 +101,10 @@ class Omnivore(NNOperator): ) inputs = data.to(self.device)[None, ...] - feats = self.model.forward_features(inputs ,input_type = self.input_type) + feats = self.model.forward_features(inputs, input_type = self.input_type) features = feats.to('cpu').squeeze(0).detach().numpy() - outs = self.model.head(feats) + outs = self.model.head(feats, input_type = self.input_type) post_act = torch.nn.Softmax(dim=1) preds = post_act(outs) pred_scores, pred_classes = preds.topk(k=self.topk)