diff --git a/clip.py b/clip.py index 1db657c..a68ca44 100644 --- a/clip.py +++ b/clip.py @@ -42,7 +42,7 @@ class Clip(NNOperator): (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) ]) - def __call__(self, data): + def inference_single_data(self, data): if self.modality == 'image': vec = self._inference_from_image(data) elif self.modality == 'text': @@ -51,6 +51,20 @@ class Clip(NNOperator): raise ValueError("modality[{}] not implemented.".format(self._modality)) return vec.detach().cpu().numpy().flatten() + def __call__(self, data): + if not isinstance(data, list): + data = [data] + else: + data = data + results = [] + for single_data in data: + result = self.inference_single_data(single_data) + results.append(result) + if len(data) == 1: + return results[0] + else: + return results + def _inference_from_text(self, text): text = self.tokenize(text).to(self.device) text_features = self.model.encode_text(text)