logo
Browse Source

Debug for tritonserve

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
c0bf53ea6e
  1. 6
      timm_image.py

6
timm_image.py

@ -45,7 +45,6 @@ log = logging.getLogger('timm_op')
class Model: class Model:
def __init__(self, model_name, device, num_classes): def __init__(self, model_name, device, num_classes):
self.model = create_model(model_name, pretrained=True, num_classes=num_classes) self.model = create_model(model_name, pretrained=True, num_classes=num_classes)
self.config = resolve_data_config({}, model=self.model)
self.model.eval() self.model.eval()
self.model.to(device) self.model.to(device)
@ -83,7 +82,6 @@ class TimmImage(NNOperator):
device=self.device, device=self.device,
num_classes=num_classes num_classes=num_classes
) )
self.config = self.model.config
self.tfms = create_transform(**self.config) self.tfms = create_transform(**self.config)
self.skip_tfms = skip_preprocess self.skip_tfms = skip_preprocess
else: else:
@ -117,6 +115,10 @@ class TimmImage(NNOperator):
def _model(self): def _model(self):
return self.model.model return self.model.model
@property
def config(self):
return resolve_data_config({}, model=self.model)
@arg(1, to_image_color('RGB')) @arg(1, to_image_color('RGB'))
def convert_img(self, img: towhee._types.Image): def convert_img(self, img: towhee._types.Image):
img = PILImage.fromarray(img.astype('uint8'), 'RGB') img = PILImage.fromarray(img.astype('uint8'), 'RGB')

Loading…
Cancel
Save