|
|
@ -26,6 +26,7 @@ from transformers import logging as t_logging |
|
|
|
from towhee import register |
|
|
|
from towhee.operator.base import NNOperator, OperatorFlag |
|
|
|
from towhee.types.arg import arg, to_image_color |
|
|
|
from towhee.dc2 import accelerate |
|
|
|
|
|
|
|
log = logging.getLogger('run_op') |
|
|
|
warnings.filterwarnings('ignore') |
|
|
@ -57,8 +58,8 @@ class BLIPModelVision(nn.Module): |
|
|
|
super().__init__() |
|
|
|
self.backbone = model |
|
|
|
|
|
|
|
def forward(self, image): |
|
|
|
image_embeds = self.backbone.vision_model(image)[0] |
|
|
|
def forward(self, pixel_values): |
|
|
|
image_embeds = self.backbone.vision_model(pixel_values)[0] |
|
|
|
image_embeds = self.backbone.vision_proj(image_embeds[:,0,:]) |
|
|
|
return image_embeds |
|
|
|
|
|
|
@ -96,12 +97,13 @@ class Blip(NNOperator): |
|
|
|
""" |
|
|
|
def __init__(self, model_name: str, modality: str, device:str = 'cpu', checkpoint_path: str = None): |
|
|
|
super().__init__() |
|
|
|
self.model_name = model_name |
|
|
|
real_name = self._configs()[model_name]['name'] |
|
|
|
self.model = Model(real_name, modality, checkpoint_path, device) |
|
|
|
self.modality = modality |
|
|
|
self.device = device |
|
|
|
self.checkpoint_path = checkpoint_path |
|
|
|
self.processor = AutoProcessor.from_pretrained('Salesforce/blip-itm-base-coco') |
|
|
|
self.processor = AutoProcessor.from_pretrained(real_name) |
|
|
|
|
|
|
|
def __call__(self, data): |
|
|
|
if not isinstance(data, list): |
|
|
@ -125,9 +127,9 @@ class Blip(NNOperator): |
|
|
|
|
|
|
|
@arg(1, to_image_color('RGB')) |
|
|
|
def _inference_from_image(self, img): |
|
|
|
inputs = self.processor(images=img, return_tensors='pt')['pixel_values'] |
|
|
|
inputs = self.processor(images=img, return_tensors='pt') |
|
|
|
inputs = inputs.to(self.device) |
|
|
|
image_feature = self.model(inputs) |
|
|
|
image_feature = self.model(inputs['pixel_values']) |
|
|
|
return image_feature |
|
|
|
|
|
|
|
def inference_single_data(self, data): |
|
|
@ -148,7 +150,7 @@ class Blip(NNOperator): |
|
|
|
|
|
|
|
@property |
|
|
|
def _model(self): |
|
|
|
return self.model |
|
|
|
return self.model.model |
|
|
|
|
|
|
|
def train(self, **kwargs): |
|
|
|
raise NotImplementedError |
|
|
@ -188,8 +190,9 @@ class Blip(NNOperator): |
|
|
|
output_file = output_file + '.onnx' |
|
|
|
else: |
|
|
|
raise AttributeError('Unsupported model_type.') |
|
|
|
|
|
|
|
if self.modality == 'image': |
|
|
|
sz = self.processor.feature_extractor.crop_size |
|
|
|
sz = self.processor.image_processor.size |
|
|
|
if isinstance(sz, int): |
|
|
|
h = sz |
|
|
|
w = sz |
|
|
@ -229,7 +232,7 @@ class Blip(NNOperator): |
|
|
|
else: |
|
|
|
raise ValueError('modality[{}] not implemented.'.format(self.modality)) |
|
|
|
|
|
|
|
onnx_export(self.model, |
|
|
|
onnx_export(self._model, |
|
|
|
(dict(inputs),), |
|
|
|
f=Path(output_file), |
|
|
|
input_names= input_names, |
|
|
|