logo
Browse Source

bugfix for the operator.

Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb 2 years ago
parent
commit
d354584a2c
  1. 38
      clipcap.py

38
clipcap.py

@ -11,11 +11,16 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import sys
import os import os
import torch
from pathlib import Path from pathlib import Path
import torch
from torchvision import transforms from torchvision import transforms
from transformers import GPT2Tokenizer
from towhee.types.arg import arg, to_image_color
from towhee.types.image_utils import to_pil from towhee.types.image_utils import to_pil
from towhee.operator.base import NNOperator, OperatorFlag from towhee.operator.base import NNOperator, OperatorFlag
from towhee import register from towhee import register
@ -26,11 +31,16 @@ class ClipCap(NNOperator):
ClipCap image captioning operator ClipCap image captioning operator
""" """
def __init__(self, model_name: str): def __init__(self, model_name: str):
super().__init__():
super().__init__()
sys.path.append(str(Path(__file__).parent)) sys.path.append(str(Path(__file__).parent))
from models.clipcap import ClipCaptionModel
from models.clipcap import ClipCaptionModel, generate_beam
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.generate_beam = generate_beam
self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
config = self._configs()[model_name] config = self._configs()[model_name]
self.prefix_length = 10
self.clip_tfms = self.tfms = transforms.Compose([ self.clip_tfms = self.tfms = transforms.Compose([
transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC), transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC),
transforms.CenterCrop(224), transforms.CenterCrop(224),
@ -42,38 +52,38 @@ class ClipCap(NNOperator):
clip_model_type = 'clip_vit_b32' clip_model_type = 'clip_vit_b32'
self.clip_model = clip.create_model(model_name=clip_model_type, pretrained=True, jit=True) self.clip_model = clip.create_model(model_name=clip_model_type, pretrained=True, jit=True)
self.model = ClipCaptionModel(prefix = 10)
self.model = ClipCaptionModel(self.prefix_length)
model_path = os.path.dirname(__file__) + '/weights/' + config['weights'] model_path = os.path.dirname(__file__) + '/weights/' + config['weights']
self.model.load_state_dict(torch.load(model_path, map_location=CPU))
self.model = model.eval()
self.model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
self.model = self.model.eval()
@arg(1, to_image_color('RGB')) @arg(1, to_image_color('RGB'))
def __call__(self, data:):
def __call__(self, data):
vec = self._inference_from_image(data) vec = self._inference_from_image(data)
return vec return vec
def _preprocess(self, img): def _preprocess(self, img):
img = to_pil(img) img = to_pil(img)
processed_img = self.self.clip_tfms(img).unsqueeze(0).to(self.device)
processed_img = self.clip_tfms(img).unsqueeze(0).to(self.device)
return processed_img return processed_img
@arg(1, to_image_color('RGB')) @arg(1, to_image_color('RGB'))
def _inference_from_image(self, img): def _inference_from_image(self, img):
img = self._preprocess(img) img = self._preprocess(img)
clip_feat = self.clip_model.encode_image(image)
clip_feat = self.clip_model.encode_image(img)
prefix_length = 10
prefix_embed = self.model.clip_project(clip_feat).reshape(1, prefix_length, -1)
self.prefix_length = 10
prefix_embed = self.model.clip_project(clip_feat).reshape(1, self.prefix_length, -1)
generated_text_prefix = generate_beam(model, tokenizer, embed=prefix_embed)[0]
generated_text_prefix = self.generate_beam(self.model, self.tokenizer, embed=prefix_embed)[0]
return generated_text_prefix return generated_text_prefix
def _configs(self): def _configs(self):
config = {} config = {}
config['clipcap_coco'] = {} config['clipcap_coco'] = {}
config['clipcap_coco']['weights'] = 'weights/coco_weights.pt'
config['clipcap_coco']['weights'] = 'coco_weights.pt'
config['clipcap_conceptual'] = {} config['clipcap_conceptual'] = {}
config['clipcap_conceptual']['weights'] = 'weights/conceptual_weights.pt'
config['clipcap_conceptual']['weights'] = 'conceptual_weights.pt'
return config return config

Loading…
Cancel
Save