logo
Browse Source

:w

:update the clip for serialization.

Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb 2 years ago
parent
commit
9ecdb950f8
  1. 4
      __init__.py
  2. 155
      clip.py
  3. 3
      requirements.txt

4
__init__.py

@ -15,5 +15,5 @@
from .clip import Clip from .clip import Clip
def clip(model_name: str, modality: str):
return Clip(model_name, modality)
def clip(model_name: str, modality: str, device:str = None, checkpoint_path:str = None):
return Clip(model_name, modality, device, checkpoint_path)

155
clip.py

@ -11,10 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import sys import sys
from pathlib import Path from pathlib import Path
import torch import torch
from torch import nn
from torchvision import transforms from torchvision import transforms
from towhee.types.image_utils import to_pil from towhee.types.image_utils import to_pil
@ -22,29 +22,55 @@ from towhee.operator.base import NNOperator, OperatorFlag
from towhee.types.arg import arg, to_image_color from towhee.types.arg import arg, to_image_color
from towhee import register from towhee import register
from transformers import CLIPTokenizer, CLIPTextModel ,CLIPModel,CLIPProcessor from transformers import CLIPTokenizer, CLIPTextModel ,CLIPModel,CLIPProcessor
from train_clip_with_hf_trainer import train_with_hf_trainer
#@accelerate
class CLIPModelVision(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, pixel_values):
image_embeds = self.model.get_image_features(pixel_values)
return image_embeds
#@accelerate
class CLIPModelText(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, input_ids, attention_mask):
text_embeds = self.model.get_text_features(input_ids, attention_mask)
return text_embeds
@register(output_schema=['vec']) @register(output_schema=['vec'])
class Clip(NNOperator): class Clip(NNOperator):
""" """
CLIP multi-modal embedding operator CLIP multi-modal embedding operator
""" """
def __init__(self, model_name: str, modality: str):
def __init__(self, model_name: str, modality: str, device, checkpoint_path):
self.model_name = model_name
self.modality = modality self.modality = modality
self.device = "cuda" if torch.cuda.is_available() else "cpu" self.device = "cuda" if torch.cuda.is_available() else "cpu"
cfg = self._configs()[model_name] cfg = self._configs()[model_name]
self.model = CLIPModel.from_pretrained(cfg['name'])
self.tokenizer = CLIPTokenizer.from_pretrained(cfg['name'])
self.processor = CLIPProcessor.from_pretrained(cfg['name'])
clip_model = CLIPModel.from_pretrained(cfg)
if self.modality == 'image':
self.model = CLIPModelVision(clip_model)
elif self.modality == 'text':
self.model = CLIPModelText(clip_model)
else:
raise ValueError("modality[{}] not implemented.".format(self.modality))
self.tokenizer = CLIPTokenizer.from_pretrained(cfg)
self.processor = CLIPProcessor.from_pretrained(cfg)
def inference_single_data(self, data): def inference_single_data(self, data):
if self.modality == 'image': if self.modality == 'image':
vec = self._inference_from_image(data) vec = self._inference_from_image(data)
elif self.modality == 'text':
elif self.modality == 'text':
vec = self._inference_from_text(data) vec = self._inference_from_text(data)
else: else:
raise ValueError("modality[{}] not implemented.".format(self._modality))
raise ValueError("modality[{}] not implemented.".format(self.modality))
return vec.detach().cpu().numpy().flatten() return vec.detach().cpu().numpy().flatten()
def __call__(self, data): def __call__(self, data):
@ -63,29 +89,122 @@ class Clip(NNOperator):
def _inference_from_text(self, text): def _inference_from_text(self, text):
tokens = self.tokenizer([text], padding=True, return_tensors="pt") tokens = self.tokenizer([text], padding=True, return_tensors="pt")
text_features = self.model.get_text_features(**tokens)
text_features = self.model(tokens['input_ids'],tokens['attention_mask'])
return text_features return text_features
@arg(1, to_image_color('RGB')) @arg(1, to_image_color('RGB'))
def _inference_from_image(self, img): def _inference_from_image(self, img):
img = to_pil(img) img = to_pil(img)
inputs = processor(images=img, return_tensors="pt")
image_features = self.model.get_image_features(**inputs)
inputs = self.processor(images=img, return_tensors="pt")
image_features = self.model(inputs['pixel_values'])
return image_features return image_features
def train(self, **kwargs): def train(self, **kwargs):
import sys
import pathlib
path = str(pathlib.Path(__file__).parent)
print(path)
sys.path.append(path)
from train_clip_with_hf_trainer import train_with_hf_trainer
data_args = kwargs.pop('data_args', None) data_args = kwargs.pop('data_args', None)
training_args = kwargs.pop('training_args', None) training_args = kwargs.pop('training_args', None)
train_with_hf_trainer(self.model, self.tokenizer, data_args, training_args) train_with_hf_trainer(self.model, self.tokenizer, data_args, training_args)
def _configs(self): def _configs(self):
config = {} config = {}
config['clip_vit_base_32'] = {}
config['clip_vit_base_32']['name'] = 'openai/clip-vit-base-patch16'
config['clip_vit_base_16'] = {}
config['clip_vit_base_16']['name'] = 'openai/clip-vit-base-patch32'
config['clip_vit_large_14'] = {}
config['clip_vit_base_32'] = 'openai/clip-vit-base-patch16'
config['clip_vit_base_16'] = 'openai/clip-vit-base-patch32'
config['clip_vit_large_14'] = 'openai/clip-vit-large-patch14' config['clip_vit_large_14'] = 'openai/clip-vit-large-patch14'
config['clip_vit_large_14_336'] = {}
config['clip_vit_large_14_336']['name'] ='openai/clip-vit-large-patch14-336'
config['clip_vit_large_14_336'] ='openai/clip-vit-large-patch14-336'
return config return config
@property
def supported_formats(self):
onnxes = self.supported_model_names(format='onnx')
if self.model_name in onnxes:
return ['onnx']
else:
return ['pytorch']
@staticmethod
def supported_model_names(format: str = None):
if format == 'pytorch' or format == 'torchscript' or format == 'onnx':
model_list = [
'openai/clip-vit-base-patch16',
'openai/clip-vit-base-patch32',
'openai/clip-vit-large-patch14',
'openai/clip-vit-large-patch14-336'
]
else:
log.error(f'Invalid format "{format}". Currently supported formats: "pytorch", "torchscript".')
return model_list
@property
def _model(self):
return self.model
def save_model(self, model_type: str = 'pytorch', output_file: str = 'default'):
import os
from PIL import Image
from torch.onnx import export as onnx_export
if output_file == 'default':
output_file = str(Path(__file__).parent)
output_file = os.path.join(output_file, 'saved', model_type)
os.makedirs(output_file, exist_ok=True)
name = self.model_name.replace('/', '-')
output_file = os.path.join(output_file, name)
if model_type in ['pytorch', 'torchscript']:
output_file = output_file + '.pt'
elif model_type == 'onnx':
output_file = output_file + '.onnx'
else:
raise AttributeError('Unsupported model_type.')
if self.modality == 'image':
sz = self.processor.feature_extractor.crop_size
dummy_input = Image.new('RGB', (sz, sz), color = 'red')
inputs = self.processor(images=dummy_input, return_tensors='pt') # a dictionary
elif self.modality == 'text':
dummy_input = 'dummy'
inputs = self.tokenizer(dummy_input, padding=True, truncation=True, return_tensors='pt') # a dictionary
else:
raise ValueError("modality[{}] not implemented.".format(self.modality))
if model_type == 'pytorch':
torch.save(self._model, output_file)
elif model_type == 'torchscript':
inputs = list(inputs.values())
try:
try:
jit_model = torch.jit.script(self._model)
except Exception:
jit_model = torch.jit.trace(self._model, inputs, strict=False)
torch.jit.save(jit_model, output_file)
except Exception as e:
log.error(f'Fail to save as torchscript: {e}.')
raise RuntimeError(f'Fail to save as torchscript: {e}.')
elif model_type == 'onnx':
if self.modality == 'image':
input_names= ['pixel_values']
output_names=['image_embeds']
dynamic_axes={'pixel_values': {0: 'batch'}, 'image_embeds': {0: 'batch'}}
elif self.modality == 'text':
input_names= ['input_ids', 'attention_mask']
output_names=['text_embeds']
dynamic_axes={'input_ids': {0: 'batch', 1: 'sequence'}, 'attention_mask': {0: 'batch', 1: 'sequence'}, 'text_embeds': {0: 'batch'}}
else:
raise ValueError("modality[{}] not implemented.".format(self.modality))
onnx_export(self.model,
(dict(inputs),),
f=Path(output_file),
input_names= input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
do_constant_folding=True,
opset_version=14,
)
else:
pass
raise NotImplementedError

3
requirements.txt

@ -1,4 +1,5 @@
torchvision torchvision
torch torch
towhee towhee
towhee.models
towhee.models
transformers

Loading…
Cancel
Save