|
|
@ -30,7 +30,11 @@ from transformers.models.blip.modeling_blip import BlipOutput, blip_loss |
|
|
|
from towhee import register |
|
|
|
from towhee.operator.base import NNOperator, OperatorFlag |
|
|
|
from towhee.types.arg import arg, to_image_color |
|
|
|
#from towhee.dc2 import accelerate |
|
|
|
try: |
|
|
|
from towhee import accelerate |
|
|
|
except: |
|
|
|
def accelerate(func): |
|
|
|
return func |
|
|
|
|
|
|
|
log = logging.getLogger('run_op') |
|
|
|
warnings.filterwarnings('ignore') |
|
|
@ -124,7 +128,7 @@ def create_model(cfg, modality, checkpoint_path, device): |
|
|
|
raise ValueError("modality[{}] not implemented.".format(modality)) |
|
|
|
return blip |
|
|
|
|
|
|
|
#@accelerate |
|
|
|
@accelerate |
|
|
|
class BLIPModelVision(nn.Module): |
|
|
|
def __init__(self, model): |
|
|
|
super().__init__() |
|
|
@ -135,7 +139,7 @@ class BLIPModelVision(nn.Module): |
|
|
|
image_embeds = self.backbone.vision_proj(image_embeds[:,0,:]) |
|
|
|
return image_embeds |
|
|
|
|
|
|
|
#@accelerate |
|
|
|
@accelerate |
|
|
|
class BLIPModelText(nn.Module): |
|
|
|
def __init__(self, model): |
|
|
|
super().__init__() |
|
|
@ -193,14 +197,12 @@ class Blip(NNOperator): |
|
|
|
|
|
|
|
def _inference_from_text(self, text): |
|
|
|
inputs = self.processor(text=text, padding=True, return_tensors='pt') |
|
|
|
inputs = inputs.to(self.device) |
|
|
|
text_feature = self.model(input_ids = inputs.input_ids, attention_mask = inputs.attention_mask)[0] |
|
|
|
return text_feature |
|
|
|
|
|
|
|
@arg(1, to_image_color('RGB')) |
|
|
|
def _inference_from_image(self, img): |
|
|
|
inputs = self.processor(images=img, return_tensors='pt') |
|
|
|
inputs = inputs.to(self.device) |
|
|
|
image_feature = self.model(inputs['pixel_values']) |
|
|
|
return image_feature |
|
|
|
|
|
|
|