diff --git a/blip.py b/blip.py index 83e3ef9..5eecd61 100644 --- a/blip.py +++ b/blip.py @@ -26,6 +26,7 @@ from transformers import logging as t_logging from towhee import register from towhee.operator.base import NNOperator, OperatorFlag from towhee.types.arg import arg, to_image_color +from towhee.dc2 import accelerate log = logging.getLogger('run_op') warnings.filterwarnings('ignore') @@ -57,8 +58,8 @@ class BLIPModelVision(nn.Module): super().__init__() self.backbone = model - def forward(self, image): - image_embeds = self.backbone.vision_model(image)[0] + def forward(self, pixel_values): + image_embeds = self.backbone.vision_model(pixel_values)[0] image_embeds = self.backbone.vision_proj(image_embeds[:,0,:]) return image_embeds @@ -96,12 +97,13 @@ class Blip(NNOperator): """ def __init__(self, model_name: str, modality: str, device:str = 'cpu', checkpoint_path: str = None): super().__init__() + self.model_name = model_name real_name = self._configs()[model_name]['name'] self.model = Model(real_name, modality, checkpoint_path, device) self.modality = modality self.device = device self.checkpoint_path = checkpoint_path - self.processor = AutoProcessor.from_pretrained('Salesforce/blip-itm-base-coco') + self.processor = AutoProcessor.from_pretrained(real_name) def __call__(self, data): if not isinstance(data, list): @@ -125,9 +127,9 @@ class Blip(NNOperator): @arg(1, to_image_color('RGB')) def _inference_from_image(self, img): - inputs = self.processor(images=img, return_tensors='pt')['pixel_values'] + inputs = self.processor(images=img, return_tensors='pt') inputs = inputs.to(self.device) - image_feature = self.model(inputs) + image_feature = self.model(inputs['pixel_values']) return image_feature def inference_single_data(self, data): @@ -148,7 +150,7 @@ class Blip(NNOperator): @property def _model(self): - return self.model + return self.model.model def train(self, **kwargs): raise NotImplementedError @@ -188,8 +190,9 @@ class Blip(NNOperator): output_file = output_file + '.onnx' else: raise AttributeError('Unsupported model_type.') + if self.modality == 'image': - sz = self.processor.feature_extractor.crop_size + sz = self.processor.image_processor.size if isinstance(sz, int): h = sz w = sz @@ -229,7 +232,7 @@ class Blip(NNOperator): else: raise ValueError('modality[{}] not implemented.'.format(self.modality)) - onnx_export(self.model, + onnx_export(self._model, (dict(inputs),), f=Path(output_file), input_names= input_names,