From 823619c7af8f74356ff8ce334463711f65699156 Mon Sep 17 00:00:00 2001 From: wxywb Date: Wed, 3 Aug 2022 17:54:04 +0800 Subject: [PATCH] update the operator. Signed-off-by: wxywb --- README.md | 3 +-- __init__.py | 4 ++-- blip.py | 12 ++++++------ models/blip.py | 2 ++ 4 files changed, 11 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index cfcbc6a..194851b 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,6 @@ import towhee towhee.glob('./animals.jpg') \ .image_decode() \ .image_captioning.blip(model_name='blip_base') \ - .select() \ .show() ``` result1 @@ -74,7 +73,7 @@ An image-text embedding operator takes a [towhee image](link/to/towhee/image/api **Parameters:** -​ ***data:*** *towhee.types.Image (a sub-class of numpy.ndarray)* or *str* +​ ***img:*** *towhee.types.Image (a sub-class of numpy.ndarray)* ​ The image to generate embedding. diff --git a/__init__.py b/__init__.py index 3a4024d..9f50597 100644 --- a/__init__.py +++ b/__init__.py @@ -15,5 +15,5 @@ from .blip import Blip -def blip(model_name: str, modality: str): - return Blip(model_name, modality) +def blip(model_name: str): + return Blip(model_name) diff --git a/blip.py b/blip.py index b58066a..916f675 100644 --- a/blip.py +++ b/blip.py @@ -19,6 +19,7 @@ import torch from torchvision import transforms from torchvision.transforms.functional import InterpolationMode +import towhee from towhee import register from towhee.operator.base import NNOperator, OperatorFlag from towhee.types.arg import arg, to_image_color @@ -37,7 +38,6 @@ class Blip(NNOperator): model_url = self._configs()[model_name]['weights'] self.model = blip_decoder(pretrained=model_url, image_size=image_size, vit='base') - self._modality = modality self.device = "cuda" if torch.cuda.is_available() else "cpu" self.model.to(self.device) self.model.eval() @@ -49,17 +49,17 @@ class Blip(NNOperator): ]) @arg(1, to_image_color('RGB')) - def __call__(self, data:): - vec = self._inference_from_image(data) + def __call__(self, img: towhee._types.Image): + vec = self._inference_from_image(img) return vec @arg(1, to_image_color('RGB')) - def _inference_from_image(self, img): + def _inference_from_image(self, img: towhee._types.Image): img = self._preprocess(img) - caption = model.generate(img, sample=False, num_beams=3, max_length=20, min_length=5) + caption = self.model.generate(img, sample=False, num_beams=3, max_length=20, min_length=5) return caption[0] - def _preprocess(self, img): + def _preprocess(self, img: towhee._types.Image): img = to_pil(img) processed_img = self.tfms(img).unsqueeze(0).to(self.device) return processed_img diff --git a/models/blip.py b/models/blip.py index 5d3619f..cb8e1b3 100644 --- a/models/blip.py +++ b/models/blip.py @@ -96,6 +96,8 @@ class BLIP_Decoder(nn.Module): self.visual_encoder, vision_width = create_vit(vit,image_size, vit_grad_ckpt, vit_ckpt_layer) self.tokenizer = init_tokenizer() + dirpath = str(Path(__file__).parent.parent) + med_config = dirpath + '/' + med_config med_config = BertConfig.from_json_file(med_config) med_config.encoder_width = vision_width self.text_decoder = BertLMHeadModel(config=med_config)