|
@ -34,6 +34,7 @@ class Jaclip(NNOperator): |
|
|
import japanese_clip as ja_clip |
|
|
import japanese_clip as ja_clip |
|
|
sys.path.pop() |
|
|
sys.path.pop() |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
self._modality = modality |
|
|
model, preprocess = ja_clip.load("rinna/japanese-clip-vit-b-16", cache_dir="{}/weights/japanese_clip".format(path), device=self.device) |
|
|
model, preprocess = ja_clip.load("rinna/japanese-clip-vit-b-16", cache_dir="{}/weights/japanese_clip".format(path), device=self.device) |
|
|
self.model = model |
|
|
self.model = model |
|
|
self.tfms = preprocess |
|
|
self.tfms = preprocess |
|
@ -51,20 +52,20 @@ class Jaclip(NNOperator): |
|
|
return vec.detach().cpu().numpy().flatten() |
|
|
return vec.detach().cpu().numpy().flatten() |
|
|
|
|
|
|
|
|
def _inference_from_text(self, text): |
|
|
def _inference_from_text(self, text): |
|
|
encodings = ja_clip.tokenize( |
|
|
|
|
|
|
|
|
encodings = self.ja_clip.tokenize( |
|
|
texts=[text], |
|
|
texts=[text], |
|
|
max_seq_len=77, |
|
|
max_seq_len=77, |
|
|
device=self.device, |
|
|
device=self.device, |
|
|
tokenizer=self.tokenizer, # this is optional. if you don't pass, load tokenizer each time |
|
|
tokenizer=self.tokenizer, # this is optional. if you don't pass, load tokenizer each time |
|
|
) |
|
|
) |
|
|
text_feature = model.get_text_features(**encodings) |
|
|
|
|
|
|
|
|
text_feature = self.model.get_text_features(**encodings) |
|
|
return text_feature |
|
|
return text_feature |
|
|
|
|
|
|
|
|
@arg(1, to_image_color('RGB')) |
|
|
@arg(1, to_image_color('RGB')) |
|
|
def _inference_from_image(self, img): |
|
|
def _inference_from_image(self, img): |
|
|
img = self._preprocess(img) |
|
|
img = self._preprocess(img) |
|
|
caption = '' |
|
|
caption = '' |
|
|
image_feature = self.model.get_image_features(image) |
|
|
|
|
|
|
|
|
image_feature = self.model.get_image_features(img) |
|
|
return image_feature |
|
|
return image_feature |
|
|
|
|
|
|
|
|
def _preprocess(self, img): |
|
|
def _preprocess(self, img): |
|
|