logo
Browse Source

make acclerate avaliable.

Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb 2 years ago
parent
commit
2d2ae1d414
  1. 12
      blip.py

12
blip.py

@ -30,7 +30,11 @@ from transformers.models.blip.modeling_blip import BlipOutput, blip_loss
from towhee import register
from towhee.operator.base import NNOperator, OperatorFlag
from towhee.types.arg import arg, to_image_color
#from towhee.dc2 import accelerate
try:
from towhee import accelerate
except:
def accelerate(func):
return func
log = logging.getLogger('run_op')
warnings.filterwarnings('ignore')
@ -124,7 +128,7 @@ def create_model(cfg, modality, checkpoint_path, device):
raise ValueError("modality[{}] not implemented.".format(modality))
return blip
#@accelerate
@accelerate
class BLIPModelVision(nn.Module):
def __init__(self, model):
super().__init__()
@ -135,7 +139,7 @@ class BLIPModelVision(nn.Module):
image_embeds = self.backbone.vision_proj(image_embeds[:,0,:])
return image_embeds
#@accelerate
@accelerate
class BLIPModelText(nn.Module):
def __init__(self, model):
super().__init__()
@ -193,14 +197,12 @@ class Blip(NNOperator):
def _inference_from_text(self, text):
inputs = self.processor(text=text, padding=True, return_tensors='pt')
inputs = inputs.to(self.device)
text_feature = self.model(input_ids = inputs.input_ids, attention_mask = inputs.attention_mask)[0]
return text_feature
@arg(1, to_image_color('RGB'))
def _inference_from_image(self, img):
inputs = self.processor(images=img, return_tensors='pt')
inputs = inputs.to(self.device)
image_feature = self.model(inputs['pixel_values'])
return image_feature

Loading…
Cancel
Save