|
@ -55,8 +55,8 @@ class ClipCap(NNOperator): |
|
|
self.model = ClipCaptionModel(self.prefix_length) |
|
|
self.model = ClipCaptionModel(self.prefix_length) |
|
|
model_path = os.path.dirname(__file__) + '/weights/' + config['weights'] |
|
|
model_path = os.path.dirname(__file__) + '/weights/' + config['weights'] |
|
|
self.model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) |
|
|
self.model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) |
|
|
|
|
|
self.model.to(self.device) |
|
|
self.model = self.model.eval() |
|
|
self.model = self.model.eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@arg(1, to_image_color('RGB')) |
|
|
@arg(1, to_image_color('RGB')) |
|
|
def inference_single_data(self, data): |
|
|
def inference_single_data(self, data): |
|
@ -85,7 +85,7 @@ class ClipCap(NNOperator): |
|
|
@arg(1, to_image_color('RGB')) |
|
|
@arg(1, to_image_color('RGB')) |
|
|
def _inference_from_image(self, img): |
|
|
def _inference_from_image(self, img): |
|
|
img = self._preprocess(img) |
|
|
img = self._preprocess(img) |
|
|
clip_feat = self.clip_model.encode_image(img) |
|
|
|
|
|
|
|
|
clip_feat = self.clip_model.encode_image(img).float() |
|
|
|
|
|
|
|
|
self.prefix_length = 10 |
|
|
self.prefix_length = 10 |
|
|
prefix_embed = self.model.clip_project(clip_feat).reshape(1, self.prefix_length, -1) |
|
|
prefix_embed = self.model.clip_project(clip_feat).reshape(1, self.prefix_length, -1) |
|
|