logo
Browse Source

add accelerate

main
ChengZi 2 years ago
parent
commit
63e5eaa824
  1. 9
      isc.py

9
isc.py

@ -24,7 +24,12 @@ from towhee.types.arg import arg, to_image_color
from towhee import register from towhee import register
from towhee.models import isc from towhee.models import isc
import sys import sys
# from towhee.dc2 import accelerate
try:
from towhee import accelerate
except:
def accelerate(func):
return func
import torch import torch
from torch import nn from torch import nn
@ -44,7 +49,7 @@ log = logging.getLogger('isc_op')
_ = sys.modules[__name__] _ = sys.modules[__name__]
# @accelerate
@accelerate
class Model: class Model:
def __init__(self, timm_backbone, checkpoint_path, device): def __init__(self, timm_backbone, checkpoint_path, device):
self.device = device self.device = device

Loading…
Cancel
Save