logo
Browse Source

adjust for batch data input.

Signed-off-by: wxywb <xy.wang@zilliz.com>
hf
wxywb 2 years ago
parent
commit
69ee0f99b5
  1. 16
      clip.py

16
clip.py

@ -42,7 +42,7 @@ class Clip(NNOperator):
(0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
]) ])
def __call__(self, data):
def inference_single_data(self, data):
if self.modality == 'image': if self.modality == 'image':
vec = self._inference_from_image(data) vec = self._inference_from_image(data)
elif self.modality == 'text': elif self.modality == 'text':
@ -51,6 +51,20 @@ class Clip(NNOperator):
raise ValueError("modality[{}] not implemented.".format(self._modality)) raise ValueError("modality[{}] not implemented.".format(self._modality))
return vec.detach().cpu().numpy().flatten() return vec.detach().cpu().numpy().flatten()
def __call__(self, data):
if not isinstance(data, list):
data = [data]
else:
data = data
results = []
for single_data in data:
result = self.inference_single_data(single_data)
results.append(result)
if len(data) == 1:
return results[0]
else:
return results
def _inference_from_text(self, text): def _inference_from_text(self, text):
text = self.tokenize(text).to(self.device) text = self.tokenize(text).to(self.device)
text_features = self.model.encode_text(text) text_features = self.model.encode_text(text)

Loading…
Cancel
Save