diff --git a/__init__.py b/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/taiyi.py b/taiyi.py new file mode 100644 index 0000000..d8931ba --- /dev/null +++ b/taiyi.py @@ -0,0 +1,65 @@ +# Copyright 2021 Zilliz. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +from transformers import BertForSequenceClassification, BertConfig, BertTokenizer +from transformers import CLIPProcessor, CLIPModel + +@register(output_schema=['vec']) +class Taiyi(NNOperator): + """ + Taiyi multi-modal embedding operator + """ + def __init__(self, model_name: str, modality: str): + self.modality = modality + self.device = "cuda" if torch.cuda.is_available() else "cpu" + self.text_tokenizer = BertTokenizer.from_pretrained("IDEA-CCNL/Taiyi-CLIP-Roberta-large-326M-Chinese") + self.text_encoder = BertForSequenceClassification.from_pretrained("IDEA-CCNL/Taiyi-CLIP-Roberta-large-326M-Chinese").eval() + self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14") + self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14") + + def inference_single_data(self, data): + if self.modality == 'image': + vec = self._inference_from_image(data) + elif self.modality == 'text': + vec = self._inference_from_text(data) + else: + raise ValueError("modality[{}] not implemented.".format(self._modality)) + return vec.detach().cpu().numpy().flatten() + + def __call__(self, data): + if not isinstance(data, list): + data = [data] + else: + data = data + results = [] + for single_data in data: + result = self.inference_single_data(single_data) + results.append(result) + if len(data) == 1: + return results[0] + else: + return results + + def _inference_from_text(self, text): + self.text = self.text_tokenizer(text, return_tensors='pt', padding=True)['input_ids'].to(self.device) + text_features = text_encoder(text).logits + return text_features + + @arg(1, to_image_color('RGB')) + def _inference_from_image(self, img): + image = to_pil(image) + image = self.processor(images=image.raw), return_tensors="pt") + image_features = clip_model.get_image_features(**image) + return image_features +