logo
Browse Source

add omnivore

Signed-off-by: gexy5 <xinyu.ge@zilliz.com>
main
gexy5 2 years ago
parent
commit
d214180cf4
  1. 11
      omnivore.py

11
omnivore.py

@ -41,12 +41,14 @@ class Omnivore(NNOperator):
def __init__(self,
model_name: str = 'omnivore_swinT',
framework: str = 'pytorch',
input_type: str = 'video',
skip_preprocess: bool = False,
classmap: dict = None,
topk: int = 5,
):
super().__init__(framework=framework)
self.model_name = model_name
self.input_type = input_type
self.skip_preprocess = skip_preprocess
self.topk = topk
self.dataset_name = 'kinetics_400'
@ -99,11 +101,14 @@ class Omnivore(NNOperator):
)
inputs = data.to(self.device)[None, ...]
outs = self.model(inputs,input_type="video")
feats = self.model.forward_features(inputs ,input_type = self.input_type)
features = feats.to('cpu').squeeze(0).detach().numpy()
outs = self.model.head(feats)
post_act = torch.nn.Softmax(dim=1)
preds = post_act(outs)
pred_scores, pred_classes = preds.topk(k=self.topk)
labels = [self.classmap[int(i)] for i in pred_classes[0]]
scores = [round(float(x), 5) for x in pred_scores[0]]
print(labels,scores)
return labels, scores
return labels, scores, features

Loading…
Cancel
Save