logo
Browse Source

fix the operator.

Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb 2 years ago
parent
commit
c9f4bf582a
  1. 2
      __init__.py
  2. 7
      jclip.py

2
__init__.py

@ -14,5 +14,5 @@
from .jclip import Jaclip
def jclip(model_name: str, modality: str):
def japanese_clip(model_name: str, modality: str):
return Jaclip(model_name, modality)

7
jclip.py

@ -34,6 +34,7 @@ class Jaclip(NNOperator):
import japanese_clip as ja_clip
sys.path.pop()
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self._modality = modality
model, preprocess = ja_clip.load("rinna/japanese-clip-vit-b-16", cache_dir="{}/weights/japanese_clip".format(path), device=self.device)
self.model = model
self.tfms = preprocess
@ -51,20 +52,20 @@ class Jaclip(NNOperator):
return vec.detach().cpu().numpy().flatten()
def _inference_from_text(self, text):
encodings = ja_clip.tokenize(
encodings = self.ja_clip.tokenize(
texts=[text],
max_seq_len=77,
device=self.device,
tokenizer=self.tokenizer, # this is optional. if you don't pass, load tokenizer each time
)
text_feature = model.get_text_features(**encodings)
text_feature = self.model.get_text_features(**encodings)
return text_feature
@arg(1, to_image_color('RGB'))
def _inference_from_image(self, img):
img = self._preprocess(img)
caption = ''
image_feature = self.model.get_image_features(image)
image_feature = self.model.get_image_features(img)
return image_feature
def _preprocess(self, img):

Loading…
Cancel
Save