From 496ae1cd523ce2cf145bcae8049ac45759c7379b Mon Sep 17 00:00:00 2001 From: gexy5 Date: Mon, 13 Jun 2022 11:01:20 +0800 Subject: [PATCH] modify Signed-off-by: gexy5 --- tsm.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tsm.py b/tsm.py index 5fc9bd8..edab8f4 100644 --- a/tsm.py +++ b/tsm.py @@ -96,7 +96,13 @@ class Tsm(NNOperator): inputs = data.to(self.device)[None, ...] feats = self.model.forward_features(inputs) - features = feats.to('cpu').squeeze(0).detach().numpy() + if self.model.reshape: + if self.model.is_shift and self.model.temporal_pool: + base_out = feats.view((-1, self.model.num_segments // 2) + feats.size()[1:]) + else: + base_out = feats.view((-1, self.model.num_segments) + feats.size()[1:]) + output = self.model.consensus(base_out) + features = output.to('cpu').squeeze(0).detach().numpy() outs = self.model.head(feats) post_act = torch.nn.Softmax(dim=1)