logo
Browse Source

add default value for device.

Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb 2 years ago
parent
commit
3af58627d6
  1. 10
      clip.py

10
clip.py

@ -22,6 +22,7 @@ from towhee.operator.base import NNOperator, OperatorFlag
from towhee.types.arg import arg, to_image_color from towhee.types.arg import arg, to_image_color
from towhee import register from towhee import register
from transformers import CLIPTokenizer, CLIPTextModel ,CLIPModel,CLIPProcessor from transformers import CLIPTokenizer, CLIPTextModel ,CLIPModel,CLIPProcessor
#from towhee.dc2 import accelerate
#@accelerate #@accelerate
class CLIPModelVision(nn.Module): class CLIPModelVision(nn.Module):
@ -49,10 +50,10 @@ class Clip(NNOperator):
""" """
CLIP multi-modal embedding operator CLIP multi-modal embedding operator
""" """
def __init__(self, model_name: str, modality: str, device, checkpoint_path):
def __init__(self, model_name: str, modality: str, device: str = 'cpu', checkpoint_path: str = None):
self.model_name = model_name self.model_name = model_name
self.modality = modality self.modality = modality
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = device
cfg = self._configs()[model_name] cfg = self._configs()[model_name]
try: try:
clip_model = CLIPModel.from_pretrained(cfg) clip_model = CLIPModel.from_pretrained(cfg)
@ -71,6 +72,7 @@ class Clip(NNOperator):
self.model = CLIPModelText(clip_model) self.model = CLIPModelText(clip_model)
else: else:
raise ValueError("modality[{}] not implemented.".format(self.modality)) raise ValueError("modality[{}] not implemented.".format(self.modality))
self.model.to(self.device)
self.tokenizer = CLIPTokenizer.from_pretrained(cfg) self.tokenizer = CLIPTokenizer.from_pretrained(cfg)
self.processor = CLIPProcessor.from_pretrained(cfg) self.processor = CLIPProcessor.from_pretrained(cfg)
@ -99,14 +101,14 @@ class Clip(NNOperator):
def _inference_from_text(self, text): def _inference_from_text(self, text):
tokens = self.tokenizer([text], padding=True, return_tensors="pt") tokens = self.tokenizer([text], padding=True, return_tensors="pt")
text_features = self.model(tokens['input_ids'],tokens['attention_mask'])
text_features = self.model(tokens['input_ids'].to(self.device), tokens['attention_mask'].to(self.device))
return text_features return text_features
@arg(1, to_image_color('RGB')) @arg(1, to_image_color('RGB'))
def _inference_from_image(self, img): def _inference_from_image(self, img):
img = to_pil(img) img = to_pil(img)
inputs = self.processor(images=img, return_tensors="pt") inputs = self.processor(images=img, return_tensors="pt")
image_features = self.model(inputs['pixel_values'])
image_features = self.model(inputs['pixel_values'].to(self.device))
return image_features return image_features
def train(self, **kwargs): def train(self, **kwargs):

Loading…
Cancel
Save