logo
Browse Source

make accelerate evaliable.

Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb 2 years ago
parent
commit
76713536b1
  1. 8
      clip.py

8
clip.py

@ -26,7 +26,11 @@ from towhee.types.arg import arg, to_image_color
from towhee import register
from transformers import CLIPTokenizer, CLIPTextModel ,CLIPModel,CLIPProcessor
from transformers import logging as t_logging
# from towhee.dc2 import accelerate
try:
from towhee import accelerate
except:
def accelerate(func):
return func
log = logging.getLogger('run_op')
warnings.filterwarnings('ignore')
@ -70,7 +74,7 @@ class CLIPModelText(nn.Module):
text_embeds = self.backbone.get_text_features(input_ids, attention_mask)
return text_embeds
# @accelerate
@accelerate
class Model:
def __init__(self, model_name, modality, checkpoint_path, device):
self.model = create_model(model_name, modality, checkpoint_path, device)

Loading…
Cancel
Save