logo
Browse Source

add omnivore

Signed-off-by: gexy5 <xinyu.ge@zilliz.com>
main
gexy5 2 years ago
parent
commit
d214180cf4
  1. 11
      omnivore.py

11
omnivore.py

@ -41,12 +41,14 @@ class Omnivore(NNOperator):
def __init__(self, def __init__(self,
model_name: str = 'omnivore_swinT', model_name: str = 'omnivore_swinT',
framework: str = 'pytorch', framework: str = 'pytorch',
input_type: str = 'video',
skip_preprocess: bool = False, skip_preprocess: bool = False,
classmap: dict = None, classmap: dict = None,
topk: int = 5, topk: int = 5,
): ):
super().__init__(framework=framework) super().__init__(framework=framework)
self.model_name = model_name self.model_name = model_name
self.input_type = input_type
self.skip_preprocess = skip_preprocess self.skip_preprocess = skip_preprocess
self.topk = topk self.topk = topk
self.dataset_name = 'kinetics_400' self.dataset_name = 'kinetics_400'
@ -99,11 +101,14 @@ class Omnivore(NNOperator):
) )
inputs = data.to(self.device)[None, ...] inputs = data.to(self.device)[None, ...]
outs = self.model(inputs,input_type="video")
feats = self.model.forward_features(inputs ,input_type = self.input_type)
features = feats.to('cpu').squeeze(0).detach().numpy()
outs = self.model.head(feats)
post_act = torch.nn.Softmax(dim=1) post_act = torch.nn.Softmax(dim=1)
preds = post_act(outs) preds = post_act(outs)
pred_scores, pred_classes = preds.topk(k=self.topk) pred_scores, pred_classes = preds.topk(k=self.topk)
labels = [self.classmap[int(i)] for i in pred_classes[0]] labels = [self.classmap[int(i)] for i in pred_classes[0]]
scores = [round(float(x), 5) for x in pred_scores[0]] scores = [round(float(x), 5) for x in pred_scores[0]]
print(labels,scores)
return labels, scores
return labels, scores, features

Loading…
Cancel
Save