|
|
@ -21,7 +21,7 @@ from towhee.types.image_utils import to_pil |
|
|
|
from towhee.operator.base import NNOperator, OperatorFlag |
|
|
|
from towhee.types.arg import arg, to_image_color |
|
|
|
from towhee import register |
|
|
|
from towhee.models import clip |
|
|
|
from transformers import CLIPTokenizer, CLIPTextModel ,CLIPModel,CLIPProcessor |
|
|
|
|
|
|
|
|
|
|
|
@register(output_schema=['vec']) |
|
|
@ -32,15 +32,9 @@ class Clip(NNOperator): |
|
|
|
def __init__(self, model_name: str, modality: str): |
|
|
|
self.modality = modality |
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
self.model = clip.create_model(model_name=model_name, pretrained=True, jit=True) |
|
|
|
self.tokenize = clip.tokenize |
|
|
|
self.tfms = transforms.Compose([ |
|
|
|
transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC), |
|
|
|
transforms.CenterCrop(224), |
|
|
|
transforms.ToTensor(), |
|
|
|
transforms.Normalize( |
|
|
|
(0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) |
|
|
|
]) |
|
|
|
self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") |
|
|
|
self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") |
|
|
|
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") |
|
|
|
|
|
|
|
def inference_single_data(self, data): |
|
|
|
if self.modality == 'image': |
|
|
@ -66,13 +60,13 @@ class Clip(NNOperator): |
|
|
|
return results |
|
|
|
|
|
|
|
def _inference_from_text(self, text): |
|
|
|
text = self.tokenize(text).to(self.device) |
|
|
|
text_features = self.model.encode_text(text) |
|
|
|
tokens = self.tokenizer([text], padding=True, return_tensors="pt") |
|
|
|
text_features = self.model.get_text_features(**tokens) |
|
|
|
return text_features |
|
|
|
|
|
|
|
@arg(1, to_image_color('RGB')) |
|
|
|
def _inference_from_image(self, img): |
|
|
|
img = to_pil(img) |
|
|
|
image = self.tfms(img).unsqueeze(0).to(self.device) |
|
|
|
image_features = self.model.encode_image(image) |
|
|
|
inputs = processor(images=img, return_tensors="pt") |
|
|
|
image_features = self.model.get_image_features(**inputs) |
|
|
|
return image_features |
|
|
|