logo
Browse Source

modify the operator for dc2.

Signed-off-by: wxywb <xy.wang@zilliz.com>
v2
wxywb 2 years ago
parent
commit
1809c72972
  1. 50
      blip.py

50
blip.py

@ -24,18 +24,53 @@ from towhee.operator.base import NNOperator, OperatorFlag
from towhee.types.arg import arg, to_image_color
from towhee.types.image_utils import from_pil, to_pil
#@accelerate
class BLIPModelVision(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, image):
image_embeds = self.model.visual_encoder(image)
return image_embeds
#@accelerate
class BLIPModelText(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, input_ids, attention_mask):
text_output = self.text_encoder(input_ids, attention_mask = attention_mask,
return_dict = False, mode = 'text')
return text_output
@register(output_schema=['vec'])
class Blip(NNOperator):
"""
BLIP multi-modal embedding operator
"""
def __init__(self, model_name: str, modality: str):
def __init__(self, model_name: str, modality: str, device:str = 'cpu', checkpoint_path: str = None):
super().__init__()
sys.path.append(str(Path(__file__).parent))
from models.blip import blip_feature_extractor
image_size = 224
model_url = self._configs()[model_name]['weights']
self.model = blip_feature_extractor(pretrained=model_url, image_size=image_size, vit='base')
self.model_name = model_name
self.device = device
cfg = self._configs()[model_name]
model_url = cfg['weights']
image_size = cfg['image_size']
model = blip_feature_extractor(pretrained=model_url, image_size=image_size, vit='base')
self.tokenizer = model.tokenizer
if self.modality == 'image':
self.model = BLIPModelVision(model)
elif self.modality == 'text':
self.model = BLIPModelText(model)
else:
raise ValueError("modality[{}] not implemented.".format(self.modality))
self._modality = modality
self.device = "cuda" if torch.cuda.is_available() else "cpu"
@ -58,14 +93,14 @@ class Blip(NNOperator):
return vec.detach().cpu().numpy().flatten()
def _inference_from_text(self, text):
text_feature = self.model(None, text, mode='text', device=self.device)[0,0]
tokens = self.tokenizer(text, return_tensors="pt").to(self.device)
text_feature = self.model(tokens.input_ids, tokens.attention_mask)
return text_feature
@arg(1, to_image_color('RGB'))
def _inference_from_image(self, img):
img = self._preprocess(img)
caption = ''
image_feature = self.model(img, caption, mode='image', device=self.device)[0,0]
image_feature = self.model(img)
return image_feature
def _preprocess(self, img):
@ -77,5 +112,6 @@ class Blip(NNOperator):
config = {}
config['blip_base'] = {}
config['blip_base']['weights'] = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth'
config['blip_base']['image_size'] = 224
return config

Loading…
Cancel
Save