Browse Source
adjust for batch data input.
Signed-off-by: wxywb <xy.wang@zilliz.com>
hf
wxywb
3 years ago
1 changed files with
15 additions and
1 deletions
-
clip.py
|
|
@ -42,7 +42,7 @@ class Clip(NNOperator): |
|
|
|
(0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) |
|
|
|
]) |
|
|
|
|
|
|
|
def __call__(self, data): |
|
|
|
def inference_single_data(self, data): |
|
|
|
if self.modality == 'image': |
|
|
|
vec = self._inference_from_image(data) |
|
|
|
elif self.modality == 'text': |
|
|
@ -51,6 +51,20 @@ class Clip(NNOperator): |
|
|
|
raise ValueError("modality[{}] not implemented.".format(self._modality)) |
|
|
|
return vec.detach().cpu().numpy().flatten() |
|
|
|
|
|
|
|
def __call__(self, data): |
|
|
|
if not isinstance(data, list): |
|
|
|
data = [data] |
|
|
|
else: |
|
|
|
data = data |
|
|
|
results = [] |
|
|
|
for single_data in data: |
|
|
|
result = self.inference_single_data(single_data) |
|
|
|
results.append(result) |
|
|
|
if len(data) == 1: |
|
|
|
return results[0] |
|
|
|
else: |
|
|
|
return results |
|
|
|
|
|
|
|
def _inference_from_text(self, text): |
|
|
|
text = self.tokenize(text).to(self.device) |
|
|
|
text_features = self.model.encode_text(text) |
|
|
|