diff --git a/clipcap.py b/clipcap.py index f99c8f8..635b5c9 100644 --- a/clipcap.py +++ b/clipcap.py @@ -11,11 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +import sys import os -import torch from pathlib import Path + +import torch from torchvision import transforms +from transformers import GPT2Tokenizer +from towhee.types.arg import arg, to_image_color from towhee.types.image_utils import to_pil from towhee.operator.base import NNOperator, OperatorFlag from towhee import register @@ -26,11 +31,16 @@ class ClipCap(NNOperator): ClipCap image captioning operator """ def __init__(self, model_name: str): - super().__init__(): + super().__init__() sys.path.append(str(Path(__file__).parent)) - from models.clipcap import ClipCaptionModel + from models.clipcap import ClipCaptionModel, generate_beam + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.generate_beam = generate_beam + self.tokenizer = GPT2Tokenizer.from_pretrained("gpt2") config = self._configs()[model_name] + self.prefix_length = 10 + self.clip_tfms = self.tfms = transforms.Compose([ transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC), transforms.CenterCrop(224), @@ -42,38 +52,38 @@ class ClipCap(NNOperator): clip_model_type = 'clip_vit_b32' self.clip_model = clip.create_model(model_name=clip_model_type, pretrained=True, jit=True) - self.model = ClipCaptionModel(prefix = 10) + self.model = ClipCaptionModel(self.prefix_length) model_path = os.path.dirname(__file__) + '/weights/' + config['weights'] - self.model.load_state_dict(torch.load(model_path, map_location=CPU)) - self.model = model.eval() + self.model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) + self.model = self.model.eval() @arg(1, to_image_color('RGB')) - def __call__(self, data:): + def __call__(self, data): vec = self._inference_from_image(data) return vec def _preprocess(self, img): img = to_pil(img) - processed_img = self.self.clip_tfms(img).unsqueeze(0).to(self.device) + processed_img = self.clip_tfms(img).unsqueeze(0).to(self.device) return processed_img @arg(1, to_image_color('RGB')) def _inference_from_image(self, img): img = self._preprocess(img) - clip_feat = self.clip_model.encode_image(image) + clip_feat = self.clip_model.encode_image(img) - prefix_length = 10 - prefix_embed = self.model.clip_project(clip_feat).reshape(1, prefix_length, -1) + self.prefix_length = 10 + prefix_embed = self.model.clip_project(clip_feat).reshape(1, self.prefix_length, -1) - generated_text_prefix = generate_beam(model, tokenizer, embed=prefix_embed)[0] + generated_text_prefix = self.generate_beam(self.model, self.tokenizer, embed=prefix_embed)[0] return generated_text_prefix def _configs(self): config = {} config['clipcap_coco'] = {} - config['clipcap_coco']['weights'] = 'weights/coco_weights.pt' + config['clipcap_coco']['weights'] = 'coco_weights.pt' config['clipcap_conceptual'] = {} - config['clipcap_conceptual']['weights'] = 'weights/conceptual_weights.pt' + config['clipcap_conceptual']['weights'] = 'conceptual_weights.pt' return config