|
@ -26,7 +26,11 @@ from towhee.types.arg import arg, to_image_color |
|
|
from towhee import register |
|
|
from towhee import register |
|
|
from transformers import CLIPTokenizer, CLIPTextModel ,CLIPModel,CLIPProcessor |
|
|
from transformers import CLIPTokenizer, CLIPTextModel ,CLIPModel,CLIPProcessor |
|
|
from transformers import logging as t_logging |
|
|
from transformers import logging as t_logging |
|
|
# from towhee.dc2 import accelerate |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
from towhee import accelerate |
|
|
|
|
|
except: |
|
|
|
|
|
def accelerate(func): |
|
|
|
|
|
return func |
|
|
|
|
|
|
|
|
log = logging.getLogger('run_op') |
|
|
log = logging.getLogger('run_op') |
|
|
warnings.filterwarnings('ignore') |
|
|
warnings.filterwarnings('ignore') |
|
@ -70,7 +74,7 @@ class CLIPModelText(nn.Module): |
|
|
text_embeds = self.backbone.get_text_features(input_ids, attention_mask) |
|
|
text_embeds = self.backbone.get_text_features(input_ids, attention_mask) |
|
|
return text_embeds |
|
|
return text_embeds |
|
|
|
|
|
|
|
|
# @accelerate |
|
|
|
|
|
|
|
|
@accelerate |
|
|
class Model: |
|
|
class Model: |
|
|
def __init__(self, model_name, modality, checkpoint_path, device): |
|
|
def __init__(self, model_name, modality, checkpoint_path, device): |
|
|
self.model = create_model(model_name, modality, checkpoint_path, device) |
|
|
self.model = create_model(model_name, modality, checkpoint_path, device) |
|
|