diff --git a/timm_image.py b/timm_image.py index 53beba5..8ddea6f 100644 --- a/timm_image.py +++ b/timm_image.py @@ -23,7 +23,7 @@ from towhee.operator.base import NNOperator, OperatorFlag from towhee.types.arg import arg, to_image_color from towhee import register from towhee.types import Image -from towhee.dc2 import accelerate +# from towhee.dc2 import accelerate import torch from torch import nn @@ -41,7 +41,7 @@ warnings.filterwarnings('ignore') log = logging.getLogger('timm_op') -@accelerate +# @accelerate class Model: def __init__(self, model_name, device, num_classes): self.model = create_model(model_name, pretrained=True, num_classes=num_classes)