logo
Browse Source

Remove accelerate

Signed-off-by: Jael Gu <mengjia.gu@zilliz.com>
main
Jael Gu 2 years ago
parent
commit
768f0068b8
  1. 4
      timm_image.py

4
timm_image.py

@ -23,7 +23,7 @@ from towhee.operator.base import NNOperator, OperatorFlag
from towhee.types.arg import arg, to_image_color from towhee.types.arg import arg, to_image_color
from towhee import register from towhee import register
from towhee.types import Image from towhee.types import Image
from towhee.dc2 import accelerate
# from towhee.dc2 import accelerate
import torch import torch
from torch import nn from torch import nn
@ -41,7 +41,7 @@ warnings.filterwarnings('ignore')
log = logging.getLogger('timm_op') log = logging.getLogger('timm_op')
@accelerate
# @accelerate
class Model: class Model:
def __init__(self, model_name, device, num_classes): def __init__(self, model_name, device, num_classes):
self.model = create_model(model_name, pretrained=True, num_classes=num_classes) self.model = create_model(model_name, pretrained=True, num_classes=num_classes)

Loading…
Cancel
Save