diff --git a/clip.py b/clip.py index 0ebd8ea..c64f453 100644 --- a/clip.py +++ b/clip.py @@ -26,7 +26,11 @@ from towhee.types.arg import arg, to_image_color from towhee import register from transformers import CLIPTokenizer, CLIPTextModel ,CLIPModel,CLIPProcessor from transformers import logging as t_logging -# from towhee.dc2 import accelerate +try: + from towhee import accelerate +except: + def accelerate(func): + return func log = logging.getLogger('run_op') warnings.filterwarnings('ignore') @@ -70,7 +74,7 @@ class CLIPModelText(nn.Module): text_embeds = self.backbone.get_text_features(input_ids, attention_mask) return text_embeds -# @accelerate +@accelerate class Model: def __init__(self, model_name, modality, checkpoint_path, device): self.model = create_model(model_name, modality, checkpoint_path, device)