logo
Browse Source

update the operator.

Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb 2 years ago
parent
commit
823619c7af
  1. 3
      README.md
  2. 4
      __init__.py
  3. 12
      blip.py
  4. 2
      models/blip.py

3
README.md

@ -27,7 +27,6 @@ import towhee
towhee.glob('./animals.jpg') \ towhee.glob('./animals.jpg') \
.image_decode() \ .image_decode() \
.image_captioning.blip(model_name='blip_base') \ .image_captioning.blip(model_name='blip_base') \
.select() \
.show() .show()
``` ```
<img src="./cap.png" alt="result1" style="height:20px;"/> <img src="./cap.png" alt="result1" style="height:20px;"/>
@ -74,7 +73,7 @@ An image-text embedding operator takes a [towhee image](link/to/towhee/image/api
**Parameters:** **Parameters:**
***data:*** *towhee.types.Image (a sub-class of numpy.ndarray)* or *str*
***img:*** *towhee.types.Image (a sub-class of numpy.ndarray)*
​ The image to generate embedding. ​ The image to generate embedding.

4
__init__.py

@ -15,5 +15,5 @@
from .blip import Blip from .blip import Blip
def blip(model_name: str, modality: str):
return Blip(model_name, modality)
def blip(model_name: str):
return Blip(model_name)

12
blip.py

@ -19,6 +19,7 @@ import torch
from torchvision import transforms from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode from torchvision.transforms.functional import InterpolationMode
import towhee
from towhee import register from towhee import register
from towhee.operator.base import NNOperator, OperatorFlag from towhee.operator.base import NNOperator, OperatorFlag
from towhee.types.arg import arg, to_image_color from towhee.types.arg import arg, to_image_color
@ -37,7 +38,6 @@ class Blip(NNOperator):
model_url = self._configs()[model_name]['weights'] model_url = self._configs()[model_name]['weights']
self.model = blip_decoder(pretrained=model_url, image_size=image_size, vit='base') self.model = blip_decoder(pretrained=model_url, image_size=image_size, vit='base')
self._modality = modality
self.device = "cuda" if torch.cuda.is_available() else "cpu" self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device) self.model.to(self.device)
self.model.eval() self.model.eval()
@ -49,17 +49,17 @@ class Blip(NNOperator):
]) ])
@arg(1, to_image_color('RGB')) @arg(1, to_image_color('RGB'))
def __call__(self, data:):
vec = self._inference_from_image(data)
def __call__(self, img: towhee._types.Image):
vec = self._inference_from_image(img)
return vec return vec
@arg(1, to_image_color('RGB')) @arg(1, to_image_color('RGB'))
def _inference_from_image(self, img):
def _inference_from_image(self, img: towhee._types.Image):
img = self._preprocess(img) img = self._preprocess(img)
caption = model.generate(img, sample=False, num_beams=3, max_length=20, min_length=5)
caption = self.model.generate(img, sample=False, num_beams=3, max_length=20, min_length=5)
return caption[0] return caption[0]
def _preprocess(self, img):
def _preprocess(self, img: towhee._types.Image):
img = to_pil(img) img = to_pil(img)
processed_img = self.tfms(img).unsqueeze(0).to(self.device) processed_img = self.tfms(img).unsqueeze(0).to(self.device)
return processed_img return processed_img

2
models/blip.py

@ -96,6 +96,8 @@ class BLIP_Decoder(nn.Module):
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer) self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
self.tokenizer = init_tokenizer() self.tokenizer = init_tokenizer()
dirpath = str(Path(__file__).parent.parent)
med_config = dirpath + '/' + med_config
med_config = BertConfig.from_json_file(med_config) med_config = BertConfig.from_json_file(med_config)
med_config.encoder_width = vision_width med_config.encoder_width = vision_width
self.text_decoder = BertLMHeadModel(config=med_config) self.text_decoder = BertLMHeadModel(config=med_config)

Loading…
Cancel
Save