diff --git a/__init__.py b/__init__.py index 9d2895f..2f252bf 100644 --- a/__init__.py +++ b/__init__.py @@ -14,5 +14,5 @@ from .jclip import Jaclip -def jclip(model_name: str, modality: str): +def japanese_clip(model_name: str, modality: str): return Jaclip(model_name, modality) diff --git a/jclip.py b/jclip.py index aa3d4dc..c11e8bf 100644 --- a/jclip.py +++ b/jclip.py @@ -34,6 +34,7 @@ class Jaclip(NNOperator): import japanese_clip as ja_clip sys.path.pop() self.device = "cuda" if torch.cuda.is_available() else "cpu" + self._modality = modality model, preprocess = ja_clip.load("rinna/japanese-clip-vit-b-16", cache_dir="{}/weights/japanese_clip".format(path), device=self.device) self.model = model self.tfms = preprocess @@ -51,20 +52,20 @@ class Jaclip(NNOperator): return vec.detach().cpu().numpy().flatten() def _inference_from_text(self, text): - encodings = ja_clip.tokenize( + encodings = self.ja_clip.tokenize( texts=[text], max_seq_len=77, device=self.device, tokenizer=self.tokenizer, # this is optional. if you don't pass, load tokenizer each time ) - text_feature = model.get_text_features(**encodings) + text_feature = self.model.get_text_features(**encodings) return text_feature @arg(1, to_image_color('RGB')) def _inference_from_image(self, img): img = self._preprocess(img) caption = '' - image_feature = self.model.get_image_features(image) + image_feature = self.model.get_image_features(img) return image_feature def _preprocess(self, img):