Browse Source
adjust for batch data.
Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb
2 years ago
1 changed files with
17 additions and
3 deletions
-
clipcap.py
|
|
@ -59,15 +59,29 @@ class ClipCap(NNOperator): |
|
|
|
|
|
|
|
|
|
|
|
@arg(1, to_image_color('RGB')) |
|
|
|
def __call__(self, data): |
|
|
|
vec = self._inference_from_image(data) |
|
|
|
return vec |
|
|
|
def inference_single_data(self, data): |
|
|
|
text = self._inference_from_image(data) |
|
|
|
return text |
|
|
|
|
|
|
|
def _preprocess(self, img): |
|
|
|
img = to_pil(img) |
|
|
|
processed_img = self.clip_tfms(img).unsqueeze(0).to(self.device) |
|
|
|
return processed_img |
|
|
|
|
|
|
|
def __call__(self, data): |
|
|
|
if not isinstance(data, list): |
|
|
|
data = [data] |
|
|
|
else: |
|
|
|
data = data |
|
|
|
results = [] |
|
|
|
for single_data in data: |
|
|
|
result = self.inference_single_data(single_data) |
|
|
|
results.append(result) |
|
|
|
if len(data) == 1: |
|
|
|
return results[0] |
|
|
|
else: |
|
|
|
return results |
|
|
|
|
|
|
|
@arg(1, to_image_color('RGB')) |
|
|
|
def _inference_from_image(self, img): |
|
|
|
img = self._preprocess(img) |
|
|
|