Browse Source
fix the operator.
Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb
2 years ago
2 changed files with
5 additions and
4 deletions
-
__init__.py
-
jclip.py
|
|
@ -14,5 +14,5 @@ |
|
|
|
|
|
|
|
from .jclip import Jaclip |
|
|
|
|
|
|
|
def jclip(model_name: str, modality: str): |
|
|
|
def japanese_clip(model_name: str, modality: str): |
|
|
|
return Jaclip(model_name, modality) |
|
|
|
|
|
@ -34,6 +34,7 @@ class Jaclip(NNOperator): |
|
|
|
import japanese_clip as ja_clip |
|
|
|
sys.path.pop() |
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
self._modality = modality |
|
|
|
model, preprocess = ja_clip.load("rinna/japanese-clip-vit-b-16", cache_dir="{}/weights/japanese_clip".format(path), device=self.device) |
|
|
|
self.model = model |
|
|
|
self.tfms = preprocess |
|
|
@ -51,20 +52,20 @@ class Jaclip(NNOperator): |
|
|
|
return vec.detach().cpu().numpy().flatten() |
|
|
|
|
|
|
|
def _inference_from_text(self, text): |
|
|
|
encodings = ja_clip.tokenize( |
|
|
|
encodings = self.ja_clip.tokenize( |
|
|
|
texts=[text], |
|
|
|
max_seq_len=77, |
|
|
|
device=self.device, |
|
|
|
tokenizer=self.tokenizer, # this is optional. if you don't pass, load tokenizer each time |
|
|
|
) |
|
|
|
text_feature = model.get_text_features(**encodings) |
|
|
|
text_feature = self.model.get_text_features(**encodings) |
|
|
|
return text_feature |
|
|
|
|
|
|
|
@arg(1, to_image_color('RGB')) |
|
|
|
def _inference_from_image(self, img): |
|
|
|
img = self._preprocess(img) |
|
|
|
caption = '' |
|
|
|
image_feature = self.model.get_image_features(image) |
|
|
|
image_feature = self.model.get_image_features(img) |
|
|
|
return image_feature |
|
|
|
|
|
|
|
def _preprocess(self, img): |
|
|
|