diff --git a/isc.py b/isc.py index a187b6f..0ce6ded 100644 --- a/isc.py +++ b/isc.py @@ -24,7 +24,12 @@ from towhee.types.arg import arg, to_image_color from towhee import register from towhee.models import isc import sys -# from towhee.dc2 import accelerate + +try: + from towhee import accelerate +except: + def accelerate(func): + return func import torch from torch import nn @@ -44,7 +49,7 @@ log = logging.getLogger('isc_op') _ = sys.modules[__name__] -# @accelerate +@accelerate class Model: def __init__(self, timm_backbone, checkpoint_path, device): self.device = device