logo
Browse Source

enable accelerate

Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
main
junjie.jiang 2 years ago
parent
commit
505b0b220d
  1. 8
      timm_image.py

8
timm_image.py

@ -23,7 +23,11 @@ from towhee.operator.base import NNOperator, OperatorFlag
from towhee.types.arg import arg, to_image_color from towhee.types.arg import arg, to_image_color
from towhee import register from towhee import register
from towhee.types import Image from towhee.types import Image
# from towhee.dc2 import accelerate
try:
from towhee import accelerate
except:
def accelerate(func):
return func
import torch import torch
from torch import nn from torch import nn
@ -48,7 +52,7 @@ def torch_no_grad(f):
return wrap return wrap
# @accelerate
@accelerate
class Model: class Model:
def __init__(self, model_name, device, num_classes): def __init__(self, model_name, device, num_classes):
self.device = device self.device = device

Loading…
Cancel
Save