logo
Browse Source

add default value for device.

Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb 2 years ago
parent
commit
3af58627d6
  1. 10
      clip.py

10
clip.py

@ -22,6 +22,7 @@ from towhee.operator.base import NNOperator, OperatorFlag
from towhee.types.arg import arg, to_image_color
from towhee import register
from transformers import CLIPTokenizer, CLIPTextModel ,CLIPModel,CLIPProcessor
#from towhee.dc2 import accelerate
#@accelerate
class CLIPModelVision(nn.Module):
@ -49,10 +50,10 @@ class Clip(NNOperator):
"""
CLIP multi-modal embedding operator
"""
def __init__(self, model_name: str, modality: str, device, checkpoint_path):
def __init__(self, model_name: str, modality: str, device: str = 'cpu', checkpoint_path: str = None):
self.model_name = model_name
self.modality = modality
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = device
cfg = self._configs()[model_name]
try:
clip_model = CLIPModel.from_pretrained(cfg)
@ -71,6 +72,7 @@ class Clip(NNOperator):
self.model = CLIPModelText(clip_model)
else:
raise ValueError("modality[{}] not implemented.".format(self.modality))
self.model.to(self.device)
self.tokenizer = CLIPTokenizer.from_pretrained(cfg)
self.processor = CLIPProcessor.from_pretrained(cfg)
@ -99,14 +101,14 @@ class Clip(NNOperator):
def _inference_from_text(self, text):
tokens = self.tokenizer([text], padding=True, return_tensors="pt")
text_features = self.model(tokens['input_ids'],tokens['attention_mask'])
text_features = self.model(tokens['input_ids'].to(self.device), tokens['attention_mask'].to(self.device))
return text_features
@arg(1, to_image_color('RGB'))
def _inference_from_image(self, img):
img = to_pil(img)
inputs = self.processor(images=img, return_tensors="pt")
image_features = self.model(inputs['pixel_values'])
image_features = self.model(inputs['pixel_values'].to(self.device))
return image_features
def train(self, **kwargs):

Loading…
Cancel
Save