logo
Browse Source

update.

Signed-off-by: wxywb <xy.wang@zilliz.com>
v2
wxywb 3 years ago
parent
commit
d237e9fc15
  1. 16
      blip.py

16
blip.py

@ -14,14 +14,15 @@
import sys import sys
from pathlib import Path from pathlib import Path
import torch
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from towhee import register from towhee import register
from towhee.operator.base import NNOperator, OperatorFlag from towhee.operator.base import NNOperator, OperatorFlag
from towhee.types.arg import arg, to_image_color from towhee.types.arg import arg, to_image_color
import torch
import ipdb
from towhee.types.image_utils import from_pil, to_pil from towhee.types.image_utils import from_pil, to_pil
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
@register(output_schema=['vec']) @register(output_schema=['vec'])
class Blip(NNOperator): class Blip(NNOperator):
@ -46,7 +47,7 @@ class Blip(NNOperator):
]) ])
def __call__(self, data): def __call__(self, data):
ipdb.set_trace()
print('call')
if self._modality == 'image': if self._modality == 'image':
vec = self._inference_from_image(data) vec = self._inference_from_image(data)
elif self._modality == 'text': elif self._modality == 'text':
@ -61,9 +62,6 @@ class Blip(NNOperator):
@arg(1, to_image_color('RGB')) @arg(1, to_image_color('RGB'))
def _inference_from_image(self, img): def _inference_from_image(self, img):
#img = to_pil(img)
#image = self.tfms(img).unsqueeze(0).to(self.device)
#image_features = self.model.encode_image(image)
img = self._preprocess(img) img = self._preprocess(img)
caption = '' caption = ''
image_feature = self.model(img, caption, mode='image', device=self.device)[0,0] image_feature = self.model(img, caption, mode='image', device=self.device)[0,0]
@ -74,7 +72,7 @@ class Blip(NNOperator):
processed_img = self.tfms(img).unsqueeze(0).to(self.device) processed_img = self.tfms(img).unsqueeze(0).to(self.device)
return processed_img return processed_img
def _configs():
def _configs(self):
config = {} config = {}
config['blip_base'] = {} config['blip_base'] = {}
config['blip_base']['weights'] = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth' config['blip_base']['weights'] = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth'

Loading…
Cancel
Save