logo
Browse Source

update the operator.

Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb 2 years ago
parent
commit
823619c7af
  1. 3
      README.md
  2. 4
      __init__.py
  3. 12
      blip.py
  4. 2
      models/blip.py

3
README.md

@ -27,7 +27,6 @@ import towhee
towhee.glob('./animals.jpg') \
.image_decode() \
.image_captioning.blip(model_name='blip_base') \
.select() \
.show()
```
<img src="./cap.png" alt="result1" style="height:20px;"/>
@ -74,7 +73,7 @@ An image-text embedding operator takes a [towhee image](link/to/towhee/image/api
**Parameters:**
***data:*** *towhee.types.Image (a sub-class of numpy.ndarray)* or *str*
***img:*** *towhee.types.Image (a sub-class of numpy.ndarray)*
​ The image to generate embedding.

4
__init__.py

@ -15,5 +15,5 @@
from .blip import Blip
def blip(model_name: str, modality: str):
return Blip(model_name, modality)
def blip(model_name: str):
return Blip(model_name)

12
blip.py

@ -19,6 +19,7 @@ import torch
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
import towhee
from towhee import register
from towhee.operator.base import NNOperator, OperatorFlag
from towhee.types.arg import arg, to_image_color
@ -37,7 +38,6 @@ class Blip(NNOperator):
model_url = self._configs()[model_name]['weights']
self.model = blip_decoder(pretrained=model_url, image_size=image_size, vit='base')
self._modality = modality
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model.to(self.device)
self.model.eval()
@ -49,17 +49,17 @@ class Blip(NNOperator):
])
@arg(1, to_image_color('RGB'))
def __call__(self, data:):
vec = self._inference_from_image(data)
def __call__(self, img: towhee._types.Image):
vec = self._inference_from_image(img)
return vec
@arg(1, to_image_color('RGB'))
def _inference_from_image(self, img):
def _inference_from_image(self, img: towhee._types.Image):
img = self._preprocess(img)
caption = model.generate(img, sample=False, num_beams=3, max_length=20, min_length=5)
caption = self.model.generate(img, sample=False, num_beams=3, max_length=20, min_length=5)
return caption[0]
def _preprocess(self, img):
def _preprocess(self, img: towhee._types.Image):
img = to_pil(img)
processed_img = self.tfms(img).unsqueeze(0).to(self.device)
return processed_img

2
models/blip.py

@ -96,6 +96,8 @@ class BLIP_Decoder(nn.Module):
self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer)
self.tokenizer = init_tokenizer()
dirpath = str(Path(__file__).parent.parent)
med_config = dirpath + '/' + med_config
med_config = BertConfig.from_json_file(med_config)
med_config.encoder_width = vision_width
self.text_decoder = BertLMHeadModel(config=med_config)

Loading…
Cancel
Save