diff --git a/timm_image.py b/timm_image.py index c9ad998..53beba5 100644 --- a/timm_image.py +++ b/timm_image.py @@ -45,7 +45,6 @@ log = logging.getLogger('timm_op') class Model: def __init__(self, model_name, device, num_classes): self.model = create_model(model_name, pretrained=True, num_classes=num_classes) - self.config = resolve_data_config({}, model=self.model) self.model.eval() self.model.to(device) @@ -83,7 +82,6 @@ class TimmImage(NNOperator): device=self.device, num_classes=num_classes ) - self.config = self.model.config self.tfms = create_transform(**self.config) self.skip_tfms = skip_preprocess else: @@ -117,6 +115,10 @@ class TimmImage(NNOperator): def _model(self): return self.model.model + @property + def config(self): + return resolve_data_config({}, model=self.model) + @arg(1, to_image_color('RGB')) def convert_img(self, img: towhee._types.Image): img = PILImage.fromarray(img.astype('uint8'), 'RGB')