diff --git a/omnivore.py b/omnivore.py index 584c919..87b79ed 100644 --- a/omnivore.py +++ b/omnivore.py @@ -41,12 +41,14 @@ class Omnivore(NNOperator): def __init__(self, model_name: str = 'omnivore_swinT', framework: str = 'pytorch', + input_type: str = 'video', skip_preprocess: bool = False, classmap: dict = None, topk: int = 5, ): super().__init__(framework=framework) self.model_name = model_name + self.input_type = input_type self.skip_preprocess = skip_preprocess self.topk = topk self.dataset_name = 'kinetics_400' @@ -99,11 +101,14 @@ class Omnivore(NNOperator): ) inputs = data.to(self.device)[None, ...] - outs = self.model(inputs,input_type="video") + feats = self.model.forward_features(inputs ,input_type = self.input_type) + features = feats.to('cpu').squeeze(0).detach().numpy() + + outs = self.model.head(feats) post_act = torch.nn.Softmax(dim=1) preds = post_act(outs) pred_scores, pred_classes = preds.topk(k=self.topk) labels = [self.classmap[int(i)] for i in pred_classes[0]] scores = [round(float(x), 5) for x in pred_scores[0]] - print(labels,scores) - return labels, scores + + return labels, scores, features