diff --git a/timm_image.py b/timm_image.py index a5690a3..515f5c2 100644 --- a/timm_image.py +++ b/timm_image.py @@ -23,7 +23,11 @@ from towhee.operator.base import NNOperator, OperatorFlag from towhee.types.arg import arg, to_image_color from towhee import register from towhee.types import Image -# from towhee.dc2 import accelerate +try: + from towhee import accelerate +except: + def accelerate(func): + return func import torch from torch import nn @@ -48,7 +52,7 @@ def torch_no_grad(f): return wrap -# @accelerate +@accelerate class Model: def __init__(self, model_name, device, num_classes): self.device = device