logo
Browse Source

fix for onnx export.

Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb 2 years ago
parent
commit
4e37839cc0
  1. 19
      blip.py

19
blip.py

@ -26,6 +26,7 @@ from transformers import logging as t_logging
from towhee import register
from towhee.operator.base import NNOperator, OperatorFlag
from towhee.types.arg import arg, to_image_color
from towhee.dc2 import accelerate
log = logging.getLogger('run_op')
warnings.filterwarnings('ignore')
@ -57,8 +58,8 @@ class BLIPModelVision(nn.Module):
super().__init__()
self.backbone = model
def forward(self, image):
image_embeds = self.backbone.vision_model(image)[0]
def forward(self, pixel_values):
image_embeds = self.backbone.vision_model(pixel_values)[0]
image_embeds = self.backbone.vision_proj(image_embeds[:,0,:])
return image_embeds
@ -96,12 +97,13 @@ class Blip(NNOperator):
"""
def __init__(self, model_name: str, modality: str, device:str = 'cpu', checkpoint_path: str = None):
super().__init__()
self.model_name = model_name
real_name = self._configs()[model_name]['name']
self.model = Model(real_name, modality, checkpoint_path, device)
self.modality = modality
self.device = device
self.checkpoint_path = checkpoint_path
self.processor = AutoProcessor.from_pretrained('Salesforce/blip-itm-base-coco')
self.processor = AutoProcessor.from_pretrained(real_name)
def __call__(self, data):
if not isinstance(data, list):
@ -125,9 +127,9 @@ class Blip(NNOperator):
@arg(1, to_image_color('RGB'))
def _inference_from_image(self, img):
inputs = self.processor(images=img, return_tensors='pt')['pixel_values']
inputs = self.processor(images=img, return_tensors='pt')
inputs = inputs.to(self.device)
image_feature = self.model(inputs)
image_feature = self.model(inputs['pixel_values'])
return image_feature
def inference_single_data(self, data):
@ -148,7 +150,7 @@ class Blip(NNOperator):
@property
def _model(self):
return self.model
return self.model.model
def train(self, **kwargs):
raise NotImplementedError
@ -188,8 +190,9 @@ class Blip(NNOperator):
output_file = output_file + '.onnx'
else:
raise AttributeError('Unsupported model_type.')
if self.modality == 'image':
sz = self.processor.feature_extractor.crop_size
sz = self.processor.image_processor.size
if isinstance(sz, int):
h = sz
w = sz
@ -229,7 +232,7 @@ class Blip(NNOperator):
else:
raise ValueError('modality[{}] not implemented.'.format(self.modality))
onnx_export(self.model,
onnx_export(self._model,
(dict(inputs),),
f=Path(output_file),
input_names= input_names,

Loading…
Cancel
Save