Browse Source
enable accelerate
Signed-off-by: junjie.jiang <junjie.jiang@zilliz.com>
main
1 changed files with
6 additions and
2 deletions
-
timm_image.py
|
@ -23,7 +23,11 @@ from towhee.operator.base import NNOperator, OperatorFlag |
|
|
from towhee.types.arg import arg, to_image_color |
|
|
from towhee.types.arg import arg, to_image_color |
|
|
from towhee import register |
|
|
from towhee import register |
|
|
from towhee.types import Image |
|
|
from towhee.types import Image |
|
|
# from towhee.dc2 import accelerate |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
from towhee import accelerate |
|
|
|
|
|
except: |
|
|
|
|
|
def accelerate(func): |
|
|
|
|
|
return func |
|
|
|
|
|
|
|
|
import torch |
|
|
import torch |
|
|
from torch import nn |
|
|
from torch import nn |
|
@ -48,7 +52,7 @@ def torch_no_grad(f): |
|
|
return wrap |
|
|
return wrap |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# @accelerate |
|
|
|
|
|
|
|
|
@accelerate |
|
|
class Model: |
|
|
class Model: |
|
|
def __init__(self, model_name, device, num_classes): |
|
|
def __init__(self, model_name, device, num_classes): |
|
|
self.device = device |
|
|
self.device = device |
|
|