|
@ -45,7 +45,6 @@ log = logging.getLogger('timm_op') |
|
|
class Model: |
|
|
class Model: |
|
|
def __init__(self, model_name, device, num_classes): |
|
|
def __init__(self, model_name, device, num_classes): |
|
|
self.model = create_model(model_name, pretrained=True, num_classes=num_classes) |
|
|
self.model = create_model(model_name, pretrained=True, num_classes=num_classes) |
|
|
self.config = resolve_data_config({}, model=self.model) |
|
|
|
|
|
self.model.eval() |
|
|
self.model.eval() |
|
|
self.model.to(device) |
|
|
self.model.to(device) |
|
|
|
|
|
|
|
@ -83,7 +82,6 @@ class TimmImage(NNOperator): |
|
|
device=self.device, |
|
|
device=self.device, |
|
|
num_classes=num_classes |
|
|
num_classes=num_classes |
|
|
) |
|
|
) |
|
|
self.config = self.model.config |
|
|
|
|
|
self.tfms = create_transform(**self.config) |
|
|
self.tfms = create_transform(**self.config) |
|
|
self.skip_tfms = skip_preprocess |
|
|
self.skip_tfms = skip_preprocess |
|
|
else: |
|
|
else: |
|
@ -117,6 +115,10 @@ class TimmImage(NNOperator): |
|
|
def _model(self): |
|
|
def _model(self): |
|
|
return self.model.model |
|
|
return self.model.model |
|
|
|
|
|
|
|
|
|
|
|
@property |
|
|
|
|
|
def config(self): |
|
|
|
|
|
return resolve_data_config({}, model=self.model) |
|
|
|
|
|
|
|
|
@arg(1, to_image_color('RGB')) |
|
|
@arg(1, to_image_color('RGB')) |
|
|
def convert_img(self, img: towhee._types.Image): |
|
|
def convert_img(self, img: towhee._types.Image): |
|
|
img = PILImage.fromarray(img.astype('uint8'), 'RGB') |
|
|
img = PILImage.fromarray(img.astype('uint8'), 'RGB') |
|
|