diff --git a/blip.py b/blip.py index 21e0240..c339ed2 100644 --- a/blip.py +++ b/blip.py @@ -30,7 +30,11 @@ from transformers.models.blip.modeling_blip import BlipOutput, blip_loss from towhee import register from towhee.operator.base import NNOperator, OperatorFlag from towhee.types.arg import arg, to_image_color -#from towhee.dc2 import accelerate +try: + from towhee import accelerate +except: + def accelerate(func): + return func log = logging.getLogger('run_op') warnings.filterwarnings('ignore') @@ -124,7 +128,7 @@ def create_model(cfg, modality, checkpoint_path, device): raise ValueError("modality[{}] not implemented.".format(modality)) return blip -#@accelerate +@accelerate class BLIPModelVision(nn.Module): def __init__(self, model): super().__init__() @@ -135,7 +139,7 @@ class BLIPModelVision(nn.Module): image_embeds = self.backbone.vision_proj(image_embeds[:,0,:]) return image_embeds -#@accelerate +@accelerate class BLIPModelText(nn.Module): def __init__(self, model): super().__init__() @@ -193,14 +197,12 @@ class Blip(NNOperator): def _inference_from_text(self, text): inputs = self.processor(text=text, padding=True, return_tensors='pt') - inputs = inputs.to(self.device) text_feature = self.model(input_ids = inputs.input_ids, attention_mask = inputs.attention_mask)[0] return text_feature @arg(1, to_image_color('RGB')) def _inference_from_image(self, img): inputs = self.processor(images=img, return_tensors='pt') - inputs = inputs.to(self.device) image_feature = self.model(inputs['pixel_values']) return image_feature