|
|
@ -19,6 +19,7 @@ import torch |
|
|
|
from torchvision import transforms |
|
|
|
from torchvision.transforms.functional import InterpolationMode |
|
|
|
|
|
|
|
import towhee |
|
|
|
from towhee import register |
|
|
|
from towhee.operator.base import NNOperator, OperatorFlag |
|
|
|
from towhee.types.arg import arg, to_image_color |
|
|
@ -37,7 +38,6 @@ class Blip(NNOperator): |
|
|
|
model_url = self._configs()[model_name]['weights'] |
|
|
|
self.model = blip_decoder(pretrained=model_url, image_size=image_size, vit='base') |
|
|
|
|
|
|
|
self._modality = modality |
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
self.model.to(self.device) |
|
|
|
self.model.eval() |
|
|
@ -49,17 +49,17 @@ class Blip(NNOperator): |
|
|
|
]) |
|
|
|
|
|
|
|
@arg(1, to_image_color('RGB')) |
|
|
|
def __call__(self, data:): |
|
|
|
vec = self._inference_from_image(data) |
|
|
|
def __call__(self, img: towhee._types.Image): |
|
|
|
vec = self._inference_from_image(img) |
|
|
|
return vec |
|
|
|
|
|
|
|
@arg(1, to_image_color('RGB')) |
|
|
|
def _inference_from_image(self, img): |
|
|
|
def _inference_from_image(self, img: towhee._types.Image): |
|
|
|
img = self._preprocess(img) |
|
|
|
caption = model.generate(img, sample=False, num_beams=3, max_length=20, min_length=5) |
|
|
|
caption = self.model.generate(img, sample=False, num_beams=3, max_length=20, min_length=5) |
|
|
|
return caption[0] |
|
|
|
|
|
|
|
def _preprocess(self, img): |
|
|
|
def _preprocess(self, img: towhee._types.Image): |
|
|
|
img = to_pil(img) |
|
|
|
processed_img = self.tfms(img).unsqueeze(0).to(self.device) |
|
|
|
return processed_img |
|
|
|