From 414d2b2f6066bc11150d9d803e8f5508332e5564 Mon Sep 17 00:00:00 2001 From: wxywb Date: Fri, 30 Dec 2022 11:02:51 +0000 Subject: [PATCH] test for hf clip. Signed-off-by: wxywb --- clip.py | 22 ++++++++-------------- 1 file changed, 8 insertions(+), 14 deletions(-) diff --git a/clip.py b/clip.py index a68ca44..9741d96 100644 --- a/clip.py +++ b/clip.py @@ -21,7 +21,7 @@ from towhee.types.image_utils import to_pil from towhee.operator.base import NNOperator, OperatorFlag from towhee.types.arg import arg, to_image_color from towhee import register -from towhee.models import clip +from transformers import CLIPTokenizer, CLIPTextModel ,CLIPModel,CLIPProcessor @register(output_schema=['vec']) @@ -32,15 +32,9 @@ class Clip(NNOperator): def __init__(self, model_name: str, modality: str): self.modality = modality self.device = "cuda" if torch.cuda.is_available() else "cpu" - self.model = clip.create_model(model_name=model_name, pretrained=True, jit=True) - self.tokenize = clip.tokenize - self.tfms = transforms.Compose([ - transforms.Resize(224, interpolation=transforms.InterpolationMode.BICUBIC), - transforms.CenterCrop(224), - transforms.ToTensor(), - transforms.Normalize( - (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)) - ]) + self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") + self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32") + self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") def inference_single_data(self, data): if self.modality == 'image': @@ -66,13 +60,13 @@ class Clip(NNOperator): return results def _inference_from_text(self, text): - text = self.tokenize(text).to(self.device) - text_features = self.model.encode_text(text) + tokens = self.tokenizer([text], padding=True, return_tensors="pt") + text_features = self.model.get_text_features(**tokens) return text_features @arg(1, to_image_color('RGB')) def _inference_from_image(self, img): img = to_pil(img) - image = self.tfms(img).unsqueeze(0).to(self.device) - image_features = self.model.encode_image(image) + inputs = processor(images=img, return_tensors="pt") + image_features = self.model.get_image_features(**inputs) return image_features