japanese-clip
copied
wxywb
2 years ago
19 changed files with 3788 additions and 0 deletions
@ -0,0 +1 @@ |
|||||
|
*.pyc |
@ -0,0 +1,18 @@ |
|||||
|
# Copyright 2021 Zilliz. All rights reserved. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
||||
|
|
||||
|
from .jclip import Jaclip |
||||
|
|
||||
|
def jclip(model_name: str, modality: str): |
||||
|
return Jaclip(model_name, modality) |
@ -0,0 +1,19 @@ |
|||||
|
# coding=utf-8 |
||||
|
# Copyright 2022 rinna Co., Ltd. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
||||
|
|
||||
|
from .clip import CLIPModel, CLIPConfig |
||||
|
from .cloob import CLOOBModel, CLOOBConfig |
||||
|
from .auto_model import load, available_models |
||||
|
from .tokenizer import load_tokenizer, tokenize |
@ -0,0 +1,95 @@ |
|||||
|
# coding=utf-8 |
||||
|
# Copyright 2022 rinna Co., Ltd. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
||||
|
|
||||
|
from typing import Union |
||||
|
import json |
||||
|
import torch |
||||
|
from torchvision import transforms as T |
||||
|
from huggingface_hub import hf_hub_url, cached_download |
||||
|
import os |
||||
|
|
||||
|
from .clip import CLIPModel |
||||
|
from .cloob import CLOOBModel |
||||
|
|
||||
|
# TODO: Fill in repo_ids |
||||
|
MODELS = { |
||||
|
'rinna/japanese-clip-vit-b-16': { |
||||
|
'repo_id': 'rinna/japanese-clip-vit-b-16', |
||||
|
'model_class': CLIPModel, |
||||
|
}, |
||||
|
'rinna/japanese-cloob-vit-b-16': { |
||||
|
'repo_id': 'rinna/japanese-cloob-vit-b-16', |
||||
|
'model_class': CLOOBModel, |
||||
|
} |
||||
|
} |
||||
|
MODEL_CLASSES = { |
||||
|
"cloob": CLOOBModel, |
||||
|
"clip": CLIPModel, |
||||
|
} |
||||
|
MODEL_FILE = "pytorch_model.bin" |
||||
|
CONFIG_FILE = "config.json" |
||||
|
|
||||
|
|
||||
|
def available_models(): |
||||
|
return list(MODELS.keys()) |
||||
|
|
||||
|
|
||||
|
def _convert_to_rgb(image): |
||||
|
return image.convert('RGB') |
||||
|
|
||||
|
|
||||
|
def _transform(image_size): |
||||
|
return T.Compose([ |
||||
|
T.Resize(image_size, interpolation=T.InterpolationMode.BILINEAR), |
||||
|
T.CenterCrop(image_size), |
||||
|
_convert_to_rgb, |
||||
|
T.ToTensor(), |
||||
|
T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711),) |
||||
|
]) |
||||
|
|
||||
|
|
||||
|
def _download(repo_id: str, cache_dir: str): |
||||
|
config_file_url = hf_hub_url(repo_id=repo_id, filename=CONFIG_FILE) |
||||
|
cached_download(config_file_url, cache_dir=cache_dir) |
||||
|
model_file_url = hf_hub_url(repo_id=repo_id, filename=MODEL_FILE) |
||||
|
cached_download(model_file_url, cache_dir=cache_dir) |
||||
|
|
||||
|
|
||||
|
def load( |
||||
|
model_name: str, |
||||
|
device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", |
||||
|
**kwargs |
||||
|
): |
||||
|
""" |
||||
|
Args: |
||||
|
model_name: model unique name or path to pre-downloaded model |
||||
|
device: device to put the loaded model |
||||
|
kwargs: kwargs for huggingface pretrained model class |
||||
|
Return: |
||||
|
(torch.nn.Module, A torchvision transform) |
||||
|
""" |
||||
|
if model_name in MODELS.keys(): |
||||
|
ModelClass = CLIPModel if 'clip' in model_name else CLOOBModel |
||||
|
elif os.path.exists(model_name): |
||||
|
assert os.path.exists(os.path.join(model_name, CONFIG_FILE)) |
||||
|
with open(os.path.join(model_name, CONFIG_FILE), "r", encoding="utf-8") as f: |
||||
|
j = json.load(f) |
||||
|
ModelClass = MODEL_CLASSES[j["model_type"]] |
||||
|
else: |
||||
|
RuntimeError(f"Model {model_name} not found; available models = {available_models()}") |
||||
|
|
||||
|
model = ModelClass.from_pretrained(model_name, **kwargs) |
||||
|
model = model.eval().requires_grad_(False).to(device) |
||||
|
return model, _transform(model.config.vision_config.image_size) |
@ -0,0 +1,16 @@ |
|||||
|
# coding=utf-8 |
||||
|
# Copyright 2022 rinna Co., Ltd. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
||||
|
from .modeling_clip import * |
||||
|
from .configuration_clip import * |
@ -0,0 +1,219 @@ |
|||||
|
# coding=utf-8 |
||||
|
# Copyright 2022 rinna Co., Ltd. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
||||
|
""" CLIP model configuration""" |
||||
|
import logging |
||||
|
import copy |
||||
|
import os |
||||
|
from typing import Union |
||||
|
|
||||
|
import numpy as np |
||||
|
from transformers import AutoConfig, PretrainedConfig |
||||
|
|
||||
|
|
||||
|
logger = logging.getLogger(__name__) |
||||
|
|
||||
|
|
||||
|
class CLIPTextConfig(PretrainedConfig): |
||||
|
model_type = "clip_text_model" |
||||
|
|
||||
|
def __init__( |
||||
|
self, |
||||
|
vocab_size=49408, |
||||
|
hidden_size=512, |
||||
|
intermediate_size=2048, |
||||
|
num_hidden_layers=12, |
||||
|
num_attention_heads=8, |
||||
|
max_position_embeddings=77, |
||||
|
hidden_act="quick_gelu", |
||||
|
layer_norm_eps=0.00001, |
||||
|
dropout=0.0, |
||||
|
attention_dropout=0.0, |
||||
|
initializer_range=0.02, |
||||
|
initializer_factor=1.0, |
||||
|
pad_token_id=1, |
||||
|
bos_token_id=0, |
||||
|
eos_token_id=2, |
||||
|
**kwargs |
||||
|
): |
||||
|
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) |
||||
|
|
||||
|
self.vocab_size = vocab_size |
||||
|
self.hidden_size = hidden_size |
||||
|
self.intermediate_size = intermediate_size |
||||
|
self.dropout = dropout |
||||
|
self.num_hidden_layers = num_hidden_layers |
||||
|
self.num_attention_heads = num_attention_heads |
||||
|
self.max_position_embeddings = max_position_embeddings |
||||
|
self.layer_norm_eps = layer_norm_eps |
||||
|
self.hidden_act = hidden_act |
||||
|
self.initializer_range = initializer_range |
||||
|
self.initializer_factor = initializer_factor |
||||
|
self.attention_dropout = attention_dropout |
||||
|
|
||||
|
@classmethod |
||||
|
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": |
||||
|
|
||||
|
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) |
||||
|
|
||||
|
# get the text config dict if we are loading from CLIPConfig |
||||
|
if config_dict.get("model_type") == "clip": |
||||
|
config_dict = config_dict["text_config"] |
||||
|
|
||||
|
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: |
||||
|
logger.warning( |
||||
|
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " |
||||
|
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." |
||||
|
) |
||||
|
|
||||
|
return cls.from_dict(config_dict, **kwargs) |
||||
|
|
||||
|
|
||||
|
class CLIPVisionConfig(PretrainedConfig): |
||||
|
model_type = "clip_vision_model" |
||||
|
|
||||
|
def __init__( |
||||
|
self, |
||||
|
hidden_size=768, |
||||
|
intermediate_size=3072, |
||||
|
num_hidden_layers=12, |
||||
|
num_attention_heads=12, |
||||
|
image_size=224, |
||||
|
patch_size=32, |
||||
|
hidden_act="quick_gelu", |
||||
|
layer_norm_eps=0.00001, |
||||
|
dropout=0.0, |
||||
|
attention_dropout=0.0, |
||||
|
initializer_range=0.02, |
||||
|
initializer_factor=1.0, |
||||
|
**kwargs |
||||
|
): |
||||
|
super().__init__(**kwargs) |
||||
|
|
||||
|
self.hidden_size = hidden_size |
||||
|
self.intermediate_size = intermediate_size |
||||
|
self.dropout = dropout |
||||
|
self.num_hidden_layers = num_hidden_layers |
||||
|
self.num_attention_heads = num_attention_heads |
||||
|
self.patch_size = patch_size |
||||
|
self.image_size = image_size |
||||
|
self.initializer_range = initializer_range |
||||
|
self.initializer_factor = initializer_factor |
||||
|
self.attention_dropout = attention_dropout |
||||
|
self.layer_norm_eps = layer_norm_eps |
||||
|
self.hidden_act = hidden_act |
||||
|
|
||||
|
@classmethod |
||||
|
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": |
||||
|
|
||||
|
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) |
||||
|
|
||||
|
# get the vision config dict if we are loading from CLIPConfig |
||||
|
if config_dict.get("model_type") == "clip": |
||||
|
config_dict = config_dict["vision_config"] |
||||
|
|
||||
|
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: |
||||
|
logger.warning( |
||||
|
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " |
||||
|
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." |
||||
|
) |
||||
|
|
||||
|
return cls.from_dict(config_dict, **kwargs) |
||||
|
|
||||
|
|
||||
|
class CLIPConfig(PretrainedConfig): |
||||
|
r""" |
||||
|
[`CLIPConfig`] is the configuration class to store the configuration of a [`CLIPModel`]. It is used to instantiate |
||||
|
CLIP model according to the specified arguments, defining the text model and vision model configs. |
||||
|
|
||||
|
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the |
||||
|
documentation from [`PretrainedConfig`] for more information. |
||||
|
|
||||
|
Args: |
||||
|
text_config_dict (`dict`, *optional*): |
||||
|
Dictionary of configuration options used to initialize [`CLIPTextConfig`]. |
||||
|
vision_config_dict (`dict`, *optional*): |
||||
|
Dictionary of configuration options used to initialize [`CLIPVisionConfig`]. |
||||
|
projection_dim (`int`, *optional*, defaults to 512): |
||||
|
Dimentionality of text and vision projection layers. |
||||
|
logit_scale_init_value (`float`, *optional*, defaults to 2.6592): |
||||
|
The inital value of the *logit_scale* paramter. Default is used as per the original CLIP implementation. |
||||
|
kwargs (*optional*): |
||||
|
Dictionary of keyword arguments. |
||||
|
""" |
||||
|
|
||||
|
model_type = "clip" |
||||
|
is_composition = True |
||||
|
|
||||
|
def __init__( |
||||
|
self, |
||||
|
text_config=None, |
||||
|
vision_config=None, |
||||
|
projection_dim=512, |
||||
|
logit_scale_init_value=None, |
||||
|
**kwargs |
||||
|
): |
||||
|
super().__init__(text_config=text_config, vision_config=vision_config, **kwargs) |
||||
|
|
||||
|
if vision_config is None: |
||||
|
raise ValueError("`vision_config` can not be `None`.") |
||||
|
|
||||
|
if text_config is None: |
||||
|
raise ValueError("`text_config` can not be `None`.") |
||||
|
|
||||
|
vision_model_type = vision_config.pop("model_type") |
||||
|
text_model_type = text_config.pop("model_type") |
||||
|
|
||||
|
if vision_model_type == "clip_vision_model": |
||||
|
self.vision_config = CLIPVisionConfig(**vision_config) |
||||
|
else: |
||||
|
self.vision_config = AutoConfig.for_model( |
||||
|
vision_model_type, **vision_config |
||||
|
) |
||||
|
|
||||
|
if text_model_type == "clip_text_model": |
||||
|
self.text_config = CLIPTextConfig(**text_config) |
||||
|
else: |
||||
|
self.text_config = AutoConfig.for_model( |
||||
|
text_model_type, **text_config |
||||
|
) |
||||
|
|
||||
|
self.projection_dim = projection_dim |
||||
|
self.logit_scale_init_value = logit_scale_init_value if logit_scale_init_value is not None else np.log(1 / 0.07) |
||||
|
self.initializer_factor = 1.0 |
||||
|
|
||||
|
@classmethod |
||||
|
def from_text_vision_configs(cls, text_config: CLIPTextConfig, vision_config: CLIPVisionConfig, **kwargs): |
||||
|
r""" |
||||
|
Instantiate a [`CLIPConfig`] (or a derived class) from clip text model configuration and clip vision model |
||||
|
configuration. |
||||
|
|
||||
|
Returns: |
||||
|
[`CLIPConfig`]: An instance of a configuration object |
||||
|
""" |
||||
|
|
||||
|
return cls(text_config_dict=text_config.to_dict(), vision_config_dict=vision_config.to_dict(), **kwargs) |
||||
|
|
||||
|
def to_dict(self): |
||||
|
""" |
||||
|
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. |
||||
|
|
||||
|
Returns: |
||||
|
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, |
||||
|
""" |
||||
|
output = copy.deepcopy(self.__dict__) |
||||
|
output["text_config"] = self.text_config.to_dict() |
||||
|
output["vision_config"] = self.vision_config.to_dict() |
||||
|
output["model_type"] = self.__class__.model_type |
||||
|
return output |
@ -0,0 +1,815 @@ |
|||||
|
# coding=utf-8 |
||||
|
# Copyright 2022 rinna Co., Ltd. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
||||
|
import logging |
||||
|
from dataclasses import dataclass |
||||
|
from typing import Any, Optional, Tuple, Union |
||||
|
|
||||
|
import torch |
||||
|
import torch.utils.checkpoint |
||||
|
from torch import nn |
||||
|
|
||||
|
from transformers import AutoModel |
||||
|
from transformers.activations import ACT2FN |
||||
|
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling |
||||
|
from transformers.modeling_utils import PreTrainedModel, ModelOutput |
||||
|
from .configuration_clip import CLIPConfig, CLIPTextConfig, CLIPVisionConfig |
||||
|
|
||||
|
|
||||
|
logger = logging.getLogger(__name__) |
||||
|
|
||||
|
|
||||
|
# Copied from transformers.models.bart.modeling_bart._expand_mask |
||||
|
def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): |
||||
|
""" |
||||
|
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. |
||||
|
""" |
||||
|
bsz, src_len = mask.size() |
||||
|
tgt_len = tgt_len if tgt_len is not None else src_len |
||||
|
|
||||
|
expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) |
||||
|
|
||||
|
inverted_mask = 1.0 - expanded_mask |
||||
|
|
||||
|
return inverted_mask.masked_fill(inverted_mask.bool(), torch.finfo(dtype).min) |
||||
|
|
||||
|
|
||||
|
# contrastive loss function, adapted from |
||||
|
# https://sachinruk.github.io/blog/pytorch/pytorch%20lightning/loss%20function/gpu/2021/03/07/CLIP.html |
||||
|
def contrastive_loss(logits: torch.Tensor) -> torch.Tensor: |
||||
|
return nn.functional.cross_entropy(logits, torch.arange(len(logits), device=logits.device)) |
||||
|
|
||||
|
|
||||
|
def clip_loss(similarity: torch.Tensor) -> torch.Tensor: |
||||
|
caption_loss = contrastive_loss(similarity) |
||||
|
image_loss = contrastive_loss(similarity.T) |
||||
|
return (caption_loss + image_loss) / 2.0 |
||||
|
|
||||
|
|
||||
|
@dataclass |
||||
|
class CLIPOutput(ModelOutput): |
||||
|
loss: Optional[torch.FloatTensor] = None |
||||
|
logits_per_image: torch.FloatTensor = None |
||||
|
logits_per_text: torch.FloatTensor = None |
||||
|
text_embeds: torch.FloatTensor = None |
||||
|
image_embeds: torch.FloatTensor = None |
||||
|
text_model_output: BaseModelOutputWithPooling = None |
||||
|
vision_model_output: BaseModelOutputWithPooling = None |
||||
|
|
||||
|
def to_tuple(self) -> Tuple[Any]: |
||||
|
return tuple( |
||||
|
self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() |
||||
|
for k in self.keys() |
||||
|
) |
||||
|
|
||||
|
|
||||
|
class CLIPVisionEmbeddings(nn.Module): |
||||
|
def __init__(self, config: CLIPVisionConfig): |
||||
|
super().__init__() |
||||
|
self.config = config |
||||
|
self.embed_dim = config.hidden_size |
||||
|
self.image_size = config.image_size |
||||
|
self.patch_size = config.patch_size |
||||
|
|
||||
|
self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) |
||||
|
|
||||
|
self.patch_embedding = nn.Conv2d( |
||||
|
in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=False |
||||
|
) |
||||
|
|
||||
|
self.num_patches = (self.image_size // self.patch_size) ** 2 |
||||
|
self.num_positions = self.num_patches + 1 |
||||
|
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) |
||||
|
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1))) |
||||
|
|
||||
|
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: |
||||
|
batch_size = pixel_values.shape[0] |
||||
|
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid] |
||||
|
patch_embeds = patch_embeds.flatten(2).transpose(1, 2) |
||||
|
|
||||
|
class_embeds = self.class_embedding.expand(batch_size, 1, -1) |
||||
|
embeddings = torch.cat([class_embeds, patch_embeds], dim=1) |
||||
|
embeddings = embeddings + self.position_embedding(self.position_ids) |
||||
|
return embeddings |
||||
|
|
||||
|
|
||||
|
class CLIPTextEmbeddings(nn.Module): |
||||
|
def __init__(self, config: CLIPTextConfig): |
||||
|
super().__init__() |
||||
|
embed_dim = config.hidden_size |
||||
|
|
||||
|
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) |
||||
|
self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim) |
||||
|
|
||||
|
# position_ids (1, len position emb) is contiguous in memory and exported when serialized |
||||
|
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) |
||||
|
|
||||
|
def forward( |
||||
|
self, |
||||
|
input_ids: Optional[torch.LongTensor] = None, |
||||
|
position_ids: Optional[torch.LongTensor] = None, |
||||
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
||||
|
) -> torch.Tensor: |
||||
|
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] |
||||
|
|
||||
|
if position_ids is None: |
||||
|
position_ids = self.position_ids[:, :seq_length] |
||||
|
|
||||
|
if inputs_embeds is None: |
||||
|
inputs_embeds = self.token_embedding(input_ids) |
||||
|
|
||||
|
position_embeddings = self.position_embedding(position_ids) |
||||
|
embeddings = inputs_embeds + position_embeddings |
||||
|
|
||||
|
return embeddings |
||||
|
|
||||
|
|
||||
|
class CLIPAttention(nn.Module): |
||||
|
"""Multi-headed attention from 'Attention Is All You Need' paper""" |
||||
|
|
||||
|
def __init__(self, config): |
||||
|
super().__init__() |
||||
|
self.config = config |
||||
|
self.embed_dim = config.hidden_size |
||||
|
self.num_heads = config.num_attention_heads |
||||
|
self.head_dim = self.embed_dim // self.num_heads |
||||
|
if self.head_dim * self.num_heads != self.embed_dim: |
||||
|
raise ValueError( |
||||
|
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" |
||||
|
f" {self.num_heads})." |
||||
|
) |
||||
|
self.scale = self.head_dim**-0.5 |
||||
|
self.dropout = config.attention_dropout |
||||
|
|
||||
|
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) |
||||
|
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) |
||||
|
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) |
||||
|
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) |
||||
|
|
||||
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): |
||||
|
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() |
||||
|
|
||||
|
def forward( |
||||
|
self, |
||||
|
hidden_states: torch.Tensor, |
||||
|
attention_mask: Optional[torch.Tensor] = None, |
||||
|
causal_attention_mask: Optional[torch.Tensor] = None, |
||||
|
output_attentions: Optional[bool] = False, |
||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
||||
|
"""Input shape: Batch x Time x Channel""" |
||||
|
|
||||
|
bsz, tgt_len, embed_dim = hidden_states.size() |
||||
|
|
||||
|
# get query proj |
||||
|
query_states = self.q_proj(hidden_states) * self.scale |
||||
|
key_states = self._shape(self.k_proj(hidden_states), -1, bsz) |
||||
|
value_states = self._shape(self.v_proj(hidden_states), -1, bsz) |
||||
|
|
||||
|
proj_shape = (bsz * self.num_heads, -1, self.head_dim) |
||||
|
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) |
||||
|
key_states = key_states.view(*proj_shape) |
||||
|
value_states = value_states.view(*proj_shape) |
||||
|
|
||||
|
src_len = key_states.size(1) |
||||
|
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) |
||||
|
|
||||
|
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): |
||||
|
raise ValueError( |
||||
|
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" |
||||
|
f" {attn_weights.size()}" |
||||
|
) |
||||
|
|
||||
|
# apply the causal_attention_mask first |
||||
|
if causal_attention_mask is not None: |
||||
|
if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len): |
||||
|
raise ValueError( |
||||
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" |
||||
|
f" {causal_attention_mask.size()}" |
||||
|
) |
||||
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask |
||||
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) |
||||
|
|
||||
|
if attention_mask is not None: |
||||
|
if attention_mask.size() != (bsz, 1, tgt_len, src_len): |
||||
|
raise ValueError( |
||||
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" |
||||
|
) |
||||
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask |
||||
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) |
||||
|
|
||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1) |
||||
|
|
||||
|
if output_attentions: |
||||
|
# this operation is a bit akward, but it's required to |
||||
|
# make sure that attn_weights keeps its gradient. |
||||
|
# In order to do so, attn_weights have to reshaped |
||||
|
# twice and have to be reused in the following |
||||
|
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) |
||||
|
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) |
||||
|
else: |
||||
|
attn_weights_reshaped = None |
||||
|
|
||||
|
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) |
||||
|
|
||||
|
attn_output = torch.bmm(attn_probs, value_states) |
||||
|
|
||||
|
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): |
||||
|
raise ValueError( |
||||
|
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" |
||||
|
f" {attn_output.size()}" |
||||
|
) |
||||
|
|
||||
|
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) |
||||
|
attn_output = attn_output.transpose(1, 2) |
||||
|
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) |
||||
|
|
||||
|
attn_output = self.out_proj(attn_output) |
||||
|
|
||||
|
return attn_output, attn_weights_reshaped |
||||
|
|
||||
|
|
||||
|
class CLIPMLP(nn.Module): |
||||
|
def __init__(self, config): |
||||
|
super().__init__() |
||||
|
self.config = config |
||||
|
self.activation_fn = ACT2FN[config.hidden_act] |
||||
|
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) |
||||
|
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) |
||||
|
|
||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
||||
|
hidden_states = self.fc1(hidden_states) |
||||
|
hidden_states = self.activation_fn(hidden_states) |
||||
|
hidden_states = self.fc2(hidden_states) |
||||
|
return hidden_states |
||||
|
|
||||
|
|
||||
|
class CLIPEncoderLayer(nn.Module): |
||||
|
def __init__(self, config: CLIPConfig): |
||||
|
super().__init__() |
||||
|
self.embed_dim = config.hidden_size |
||||
|
self.self_attn = CLIPAttention(config) |
||||
|
self.layer_norm1 = nn.LayerNorm(self.embed_dim) |
||||
|
self.mlp = CLIPMLP(config) |
||||
|
self.layer_norm2 = nn.LayerNorm(self.embed_dim) |
||||
|
|
||||
|
def forward( |
||||
|
self, |
||||
|
hidden_states: torch.Tensor, |
||||
|
attention_mask: torch.Tensor, |
||||
|
causal_attention_mask: torch.Tensor, |
||||
|
output_attentions: Optional[bool] = False, |
||||
|
) -> Tuple[torch.FloatTensor]: |
||||
|
""" |
||||
|
Args: |
||||
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` |
||||
|
attention_mask (`torch.FloatTensor`): attention mask of size |
||||
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. |
||||
|
`(config.encoder_attention_heads,)`. |
||||
|
output_attentions (`bool`, *optional*): |
||||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
||||
|
returned tensors for more detail. |
||||
|
""" |
||||
|
residual = hidden_states |
||||
|
|
||||
|
hidden_states = self.layer_norm1(hidden_states) |
||||
|
hidden_states, attn_weights = self.self_attn( |
||||
|
hidden_states=hidden_states, |
||||
|
attention_mask=attention_mask, |
||||
|
causal_attention_mask=causal_attention_mask, |
||||
|
output_attentions=output_attentions, |
||||
|
) |
||||
|
hidden_states = residual + hidden_states |
||||
|
|
||||
|
residual = hidden_states |
||||
|
hidden_states = self.layer_norm2(hidden_states) |
||||
|
hidden_states = self.mlp(hidden_states) |
||||
|
hidden_states = residual + hidden_states |
||||
|
|
||||
|
outputs = (hidden_states,) |
||||
|
|
||||
|
if output_attentions: |
||||
|
outputs += (attn_weights,) |
||||
|
|
||||
|
return outputs |
||||
|
|
||||
|
|
||||
|
class CLIPPreTrainedModel(PreTrainedModel): |
||||
|
""" |
||||
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained |
||||
|
models. |
||||
|
""" |
||||
|
|
||||
|
config_class = CLIPConfig |
||||
|
base_model_prefix = "clip" |
||||
|
supports_gradient_checkpointing = True |
||||
|
_keys_to_ignore_on_load_missing = [r"position_ids"] |
||||
|
|
||||
|
def _init_weights(self, module): |
||||
|
"""Initialize the weights""" |
||||
|
factor = self.config.initializer_factor |
||||
|
if isinstance(module, CLIPTextEmbeddings): |
||||
|
module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) |
||||
|
module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) |
||||
|
elif isinstance(module, CLIPVisionEmbeddings): |
||||
|
factor = self.config.initializer_factor |
||||
|
nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) |
||||
|
nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) |
||||
|
nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor) |
||||
|
elif isinstance(module, CLIPAttention): |
||||
|
factor = self.config.initializer_factor |
||||
|
in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor |
||||
|
out_proj_std = (module.embed_dim**-0.5) * factor |
||||
|
nn.init.normal_(module.q_proj.weight, std=in_proj_std) |
||||
|
nn.init.normal_(module.k_proj.weight, std=in_proj_std) |
||||
|
nn.init.normal_(module.v_proj.weight, std=in_proj_std) |
||||
|
nn.init.normal_(module.out_proj.weight, std=out_proj_std) |
||||
|
elif isinstance(module, CLIPMLP): |
||||
|
factor = self.config.initializer_factor |
||||
|
in_proj_std = ( |
||||
|
(module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor |
||||
|
) |
||||
|
fc_std = (2 * module.config.hidden_size) ** -0.5 * factor |
||||
|
nn.init.normal_(module.fc1.weight, std=fc_std) |
||||
|
nn.init.normal_(module.fc2.weight, std=in_proj_std) |
||||
|
elif isinstance(module, CLIPModel): |
||||
|
nn.init.normal_( |
||||
|
module.text_projection.weight, |
||||
|
std=module.text_embed_dim**-0.5 * self.config.initializer_factor, |
||||
|
) |
||||
|
nn.init.normal_( |
||||
|
module.visual_projection.weight, |
||||
|
std=module.vision_embed_dim**-0.5 * self.config.initializer_factor, |
||||
|
) |
||||
|
|
||||
|
if isinstance(module, nn.LayerNorm): |
||||
|
module.bias.data.zero_() |
||||
|
module.weight.data.fill_(1.0) |
||||
|
if isinstance(module, nn.Linear) and module.bias is not None: |
||||
|
module.bias.data.zero_() |
||||
|
|
||||
|
def _set_gradient_checkpointing(self, module, value=False): |
||||
|
if isinstance(module, CLIPEncoder): |
||||
|
module.gradient_checkpointing = value |
||||
|
|
||||
|
|
||||
|
class CLIPEncoder(nn.Module): |
||||
|
""" |
||||
|
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a |
||||
|
[`CLIPEncoderLayer`]. |
||||
|
Args: |
||||
|
config: CLIPConfig |
||||
|
""" |
||||
|
|
||||
|
def __init__(self, config: CLIPConfig): |
||||
|
super().__init__() |
||||
|
self.config = config |
||||
|
self.layers = nn.ModuleList([CLIPEncoderLayer(config) for _ in range(config.num_hidden_layers)]) |
||||
|
self.gradient_checkpointing = False |
||||
|
|
||||
|
def forward( |
||||
|
self, |
||||
|
inputs_embeds, |
||||
|
attention_mask: Optional[torch.Tensor] = None, |
||||
|
causal_attention_mask: Optional[torch.Tensor] = None, |
||||
|
output_attentions: Optional[bool] = None, |
||||
|
output_hidden_states: Optional[bool] = None, |
||||
|
return_dict: Optional[bool] = None, |
||||
|
) -> Union[Tuple, BaseModelOutput]: |
||||
|
r""" |
||||
|
Args: |
||||
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): |
||||
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. |
||||
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors |
||||
|
than the model's internal embedding lookup matrix. |
||||
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): |
||||
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
||||
|
- 1 for tokens that are **not masked**, |
||||
|
- 0 for tokens that are **masked**. |
||||
|
[What are attention masks?](../glossary#attention-mask) |
||||
|
causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): |
||||
|
Causal mask for the text model. Mask values selected in `[0, 1]`: |
||||
|
- 1 for tokens that are **not masked**, |
||||
|
- 0 for tokens that are **masked**. |
||||
|
[What are attention masks?](../glossary#attention-mask) |
||||
|
output_attentions (`bool`, *optional*): |
||||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
||||
|
returned tensors for more detail. |
||||
|
output_hidden_states (`bool`, *optional*): |
||||
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors |
||||
|
for more detail. |
||||
|
return_dict (`bool`, *optional*): |
||||
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
||||
|
""" |
||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
||||
|
output_hidden_states = ( |
||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
||||
|
) |
||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
||||
|
|
||||
|
encoder_states = () if output_hidden_states else None |
||||
|
all_attentions = () if output_attentions else None |
||||
|
|
||||
|
hidden_states = inputs_embeds |
||||
|
for idx, encoder_layer in enumerate(self.layers): |
||||
|
if output_hidden_states: |
||||
|
encoder_states = encoder_states + (hidden_states,) |
||||
|
if self.gradient_checkpointing and self.training: |
||||
|
|
||||
|
def create_custom_forward(module): |
||||
|
def custom_forward(*inputs): |
||||
|
return module(*inputs, output_attentions) |
||||
|
|
||||
|
return custom_forward |
||||
|
|
||||
|
layer_outputs = torch.utils.checkpoint.checkpoint( |
||||
|
create_custom_forward(encoder_layer), |
||||
|
hidden_states, |
||||
|
attention_mask, |
||||
|
causal_attention_mask, |
||||
|
) |
||||
|
else: |
||||
|
layer_outputs = encoder_layer( |
||||
|
hidden_states, |
||||
|
attention_mask, |
||||
|
causal_attention_mask, |
||||
|
output_attentions=output_attentions, |
||||
|
) |
||||
|
|
||||
|
hidden_states = layer_outputs[0] |
||||
|
|
||||
|
if output_attentions: |
||||
|
all_attentions = all_attentions + (layer_outputs[1],) |
||||
|
|
||||
|
if output_hidden_states: |
||||
|
encoder_states = encoder_states + (hidden_states,) |
||||
|
|
||||
|
if not return_dict: |
||||
|
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) |
||||
|
return BaseModelOutput( |
||||
|
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions |
||||
|
) |
||||
|
|
||||
|
|
||||
|
class CLIPTextTransformer(nn.Module): |
||||
|
def __init__(self, config: CLIPTextConfig): |
||||
|
super().__init__() |
||||
|
self.config = config |
||||
|
embed_dim = config.hidden_size |
||||
|
self.embeddings = CLIPTextEmbeddings(config) |
||||
|
self.encoder = CLIPEncoder(config) |
||||
|
self.final_layer_norm = nn.LayerNorm(embed_dim) |
||||
|
|
||||
|
def forward( |
||||
|
self, |
||||
|
input_ids: Optional[torch.Tensor] = None, |
||||
|
attention_mask: Optional[torch.Tensor] = None, |
||||
|
position_ids: Optional[torch.Tensor] = None, |
||||
|
output_attentions: Optional[bool] = None, |
||||
|
output_hidden_states: Optional[bool] = None, |
||||
|
return_dict: Optional[bool] = None, |
||||
|
) -> Union[Tuple, BaseModelOutputWithPooling]: |
||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
||||
|
output_hidden_states = ( |
||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
||||
|
) |
||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
||||
|
|
||||
|
if input_ids is None: |
||||
|
raise ValueError("You have to specify either input_ids") |
||||
|
|
||||
|
input_shape = input_ids.size() |
||||
|
input_ids = input_ids.view(-1, input_shape[-1]) |
||||
|
|
||||
|
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) |
||||
|
|
||||
|
bsz, seq_len = input_shape |
||||
|
# CLIP's text model uses causal mask, prepare it here. |
||||
|
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324 |
||||
|
causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len).to(hidden_states.device) |
||||
|
# expand attention_mask |
||||
|
if attention_mask is not None: |
||||
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] |
||||
|
attention_mask = _expand_mask(attention_mask, hidden_states.dtype) |
||||
|
|
||||
|
encoder_outputs = self.encoder( |
||||
|
inputs_embeds=hidden_states, |
||||
|
attention_mask=attention_mask, |
||||
|
causal_attention_mask=causal_attention_mask, |
||||
|
output_attentions=output_attentions, |
||||
|
output_hidden_states=output_hidden_states, |
||||
|
return_dict=return_dict, |
||||
|
) |
||||
|
|
||||
|
last_hidden_state = encoder_outputs[0] |
||||
|
last_hidden_state = self.final_layer_norm(last_hidden_state) |
||||
|
|
||||
|
# text_embeds.shape = [batch_size, sequence_length, transformer.width] |
||||
|
# take features from the eot embedding (eot_token is the highest number in each sequence) |
||||
|
pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0]), input_ids.argmax(dim=-1)] |
||||
|
|
||||
|
if not return_dict: |
||||
|
return (last_hidden_state, pooled_output) + encoder_outputs[1:] |
||||
|
|
||||
|
return BaseModelOutputWithPooling( |
||||
|
last_hidden_state=last_hidden_state, |
||||
|
pooler_output=pooled_output, |
||||
|
hidden_states=encoder_outputs.hidden_states, |
||||
|
attentions=encoder_outputs.attentions, |
||||
|
) |
||||
|
|
||||
|
def _build_causal_attention_mask(self, bsz, seq_len): |
||||
|
# lazily create causal attention mask, with full attention between the vision tokens |
||||
|
# pytorch uses additive attention mask; fill with -inf |
||||
|
mask = torch.empty(bsz, seq_len, seq_len) |
||||
|
mask.fill_(float("-inf")) |
||||
|
mask.triu_(1) # zero out the lower diagonal |
||||
|
mask = mask.unsqueeze(1) # expand mask |
||||
|
return mask |
||||
|
|
||||
|
|
||||
|
class CLIPTextModel(CLIPPreTrainedModel): |
||||
|
config_class = CLIPTextConfig |
||||
|
|
||||
|
def __init__(self, config: CLIPTextConfig): |
||||
|
super().__init__(config) |
||||
|
self.text_model = CLIPTextTransformer(config) |
||||
|
# Initialize weights and apply final processing |
||||
|
self.post_init() |
||||
|
|
||||
|
def get_input_embeddings(self) -> nn.Module: |
||||
|
return self.text_model.embeddings.token_embedding |
||||
|
|
||||
|
def set_input_embeddings(self, value): |
||||
|
self.text_model.embeddings.token_embedding = value |
||||
|
|
||||
|
def forward( |
||||
|
self, |
||||
|
input_ids: Optional[torch.Tensor] = None, |
||||
|
attention_mask: Optional[torch.Tensor] = None, |
||||
|
position_ids: Optional[torch.Tensor] = None, |
||||
|
output_attentions: Optional[bool] = None, |
||||
|
output_hidden_states: Optional[bool] = None, |
||||
|
return_dict: Optional[bool] = None, |
||||
|
) -> Union[Tuple, BaseModelOutputWithPooling]: |
||||
|
return self.text_model( |
||||
|
input_ids=input_ids, |
||||
|
attention_mask=attention_mask, |
||||
|
position_ids=position_ids, |
||||
|
output_attentions=output_attentions, |
||||
|
output_hidden_states=output_hidden_states, |
||||
|
return_dict=return_dict, |
||||
|
) |
||||
|
|
||||
|
|
||||
|
class CLIPVisionTransformer(nn.Module): |
||||
|
def __init__(self, config: CLIPVisionConfig): |
||||
|
super().__init__() |
||||
|
self.config = config |
||||
|
embed_dim = config.hidden_size |
||||
|
|
||||
|
self.embeddings = CLIPVisionEmbeddings(config) |
||||
|
self.pre_layrnorm = nn.LayerNorm(embed_dim) |
||||
|
self.encoder = CLIPEncoder(config) |
||||
|
self.post_layernorm = nn.LayerNorm(embed_dim) |
||||
|
|
||||
|
def forward( |
||||
|
self, |
||||
|
pixel_values: Optional[torch.FloatTensor] = None, |
||||
|
output_attentions: Optional[bool] = None, |
||||
|
output_hidden_states: Optional[bool] = None, |
||||
|
return_dict: Optional[bool] = None, |
||||
|
) -> Union[Tuple, BaseModelOutputWithPooling]: |
||||
|
r""" |
||||
|
Returns: |
||||
|
""" |
||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
||||
|
output_hidden_states = ( |
||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
||||
|
) |
||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
||||
|
|
||||
|
if pixel_values is None: |
||||
|
raise ValueError("You have to specify pixel_values") |
||||
|
|
||||
|
hidden_states = self.embeddings(pixel_values) |
||||
|
hidden_states = self.pre_layrnorm(hidden_states) |
||||
|
|
||||
|
encoder_outputs = self.encoder( |
||||
|
inputs_embeds=hidden_states, |
||||
|
output_attentions=output_attentions, |
||||
|
output_hidden_states=output_hidden_states, |
||||
|
return_dict=return_dict, |
||||
|
) |
||||
|
|
||||
|
last_hidden_state = encoder_outputs[0] |
||||
|
pooled_output = last_hidden_state[:, 0, :] |
||||
|
pooled_output = self.post_layernorm(pooled_output) |
||||
|
|
||||
|
if not return_dict: |
||||
|
return (last_hidden_state, pooled_output) + encoder_outputs[1:] |
||||
|
|
||||
|
return BaseModelOutputWithPooling( |
||||
|
last_hidden_state=last_hidden_state, |
||||
|
pooler_output=pooled_output, |
||||
|
hidden_states=encoder_outputs.hidden_states, |
||||
|
attentions=encoder_outputs.attentions, |
||||
|
) |
||||
|
|
||||
|
|
||||
|
class CLIPVisionModel(CLIPPreTrainedModel): |
||||
|
config_class = CLIPVisionConfig |
||||
|
main_input_name = "pixel_values" |
||||
|
|
||||
|
def __init__(self, config: CLIPVisionConfig): |
||||
|
super().__init__(config) |
||||
|
self.vision_model = CLIPVisionTransformer(config) |
||||
|
# Initialize weights and apply final processing |
||||
|
self.post_init() |
||||
|
|
||||
|
def get_input_embeddings(self) -> nn.Module: |
||||
|
return self.vision_model.embeddings.patch_embedding |
||||
|
|
||||
|
def forward( |
||||
|
self, |
||||
|
pixel_values: Optional[torch.FloatTensor] = None, |
||||
|
output_attentions: Optional[bool] = None, |
||||
|
output_hidden_states: Optional[bool] = None, |
||||
|
return_dict: Optional[bool] = None, |
||||
|
) -> Union[Tuple, BaseModelOutputWithPooling]: |
||||
|
return self.vision_model( |
||||
|
pixel_values=pixel_values, |
||||
|
output_attentions=output_attentions, |
||||
|
output_hidden_states=output_hidden_states, |
||||
|
return_dict=return_dict, |
||||
|
) |
||||
|
|
||||
|
|
||||
|
class CLIPModel(CLIPPreTrainedModel): |
||||
|
config_class = CLIPConfig |
||||
|
|
||||
|
def __init__(self, config: CLIPConfig): |
||||
|
super().__init__(config) |
||||
|
text_config = config.text_config |
||||
|
vision_config = config.vision_config |
||||
|
|
||||
|
self.projection_dim = config.projection_dim |
||||
|
self.text_embed_dim = text_config.hidden_size |
||||
|
self.vision_embed_dim = vision_config.hidden_size |
||||
|
|
||||
|
if isinstance(text_config, CLIPTextConfig): |
||||
|
text_model = CLIPTextTransformer(text_config) |
||||
|
else: |
||||
|
text_model = AutoModel.from_config(config.text_config, add_pooling_layer=False) |
||||
|
|
||||
|
if isinstance(config.vision_config, CLIPVisionConfig): |
||||
|
vision_model = CLIPVisionModel(config.vision_config) |
||||
|
else: |
||||
|
vision_model = AutoModel.from_config(config.vision_config, add_pooling_layer=False) |
||||
|
|
||||
|
self.text_model = text_model |
||||
|
self.vision_model = vision_model |
||||
|
|
||||
|
self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False) |
||||
|
self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False) |
||||
|
self.logit_scale = nn.Parameter(torch.ones([]) * self.config.logit_scale_init_value) |
||||
|
|
||||
|
# Initialize weights and apply final processing |
||||
|
self.post_init() |
||||
|
|
||||
|
def encode_text(self, input_ids, **kwargs): |
||||
|
return self.get_text_features(input_ids=input_ids, **kwargs) |
||||
|
|
||||
|
def get_text_features( |
||||
|
self, |
||||
|
input_ids: Optional[torch.Tensor] = None, |
||||
|
attention_mask: Optional[torch.Tensor] = None, |
||||
|
position_ids: Optional[torch.Tensor] = None, |
||||
|
output_attentions: Optional[bool] = None, |
||||
|
output_hidden_states: Optional[bool] = None, |
||||
|
return_dict: Optional[bool] = None, |
||||
|
) -> torch.FloatTensor: |
||||
|
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components. |
||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
||||
|
output_hidden_states = ( |
||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
||||
|
) |
||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
||||
|
|
||||
|
text_outputs = self.text_model( |
||||
|
input_ids=input_ids, |
||||
|
attention_mask=attention_mask, |
||||
|
position_ids=position_ids, |
||||
|
output_attentions=output_attentions, |
||||
|
output_hidden_states=output_hidden_states, |
||||
|
return_dict=return_dict, |
||||
|
) |
||||
|
pooled_output = text_outputs.last_hidden_state[:, 0, :] |
||||
|
text_features = self.text_projection(pooled_output) |
||||
|
|
||||
|
return text_features |
||||
|
|
||||
|
def encode_image(self, pixel_values, **kwargs): |
||||
|
return self.get_image_features(pixel_values=pixel_values, **kwargs) |
||||
|
|
||||
|
def get_image_features( |
||||
|
self, |
||||
|
pixel_values: Optional[torch.FloatTensor] = None, |
||||
|
output_attentions: Optional[bool] = None, |
||||
|
output_hidden_states: Optional[bool] = None, |
||||
|
return_dict: Optional[bool] = None, |
||||
|
) -> torch.FloatTensor: |
||||
|
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components. |
||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
||||
|
output_hidden_states = ( |
||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
||||
|
) |
||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
||||
|
|
||||
|
vision_outputs = self.vision_model( |
||||
|
pixel_values=pixel_values, |
||||
|
output_attentions=output_attentions, |
||||
|
output_hidden_states=output_hidden_states, |
||||
|
return_dict=return_dict, |
||||
|
) |
||||
|
pooled_output = vision_outputs.last_hidden_state[:, 0, :] |
||||
|
image_features = self.visual_projection(pooled_output) |
||||
|
|
||||
|
return image_features |
||||
|
|
||||
|
def forward( |
||||
|
self, |
||||
|
input_ids: Optional[torch.LongTensor] = None, |
||||
|
pixel_values: Optional[torch.FloatTensor] = None, |
||||
|
attention_mask: Optional[torch.Tensor] = None, |
||||
|
position_ids: Optional[torch.LongTensor] = None, |
||||
|
return_loss: Optional[bool] = None, |
||||
|
output_attentions: Optional[bool] = None, |
||||
|
output_hidden_states: Optional[bool] = None, |
||||
|
return_dict: Optional[bool] = None, |
||||
|
) -> Union[Tuple, CLIPOutput]: |
||||
|
# Use CLIP model's config for some fields (if specified) instead of those of vision & text components. |
||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
||||
|
output_hidden_states = ( |
||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
||||
|
) |
||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
||||
|
|
||||
|
vision_outputs = self.vision_model( |
||||
|
pixel_values=pixel_values, |
||||
|
output_attentions=output_attentions, |
||||
|
output_hidden_states=output_hidden_states, |
||||
|
return_dict=return_dict, |
||||
|
) |
||||
|
|
||||
|
text_outputs = self.text_model( |
||||
|
input_ids=input_ids, |
||||
|
attention_mask=attention_mask, |
||||
|
position_ids=position_ids, |
||||
|
output_attentions=output_attentions, |
||||
|
output_hidden_states=output_hidden_states, |
||||
|
return_dict=return_dict, |
||||
|
) |
||||
|
image_embeds = vision_outputs.last_hidden_state[:, 0, :] |
||||
|
image_embeds = self.visual_projection(image_embeds) |
||||
|
|
||||
|
text_embeds = text_outputs.last_hidden_state[:, 0, :] |
||||
|
text_embeds = self.text_projection(text_embeds) |
||||
|
|
||||
|
# normalized features |
||||
|
image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True) |
||||
|
text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True) |
||||
|
|
||||
|
# cosine similarity as logits |
||||
|
logit_scale = self.logit_scale.exp() |
||||
|
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale |
||||
|
logits_per_image = logits_per_text.T |
||||
|
|
||||
|
loss = None |
||||
|
if return_loss: |
||||
|
loss = clip_loss(logits_per_text) |
||||
|
|
||||
|
if not return_dict: |
||||
|
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs) |
||||
|
return ((loss,) + output) if loss is not None else output |
||||
|
|
||||
|
return CLIPOutput( |
||||
|
loss=loss, |
||||
|
logits_per_image=logits_per_image, |
||||
|
logits_per_text=logits_per_text, |
||||
|
text_embeds=text_embeds, |
||||
|
image_embeds=image_embeds, |
||||
|
text_model_output=text_outputs, |
||||
|
vision_model_output=vision_outputs, |
||||
|
) |
||||
|
|
@ -0,0 +1,16 @@ |
|||||
|
# coding=utf-8 |
||||
|
# Copyright 2022 rinna Co., Ltd. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
||||
|
from .configuration_cloob import * |
||||
|
from .modeling_cloob import * |
@ -0,0 +1,203 @@ |
|||||
|
# coding=utf-8 |
||||
|
# Copyright 2022 rinna Co., Ltd. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
||||
|
""" CLOOB model configuration""" |
||||
|
import logging |
||||
|
import copy |
||||
|
import os |
||||
|
from typing import Union |
||||
|
|
||||
|
from transformers import AutoConfig, PretrainedConfig |
||||
|
|
||||
|
|
||||
|
logger = logging.getLogger(__name__) |
||||
|
|
||||
|
|
||||
|
class CLOOBTextConfig(PretrainedConfig): |
||||
|
model_type = "cloob_text_model" |
||||
|
|
||||
|
def __init__( |
||||
|
self, |
||||
|
vocab_size=49408, |
||||
|
hidden_size=512, |
||||
|
intermediate_size=2048, |
||||
|
num_hidden_layers=12, |
||||
|
num_attention_heads=8, |
||||
|
max_position_embeddings=77, |
||||
|
hidden_act="quick_gelu", |
||||
|
layer_norm_eps=0.00001, |
||||
|
dropout=0.0, |
||||
|
attention_dropout=0.0, |
||||
|
initializer_range=0.02, |
||||
|
initializer_factor=1.0, |
||||
|
pad_token_id=1, |
||||
|
bos_token_id=0, |
||||
|
eos_token_id=2, |
||||
|
**kwargs |
||||
|
): |
||||
|
super().__init__(pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs) |
||||
|
|
||||
|
self.vocab_size = vocab_size |
||||
|
self.hidden_size = hidden_size |
||||
|
self.intermediate_size = intermediate_size |
||||
|
self.dropout = dropout |
||||
|
self.num_hidden_layers = num_hidden_layers |
||||
|
self.num_attention_heads = num_attention_heads |
||||
|
self.max_position_embeddings = max_position_embeddings |
||||
|
self.layer_norm_eps = layer_norm_eps |
||||
|
self.hidden_act = hidden_act |
||||
|
self.initializer_range = initializer_range |
||||
|
self.initializer_factor = initializer_factor |
||||
|
self.attention_dropout = attention_dropout |
||||
|
|
||||
|
@classmethod |
||||
|
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": |
||||
|
|
||||
|
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) |
||||
|
|
||||
|
# get the text config dict if we are loading from CLIPConfig |
||||
|
if config_dict.get("model_type") == "clip": |
||||
|
config_dict = config_dict["text_config"] |
||||
|
|
||||
|
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: |
||||
|
logger.warning( |
||||
|
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " |
||||
|
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." |
||||
|
) |
||||
|
|
||||
|
return cls.from_dict(config_dict, **kwargs) |
||||
|
|
||||
|
|
||||
|
class CLOOBVisionConfig(PretrainedConfig): |
||||
|
model_type = "cloob_vision_model" |
||||
|
|
||||
|
def __init__( |
||||
|
self, |
||||
|
hidden_size=768, |
||||
|
intermediate_size=3072, |
||||
|
num_hidden_layers=12, |
||||
|
num_attention_heads=12, |
||||
|
image_size=224, |
||||
|
patch_size=32, |
||||
|
hidden_act="quick_gelu", |
||||
|
layer_norm_eps=0.00001, |
||||
|
dropout=0.0, |
||||
|
attention_dropout=0.0, |
||||
|
initializer_range=0.02, |
||||
|
initializer_factor=1.0, |
||||
|
**kwargs |
||||
|
): |
||||
|
super().__init__(**kwargs) |
||||
|
|
||||
|
self.hidden_size = hidden_size |
||||
|
self.intermediate_size = intermediate_size |
||||
|
self.dropout = dropout |
||||
|
self.num_hidden_layers = num_hidden_layers |
||||
|
self.num_attention_heads = num_attention_heads |
||||
|
self.patch_size = patch_size |
||||
|
self.image_size = image_size |
||||
|
self.initializer_range = initializer_range |
||||
|
self.initializer_factor = initializer_factor |
||||
|
self.attention_dropout = attention_dropout |
||||
|
self.layer_norm_eps = layer_norm_eps |
||||
|
self.hidden_act = hidden_act |
||||
|
|
||||
|
@classmethod |
||||
|
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": |
||||
|
|
||||
|
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) |
||||
|
|
||||
|
# get the vision config dict if we are loading from CLIPConfig |
||||
|
if config_dict.get("model_type") == "clip": |
||||
|
config_dict = config_dict["vision_config"] |
||||
|
|
||||
|
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: |
||||
|
logger.warning( |
||||
|
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " |
||||
|
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." |
||||
|
) |
||||
|
|
||||
|
return cls.from_dict(config_dict, **kwargs) |
||||
|
|
||||
|
|
||||
|
class CLOOBConfig(PretrainedConfig): |
||||
|
model_type = "cloob" |
||||
|
is_composition = True |
||||
|
|
||||
|
def __init__( |
||||
|
self, |
||||
|
text_config=None, |
||||
|
vision_config=None, |
||||
|
projection_dim=512, |
||||
|
init_inv_tau=30.0, |
||||
|
scale_hopfield=15.0, |
||||
|
**kwargs |
||||
|
): |
||||
|
super().__init__(text_config=text_config, vision_config=vision_config, **kwargs) |
||||
|
|
||||
|
if vision_config is None: |
||||
|
raise ValueError("`vision_config` can not be `None`.") |
||||
|
|
||||
|
if text_config is None: |
||||
|
raise ValueError("`text_config` can not be `None`.") |
||||
|
|
||||
|
vision_model_type = vision_config.pop("model_type") |
||||
|
text_model_type = text_config.pop("model_type") |
||||
|
|
||||
|
if vision_model_type == "cloob_vision_model": |
||||
|
self.vision_config = CLOOBVisionConfig(**vision_config) |
||||
|
else: |
||||
|
self.vision_config = AutoConfig.for_model( |
||||
|
vision_model_type, **vision_config |
||||
|
) |
||||
|
|
||||
|
if text_model_type == "cloob_text_model": |
||||
|
self.text_config = CLOOBTextConfig(**text_config) |
||||
|
else: |
||||
|
self.text_config = AutoConfig.for_model( |
||||
|
text_model_type, **text_config |
||||
|
) |
||||
|
|
||||
|
self.projection_dim = projection_dim |
||||
|
self.initializer_factor = 1.0 |
||||
|
self.init_inv_tau = init_inv_tau |
||||
|
self.scale_hopfield = scale_hopfield |
||||
|
|
||||
|
|
||||
|
@classmethod |
||||
|
def from_text_vision_configs(cls, text_config: CLOOBTextConfig, vision_config: CLOOBVisionConfig, **kwargs): |
||||
|
r""" |
||||
|
Instantiate a [`CLIPConfig`] (or a derived class) from clip text model configuration and clip vision model |
||||
|
configuration. |
||||
|
|
||||
|
Returns: |
||||
|
[`CLIPConfig`]: An instance of a configuration object |
||||
|
""" |
||||
|
|
||||
|
return cls(text_config_dict=text_config.to_dict(), vision_config_dict=vision_config.to_dict(), **kwargs) |
||||
|
|
||||
|
def to_dict(self): |
||||
|
""" |
||||
|
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`]. |
||||
|
|
||||
|
Returns: |
||||
|
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance, |
||||
|
""" |
||||
|
output = copy.deepcopy(self.__dict__) |
||||
|
output["text_config"] = self.text_config.to_dict() |
||||
|
output["vision_config"] = self.vision_config.to_dict() |
||||
|
output["model_type"] = self.__class__.model_type |
||||
|
return output |
||||
|
|
||||
|
|
@ -0,0 +1,58 @@ |
|||||
|
# coding=utf-8 |
||||
|
# Copyright 2022 rinna Co., Ltd. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
||||
|
|
||||
|
import torch |
||||
|
import torch.nn.functional as F |
||||
|
|
||||
|
|
||||
|
def cloob_loss(image_features, text_features, inv_tau, scale_hopfield): |
||||
|
""" |
||||
|
Note: this loss has been rescaled from the original CLOOB loss for interpretability, |
||||
|
to convert to the original, divide it by inv_tau / 2. |
||||
|
""" |
||||
|
p_xx, p_yy, p_xy, p_yx = hopfield_retrieval(image_features, text_features, scale_hopfield) |
||||
|
identity = torch.eye(p_xx.shape[1]) > 0.5 |
||||
|
i = identity.to(p_xx.device) |
||||
|
loss_img = infoLOOB_loss(p_xx.T, p_xy.T, i, inv_tau=inv_tau) |
||||
|
loss_txt = infoLOOB_loss(p_yy.T, p_yx.T, i, inv_tau=inv_tau) |
||||
|
return (loss_img + loss_txt) / 2 |
||||
|
|
||||
|
|
||||
|
def infoLOOB_loss(x, y, i, inv_tau): |
||||
|
tau = 1 / inv_tau |
||||
|
k = x @ y.T / tau |
||||
|
positives = -torch.mean(torch.sum(k * i, dim=1)) |
||||
|
|
||||
|
# For logsumexp the zero entries must be equal to a very large negative number |
||||
|
large_neg = -10000.0 |
||||
|
arg_lse = k * torch.logical_not(i) + i * large_neg |
||||
|
negatives = torch.mean(torch.logsumexp(arg_lse, dim=1)) |
||||
|
return positives + negatives |
||||
|
|
||||
|
|
||||
|
def hopfield_retrieval(image_features, text_features, scale_hopfield): |
||||
|
patterns_xx = hopfield(state_patterns=image_features, stored_patterns=image_features, scale_hopfield=scale_hopfield) |
||||
|
patterns_yy = hopfield(state_patterns=text_features, stored_patterns=text_features, scale_hopfield=scale_hopfield) |
||||
|
patterns_xy = hopfield(state_patterns=text_features, stored_patterns=image_features, scale_hopfield=scale_hopfield) |
||||
|
patterns_yx = hopfield(state_patterns=image_features, stored_patterns=text_features, scale_hopfield=scale_hopfield) |
||||
|
|
||||
|
return patterns_xx, patterns_yy, patterns_xy, patterns_yx |
||||
|
|
||||
|
|
||||
|
def hopfield(state_patterns, stored_patterns, scale_hopfield): |
||||
|
retrieved_patterns = stored_patterns.T @ F.softmax(scale_hopfield * stored_patterns @ state_patterns.T, dim=0) |
||||
|
# Row vectors -> dim=1 to normalize the row vectors |
||||
|
retrieved_patterns = retrieved_patterns / retrieved_patterns.norm(dim=0, keepdim=True) |
||||
|
return retrieved_patterns |
@ -0,0 +1,783 @@ |
|||||
|
# coding=utf-8 |
||||
|
# Copyright 2022 rinna Co., Ltd. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
||||
|
import logging |
||||
|
from dataclasses import dataclass |
||||
|
from typing import Any, Optional, Tuple, Union |
||||
|
|
||||
|
import torch |
||||
|
import torch.utils.checkpoint |
||||
|
from torch import nn |
||||
|
|
||||
|
from transformers import AutoModel |
||||
|
from transformers.activations import ACT2FN |
||||
|
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling |
||||
|
from transformers.modeling_utils import PreTrainedModel, ModelOutput |
||||
|
from .configuration_cloob import CLOOBConfig, CLOOBTextConfig, CLOOBVisionConfig |
||||
|
from .loss import cloob_loss |
||||
|
from ..clip.modeling_clip import _expand_mask |
||||
|
|
||||
|
logger = logging.getLogger(__name__) |
||||
|
|
||||
|
|
||||
|
@dataclass |
||||
|
class CLOOBOutput(ModelOutput): |
||||
|
loss: Optional[torch.FloatTensor] = None |
||||
|
inv_tau: Union[torch.FloatTensor, float] = None |
||||
|
text_embeds: torch.FloatTensor = None |
||||
|
image_embeds: torch.FloatTensor = None |
||||
|
text_model_output: BaseModelOutputWithPooling = None |
||||
|
vision_model_output: BaseModelOutputWithPooling = None |
||||
|
|
||||
|
def to_tuple(self) -> Tuple[Any]: |
||||
|
return tuple( |
||||
|
self[k] if k not in ["text_model_output", "vision_model_output"] else getattr(self, k).to_tuple() |
||||
|
for k in self.keys() |
||||
|
) |
||||
|
|
||||
|
|
||||
|
class CLOOBVisionEmbeddings(nn.Module): |
||||
|
def __init__(self, config: CLOOBVisionConfig): |
||||
|
super().__init__() |
||||
|
self.config = config |
||||
|
self.embed_dim = config.hidden_size |
||||
|
self.image_size = config.image_size |
||||
|
self.patch_size = config.patch_size |
||||
|
|
||||
|
self.class_embedding = nn.Parameter(torch.randn(self.embed_dim)) |
||||
|
|
||||
|
self.patch_embedding = nn.Conv2d( |
||||
|
in_channels=3, out_channels=self.embed_dim, kernel_size=self.patch_size, stride=self.patch_size, bias=False |
||||
|
) |
||||
|
|
||||
|
self.num_patches = (self.image_size // self.patch_size) ** 2 |
||||
|
self.num_positions = self.num_patches + 1 |
||||
|
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) |
||||
|
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1))) |
||||
|
|
||||
|
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: |
||||
|
batch_size = pixel_values.shape[0] |
||||
|
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid] |
||||
|
patch_embeds = patch_embeds.flatten(2).transpose(1, 2) |
||||
|
|
||||
|
class_embeds = self.class_embedding.expand(batch_size, 1, -1) |
||||
|
embeddings = torch.cat([class_embeds, patch_embeds], dim=1) |
||||
|
embeddings = embeddings + self.position_embedding(self.position_ids) |
||||
|
return embeddings |
||||
|
|
||||
|
|
||||
|
class CLOOBTextEmbeddings(nn.Module): |
||||
|
def __init__(self, config: CLOOBTextConfig): |
||||
|
super().__init__() |
||||
|
embed_dim = config.hidden_size |
||||
|
|
||||
|
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim) |
||||
|
self.position_embedding = nn.Embedding(config.max_position_embeddings, embed_dim) |
||||
|
|
||||
|
# position_ids (1, len position emb) is contiguous in memory and exported when serialized |
||||
|
self.register_buffer("position_ids", torch.arange(config.max_position_embeddings).expand((1, -1))) |
||||
|
|
||||
|
def forward( |
||||
|
self, |
||||
|
input_ids: Optional[torch.LongTensor] = None, |
||||
|
position_ids: Optional[torch.LongTensor] = None, |
||||
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
||||
|
) -> torch.Tensor: |
||||
|
seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2] |
||||
|
|
||||
|
if position_ids is None: |
||||
|
position_ids = self.position_ids[:, :seq_length] |
||||
|
|
||||
|
if inputs_embeds is None: |
||||
|
inputs_embeds = self.token_embedding(input_ids) |
||||
|
|
||||
|
position_embeddings = self.position_embedding(position_ids) |
||||
|
embeddings = inputs_embeds + position_embeddings |
||||
|
|
||||
|
return embeddings |
||||
|
|
||||
|
|
||||
|
class CLOOBAttention(nn.Module): |
||||
|
"""Multi-headed attention from 'Attention Is All You Need' paper""" |
||||
|
|
||||
|
def __init__(self, config): |
||||
|
super().__init__() |
||||
|
self.config = config |
||||
|
self.embed_dim = config.hidden_size |
||||
|
self.num_heads = config.num_attention_heads |
||||
|
self.head_dim = self.embed_dim // self.num_heads |
||||
|
if self.head_dim * self.num_heads != self.embed_dim: |
||||
|
raise ValueError( |
||||
|
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" |
||||
|
f" {self.num_heads})." |
||||
|
) |
||||
|
self.scale = self.head_dim**-0.5 |
||||
|
self.dropout = config.attention_dropout |
||||
|
|
||||
|
self.k_proj = nn.Linear(self.embed_dim, self.embed_dim) |
||||
|
self.v_proj = nn.Linear(self.embed_dim, self.embed_dim) |
||||
|
self.q_proj = nn.Linear(self.embed_dim, self.embed_dim) |
||||
|
self.out_proj = nn.Linear(self.embed_dim, self.embed_dim) |
||||
|
|
||||
|
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): |
||||
|
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() |
||||
|
|
||||
|
def forward( |
||||
|
self, |
||||
|
hidden_states: torch.Tensor, |
||||
|
attention_mask: Optional[torch.Tensor] = None, |
||||
|
causal_attention_mask: Optional[torch.Tensor] = None, |
||||
|
output_attentions: Optional[bool] = False, |
||||
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
||||
|
"""Input shape: Batch x Time x Channel""" |
||||
|
|
||||
|
bsz, tgt_len, embed_dim = hidden_states.size() |
||||
|
|
||||
|
# get query proj |
||||
|
query_states = self.q_proj(hidden_states) * self.scale |
||||
|
key_states = self._shape(self.k_proj(hidden_states), -1, bsz) |
||||
|
value_states = self._shape(self.v_proj(hidden_states), -1, bsz) |
||||
|
|
||||
|
proj_shape = (bsz * self.num_heads, -1, self.head_dim) |
||||
|
query_states = self._shape(query_states, tgt_len, bsz).view(*proj_shape) |
||||
|
key_states = key_states.view(*proj_shape) |
||||
|
value_states = value_states.view(*proj_shape) |
||||
|
|
||||
|
src_len = key_states.size(1) |
||||
|
attn_weights = torch.bmm(query_states, key_states.transpose(1, 2)) |
||||
|
|
||||
|
if attn_weights.size() != (bsz * self.num_heads, tgt_len, src_len): |
||||
|
raise ValueError( |
||||
|
f"Attention weights should be of size {(bsz * self.num_heads, tgt_len, src_len)}, but is" |
||||
|
f" {attn_weights.size()}" |
||||
|
) |
||||
|
|
||||
|
# apply the causal_attention_mask first |
||||
|
if causal_attention_mask is not None: |
||||
|
if causal_attention_mask.size() != (bsz, 1, tgt_len, src_len): |
||||
|
raise ValueError( |
||||
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is" |
||||
|
f" {causal_attention_mask.size()}" |
||||
|
) |
||||
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + causal_attention_mask |
||||
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) |
||||
|
|
||||
|
if attention_mask is not None: |
||||
|
if attention_mask.size() != (bsz, 1, tgt_len, src_len): |
||||
|
raise ValueError( |
||||
|
f"Attention mask should be of size {(bsz, 1, tgt_len, src_len)}, but is {attention_mask.size()}" |
||||
|
) |
||||
|
attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) + attention_mask |
||||
|
attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) |
||||
|
|
||||
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1) |
||||
|
|
||||
|
if output_attentions: |
||||
|
# this operation is a bit akward, but it's required to |
||||
|
# make sure that attn_weights keeps its gradient. |
||||
|
# In order to do so, attn_weights have to reshaped |
||||
|
# twice and have to be reused in the following |
||||
|
attn_weights_reshaped = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) |
||||
|
attn_weights = attn_weights_reshaped.view(bsz * self.num_heads, tgt_len, src_len) |
||||
|
else: |
||||
|
attn_weights_reshaped = None |
||||
|
|
||||
|
attn_probs = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training) |
||||
|
|
||||
|
attn_output = torch.bmm(attn_probs, value_states) |
||||
|
|
||||
|
if attn_output.size() != (bsz * self.num_heads, tgt_len, self.head_dim): |
||||
|
raise ValueError( |
||||
|
f"`attn_output` should be of size {(bsz, self.num_heads, tgt_len, self.head_dim)}, but is" |
||||
|
f" {attn_output.size()}" |
||||
|
) |
||||
|
|
||||
|
attn_output = attn_output.view(bsz, self.num_heads, tgt_len, self.head_dim) |
||||
|
attn_output = attn_output.transpose(1, 2) |
||||
|
attn_output = attn_output.reshape(bsz, tgt_len, embed_dim) |
||||
|
|
||||
|
attn_output = self.out_proj(attn_output) |
||||
|
|
||||
|
return attn_output, attn_weights_reshaped |
||||
|
|
||||
|
|
||||
|
class CLOOBMLP(nn.Module): |
||||
|
def __init__(self, config): |
||||
|
super().__init__() |
||||
|
self.config = config |
||||
|
self.activation_fn = ACT2FN[config.hidden_act] |
||||
|
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size) |
||||
|
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size) |
||||
|
|
||||
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
||||
|
hidden_states = self.fc1(hidden_states) |
||||
|
hidden_states = self.activation_fn(hidden_states) |
||||
|
hidden_states = self.fc2(hidden_states) |
||||
|
return hidden_states |
||||
|
|
||||
|
|
||||
|
class CLOOBEncoderLayer(nn.Module): |
||||
|
def __init__(self, config: CLOOBConfig): |
||||
|
super().__init__() |
||||
|
self.embed_dim = config.hidden_size |
||||
|
self.self_attn = CLOOBAttention(config) |
||||
|
self.layer_norm1 = nn.LayerNorm(self.embed_dim) |
||||
|
self.mlp = CLOOBMLP(config) |
||||
|
self.layer_norm2 = nn.LayerNorm(self.embed_dim) |
||||
|
|
||||
|
def forward( |
||||
|
self, |
||||
|
hidden_states: torch.Tensor, |
||||
|
attention_mask: torch.Tensor, |
||||
|
causal_attention_mask: torch.Tensor, |
||||
|
output_attentions: Optional[bool] = False, |
||||
|
) -> Tuple[torch.FloatTensor]: |
||||
|
""" |
||||
|
Args: |
||||
|
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` |
||||
|
attention_mask (`torch.FloatTensor`): attention mask of size |
||||
|
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. |
||||
|
`(config.encoder_attention_heads,)`. |
||||
|
output_attentions (`bool`, *optional*): |
||||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
||||
|
returned tensors for more detail. |
||||
|
""" |
||||
|
residual = hidden_states |
||||
|
|
||||
|
hidden_states = self.layer_norm1(hidden_states) |
||||
|
hidden_states, attn_weights = self.self_attn( |
||||
|
hidden_states=hidden_states, |
||||
|
attention_mask=attention_mask, |
||||
|
causal_attention_mask=causal_attention_mask, |
||||
|
output_attentions=output_attentions, |
||||
|
) |
||||
|
hidden_states = residual + hidden_states |
||||
|
|
||||
|
residual = hidden_states |
||||
|
hidden_states = self.layer_norm2(hidden_states) |
||||
|
hidden_states = self.mlp(hidden_states) |
||||
|
hidden_states = residual + hidden_states |
||||
|
|
||||
|
outputs = (hidden_states,) |
||||
|
|
||||
|
if output_attentions: |
||||
|
outputs += (attn_weights,) |
||||
|
|
||||
|
return outputs |
||||
|
|
||||
|
|
||||
|
class CLOOBPreTrainedModel(PreTrainedModel): |
||||
|
""" |
||||
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained |
||||
|
models. |
||||
|
""" |
||||
|
|
||||
|
config_class = CLOOBConfig |
||||
|
base_model_prefix = "cloob" |
||||
|
supports_gradient_checkpointing = True |
||||
|
_keys_to_ignore_on_load_missing = [r"position_ids"] |
||||
|
|
||||
|
def _init_weights(self, module): |
||||
|
"""Initialize the weights""" |
||||
|
factor = self.config.initializer_factor |
||||
|
if isinstance(module, CLOOBTextEmbeddings): |
||||
|
module.token_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) |
||||
|
module.position_embedding.weight.data.normal_(mean=0.0, std=factor * 0.02) |
||||
|
elif isinstance(module, CLOOBVisionEmbeddings): |
||||
|
factor = self.config.initializer_factor |
||||
|
nn.init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor) |
||||
|
nn.init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor) |
||||
|
nn.init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor) |
||||
|
elif isinstance(module, CLOOBAttention): |
||||
|
factor = self.config.initializer_factor |
||||
|
in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor |
||||
|
out_proj_std = (module.embed_dim**-0.5) * factor |
||||
|
nn.init.normal_(module.q_proj.weight, std=in_proj_std) |
||||
|
nn.init.normal_(module.k_proj.weight, std=in_proj_std) |
||||
|
nn.init.normal_(module.v_proj.weight, std=in_proj_std) |
||||
|
nn.init.normal_(module.out_proj.weight, std=out_proj_std) |
||||
|
elif isinstance(module, CLOOBMLP): |
||||
|
factor = self.config.initializer_factor |
||||
|
in_proj_std = ( |
||||
|
(module.config.hidden_size**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor |
||||
|
) |
||||
|
fc_std = (2 * module.config.hidden_size) ** -0.5 * factor |
||||
|
nn.init.normal_(module.fc1.weight, std=fc_std) |
||||
|
nn.init.normal_(module.fc2.weight, std=in_proj_std) |
||||
|
elif isinstance(module, CLOOBModel): |
||||
|
nn.init.normal_( |
||||
|
module.text_projection.weight, |
||||
|
std=module.text_embed_dim**-0.5 * self.config.initializer_factor, |
||||
|
) |
||||
|
nn.init.normal_( |
||||
|
module.visual_projection.weight, |
||||
|
std=module.vision_embed_dim**-0.5 * self.config.initializer_factor, |
||||
|
) |
||||
|
|
||||
|
if isinstance(module, nn.LayerNorm): |
||||
|
module.bias.data.zero_() |
||||
|
module.weight.data.fill_(1.0) |
||||
|
if isinstance(module, nn.Linear) and module.bias is not None: |
||||
|
module.bias.data.zero_() |
||||
|
|
||||
|
def _set_gradient_checkpointing(self, module, value=False): |
||||
|
if isinstance(module, CLOOBEncoder): |
||||
|
module.gradient_checkpointing = value |
||||
|
|
||||
|
|
||||
|
class CLOOBEncoder(nn.Module): |
||||
|
""" |
||||
|
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a |
||||
|
[`CLOOBEncoderLayer`]. |
||||
|
Args: |
||||
|
config: CLOOBConfig |
||||
|
""" |
||||
|
|
||||
|
def __init__(self, config: CLOOBConfig): |
||||
|
super().__init__() |
||||
|
self.config = config |
||||
|
self.layers = nn.ModuleList([CLOOBEncoderLayer(config) for _ in range(config.num_hidden_layers)]) |
||||
|
self.gradient_checkpointing = False |
||||
|
|
||||
|
def forward( |
||||
|
self, |
||||
|
inputs_embeds, |
||||
|
attention_mask: Optional[torch.Tensor] = None, |
||||
|
causal_attention_mask: Optional[torch.Tensor] = None, |
||||
|
output_attentions: Optional[bool] = None, |
||||
|
output_hidden_states: Optional[bool] = None, |
||||
|
return_dict: Optional[bool] = None, |
||||
|
) -> Union[Tuple, BaseModelOutput]: |
||||
|
r""" |
||||
|
Args: |
||||
|
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`): |
||||
|
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. |
||||
|
This is useful if you want more control over how to convert `input_ids` indices into associated vectors |
||||
|
than the model's internal embedding lookup matrix. |
||||
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): |
||||
|
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: |
||||
|
- 1 for tokens that are **not masked**, |
||||
|
- 0 for tokens that are **masked**. |
||||
|
[What are attention masks?](../glossary#attention-mask) |
||||
|
causal_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): |
||||
|
Causal mask for the text model. Mask values selected in `[0, 1]`: |
||||
|
- 1 for tokens that are **not masked**, |
||||
|
- 0 for tokens that are **masked**. |
||||
|
[What are attention masks?](../glossary#attention-mask) |
||||
|
output_attentions (`bool`, *optional*): |
||||
|
Whether or not to return the attentions tensors of all attention layers. See `attentions` under |
||||
|
returned tensors for more detail. |
||||
|
output_hidden_states (`bool`, *optional*): |
||||
|
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors |
||||
|
for more detail. |
||||
|
return_dict (`bool`, *optional*): |
||||
|
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. |
||||
|
""" |
||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
||||
|
output_hidden_states = ( |
||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
||||
|
) |
||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
||||
|
|
||||
|
encoder_states = () if output_hidden_states else None |
||||
|
all_attentions = () if output_attentions else None |
||||
|
|
||||
|
hidden_states = inputs_embeds |
||||
|
for idx, encoder_layer in enumerate(self.layers): |
||||
|
if output_hidden_states: |
||||
|
encoder_states = encoder_states + (hidden_states,) |
||||
|
if self.gradient_checkpointing and self.training: |
||||
|
|
||||
|
def create_custom_forward(module): |
||||
|
def custom_forward(*inputs): |
||||
|
return module(*inputs, output_attentions) |
||||
|
|
||||
|
return custom_forward |
||||
|
|
||||
|
layer_outputs = torch.utils.checkpoint.checkpoint( |
||||
|
create_custom_forward(encoder_layer), |
||||
|
hidden_states, |
||||
|
attention_mask, |
||||
|
causal_attention_mask, |
||||
|
) |
||||
|
else: |
||||
|
layer_outputs = encoder_layer( |
||||
|
hidden_states, |
||||
|
attention_mask, |
||||
|
causal_attention_mask, |
||||
|
output_attentions=output_attentions, |
||||
|
) |
||||
|
|
||||
|
hidden_states = layer_outputs[0] |
||||
|
|
||||
|
if output_attentions: |
||||
|
all_attentions = all_attentions + (layer_outputs[1],) |
||||
|
|
||||
|
if output_hidden_states: |
||||
|
encoder_states = encoder_states + (hidden_states,) |
||||
|
|
||||
|
if not return_dict: |
||||
|
return tuple(v for v in [hidden_states, encoder_states, all_attentions] if v is not None) |
||||
|
return BaseModelOutput( |
||||
|
last_hidden_state=hidden_states, hidden_states=encoder_states, attentions=all_attentions |
||||
|
) |
||||
|
|
||||
|
|
||||
|
class CLOOBTextTransformer(nn.Module): |
||||
|
def __init__(self, config: CLOOBTextConfig): |
||||
|
super().__init__() |
||||
|
self.config = config |
||||
|
embed_dim = config.hidden_size |
||||
|
self.embeddings = CLOOBTextEmbeddings(config) |
||||
|
self.encoder = CLOOBEncoder(config) |
||||
|
self.final_layer_norm = nn.LayerNorm(embed_dim) |
||||
|
|
||||
|
def forward( |
||||
|
self, |
||||
|
input_ids: Optional[torch.Tensor] = None, |
||||
|
attention_mask: Optional[torch.Tensor] = None, |
||||
|
position_ids: Optional[torch.Tensor] = None, |
||||
|
output_attentions: Optional[bool] = None, |
||||
|
output_hidden_states: Optional[bool] = None, |
||||
|
return_dict: Optional[bool] = None, |
||||
|
) -> Union[Tuple, BaseModelOutputWithPooling]: |
||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
||||
|
output_hidden_states = ( |
||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
||||
|
) |
||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
||||
|
|
||||
|
if input_ids is None: |
||||
|
raise ValueError("You have to specify either input_ids") |
||||
|
|
||||
|
input_shape = input_ids.size() |
||||
|
input_ids = input_ids.view(-1, input_shape[-1]) |
||||
|
|
||||
|
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids) |
||||
|
|
||||
|
bsz, seq_len = input_shape |
||||
|
# CLOOB's text model uses causal mask, prepare it here. |
||||
|
# https://github.com/openai/CLOOB/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/CLOOB/model.py#L324 |
||||
|
causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len).to(hidden_states.device) |
||||
|
# expand attention_mask |
||||
|
if attention_mask is not None: |
||||
|
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] |
||||
|
attention_mask = _expand_mask(attention_mask, hidden_states.dtype) |
||||
|
|
||||
|
encoder_outputs = self.encoder( |
||||
|
inputs_embeds=hidden_states, |
||||
|
attention_mask=attention_mask, |
||||
|
causal_attention_mask=causal_attention_mask, |
||||
|
output_attentions=output_attentions, |
||||
|
output_hidden_states=output_hidden_states, |
||||
|
return_dict=return_dict, |
||||
|
) |
||||
|
|
||||
|
last_hidden_state = encoder_outputs[0] |
||||
|
last_hidden_state = self.final_layer_norm(last_hidden_state) |
||||
|
|
||||
|
# text_embeds.shape = [batch_size, sequence_length, transformer.width] |
||||
|
# take features from the eot embedding (eot_token is the highest number in each sequence) |
||||
|
pooled_output = last_hidden_state[torch.arange(last_hidden_state.shape[0]), input_ids.argmax(dim=-1)] |
||||
|
|
||||
|
if not return_dict: |
||||
|
return (last_hidden_state, pooled_output) + encoder_outputs[1:] |
||||
|
|
||||
|
return BaseModelOutputWithPooling( |
||||
|
last_hidden_state=last_hidden_state, |
||||
|
pooler_output=pooled_output, |
||||
|
hidden_states=encoder_outputs.hidden_states, |
||||
|
attentions=encoder_outputs.attentions, |
||||
|
) |
||||
|
|
||||
|
def _build_causal_attention_mask(self, bsz, seq_len): |
||||
|
# lazily create causal attention mask, with full attention between the vision tokens |
||||
|
# pytorch uses additive attention mask; fill with -inf |
||||
|
mask = torch.empty(bsz, seq_len, seq_len) |
||||
|
mask.fill_(float("-inf")) |
||||
|
mask.triu_(1) # zero out the lower diagonal |
||||
|
mask = mask.unsqueeze(1) # expand mask |
||||
|
return mask |
||||
|
|
||||
|
|
||||
|
class CLOOBTextModel(CLOOBPreTrainedModel): |
||||
|
config_class = CLOOBTextConfig |
||||
|
|
||||
|
def __init__(self, config: CLOOBTextConfig): |
||||
|
super().__init__(config) |
||||
|
self.text_model = CLOOBTextTransformer(config) |
||||
|
# Initialize weights and apply final processing |
||||
|
self.post_init() |
||||
|
|
||||
|
def get_input_embeddings(self) -> nn.Module: |
||||
|
return self.text_model.embeddings.token_embedding |
||||
|
|
||||
|
def set_input_embeddings(self, value): |
||||
|
self.text_model.embeddings.token_embedding = value |
||||
|
|
||||
|
def forward( |
||||
|
self, |
||||
|
input_ids: Optional[torch.Tensor] = None, |
||||
|
attention_mask: Optional[torch.Tensor] = None, |
||||
|
position_ids: Optional[torch.Tensor] = None, |
||||
|
output_attentions: Optional[bool] = None, |
||||
|
output_hidden_states: Optional[bool] = None, |
||||
|
return_dict: Optional[bool] = None, |
||||
|
) -> Union[Tuple, BaseModelOutputWithPooling]: |
||||
|
return self.text_model( |
||||
|
input_ids=input_ids, |
||||
|
attention_mask=attention_mask, |
||||
|
position_ids=position_ids, |
||||
|
output_attentions=output_attentions, |
||||
|
output_hidden_states=output_hidden_states, |
||||
|
return_dict=return_dict, |
||||
|
) |
||||
|
|
||||
|
|
||||
|
class CLOOBVisionTransformer(nn.Module): |
||||
|
def __init__(self, config: CLOOBVisionConfig): |
||||
|
super().__init__() |
||||
|
self.config = config |
||||
|
embed_dim = config.hidden_size |
||||
|
|
||||
|
self.embeddings = CLOOBVisionEmbeddings(config) |
||||
|
self.pre_layrnorm = nn.LayerNorm(embed_dim) |
||||
|
self.encoder = CLOOBEncoder(config) |
||||
|
self.post_layernorm = nn.LayerNorm(embed_dim) |
||||
|
|
||||
|
def forward( |
||||
|
self, |
||||
|
pixel_values: Optional[torch.FloatTensor] = None, |
||||
|
output_attentions: Optional[bool] = None, |
||||
|
output_hidden_states: Optional[bool] = None, |
||||
|
return_dict: Optional[bool] = None, |
||||
|
) -> Union[Tuple, BaseModelOutputWithPooling]: |
||||
|
r""" |
||||
|
Returns: |
||||
|
""" |
||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
||||
|
output_hidden_states = ( |
||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
||||
|
) |
||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
||||
|
|
||||
|
if pixel_values is None: |
||||
|
raise ValueError("You have to specify pixel_values") |
||||
|
|
||||
|
hidden_states = self.embeddings(pixel_values) |
||||
|
hidden_states = self.pre_layrnorm(hidden_states) |
||||
|
|
||||
|
encoder_outputs = self.encoder( |
||||
|
inputs_embeds=hidden_states, |
||||
|
output_attentions=output_attentions, |
||||
|
output_hidden_states=output_hidden_states, |
||||
|
return_dict=return_dict, |
||||
|
) |
||||
|
|
||||
|
last_hidden_state = encoder_outputs[0] |
||||
|
pooled_output = last_hidden_state[:, 0, :] |
||||
|
pooled_output = self.post_layernorm(pooled_output) |
||||
|
|
||||
|
if not return_dict: |
||||
|
return (last_hidden_state, pooled_output) + encoder_outputs[1:] |
||||
|
|
||||
|
return BaseModelOutputWithPooling( |
||||
|
last_hidden_state=last_hidden_state, |
||||
|
pooler_output=pooled_output, |
||||
|
hidden_states=encoder_outputs.hidden_states, |
||||
|
attentions=encoder_outputs.attentions, |
||||
|
) |
||||
|
|
||||
|
|
||||
|
class CLOOBVisionModel(CLOOBPreTrainedModel): |
||||
|
config_class = CLOOBVisionConfig |
||||
|
main_input_name = "pixel_values" |
||||
|
|
||||
|
def __init__(self, config: CLOOBVisionConfig): |
||||
|
super().__init__(config) |
||||
|
self.vision_model = CLOOBVisionTransformer(config) |
||||
|
# Initialize weights and apply final processing |
||||
|
self.post_init() |
||||
|
|
||||
|
def get_input_embeddings(self) -> nn.Module: |
||||
|
return self.vision_model.embeddings.patch_embedding |
||||
|
|
||||
|
def forward( |
||||
|
self, |
||||
|
pixel_values: Optional[torch.FloatTensor] = None, |
||||
|
output_attentions: Optional[bool] = None, |
||||
|
output_hidden_states: Optional[bool] = None, |
||||
|
return_dict: Optional[bool] = None, |
||||
|
) -> Union[Tuple, BaseModelOutputWithPooling]: |
||||
|
return self.vision_model( |
||||
|
pixel_values=pixel_values, |
||||
|
output_attentions=output_attentions, |
||||
|
output_hidden_states=output_hidden_states, |
||||
|
return_dict=return_dict, |
||||
|
) |
||||
|
|
||||
|
|
||||
|
class CLOOBModel(CLOOBPreTrainedModel): |
||||
|
config_class = CLOOBConfig |
||||
|
|
||||
|
def __init__(self, config: CLOOBConfig): |
||||
|
super().__init__(config) |
||||
|
text_config = config.text_config |
||||
|
vision_config = config.vision_config |
||||
|
|
||||
|
self.projection_dim = config.projection_dim |
||||
|
self.text_embed_dim = text_config.hidden_size |
||||
|
self.vision_embed_dim = vision_config.hidden_size |
||||
|
|
||||
|
if isinstance(text_config, CLOOBTextConfig): |
||||
|
text_model = CLOOBTextTransformer(text_config) |
||||
|
else: |
||||
|
text_model = AutoModel.from_config(config.text_config, add_pooling_layer=False) |
||||
|
|
||||
|
if isinstance(config.vision_config, CLOOBVisionConfig): |
||||
|
vision_model = CLOOBVisionModel(config.vision_config) |
||||
|
else: |
||||
|
vision_model = AutoModel.from_config(config.vision_config, add_pooling_layer=False) |
||||
|
|
||||
|
self.text_model = text_model |
||||
|
self.vision_model = vision_model |
||||
|
|
||||
|
self.visual_projection = nn.Linear(self.vision_embed_dim, self.projection_dim, bias=False) |
||||
|
self.text_projection = nn.Linear(self.text_embed_dim, self.projection_dim, bias=False) |
||||
|
|
||||
|
self.inv_tau = config.init_inv_tau |
||||
|
self.scale_hopfield = config.scale_hopfield |
||||
|
|
||||
|
# Initialize weights and apply final processing |
||||
|
self.post_init() |
||||
|
|
||||
|
def encode_text(self, input_ids, **kwargs): |
||||
|
return self.get_text_features(input_ids=input_ids, **kwargs) |
||||
|
|
||||
|
def get_text_features( |
||||
|
self, |
||||
|
input_ids: Optional[torch.Tensor] = None, |
||||
|
attention_mask: Optional[torch.Tensor] = None, |
||||
|
position_ids: Optional[torch.Tensor] = None, |
||||
|
output_attentions: Optional[bool] = None, |
||||
|
output_hidden_states: Optional[bool] = None, |
||||
|
return_dict: Optional[bool] = None, |
||||
|
) -> torch.FloatTensor: |
||||
|
# Use CLOOB model's config for some fields (if specified) instead of those of vision & text components. |
||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
||||
|
output_hidden_states = ( |
||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
||||
|
) |
||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
||||
|
|
||||
|
text_outputs = self.text_model( |
||||
|
input_ids=input_ids, |
||||
|
attention_mask=attention_mask, |
||||
|
position_ids=position_ids, |
||||
|
output_attentions=output_attentions, |
||||
|
output_hidden_states=output_hidden_states, |
||||
|
return_dict=return_dict, |
||||
|
) |
||||
|
pooled_output = text_outputs.last_hidden_state[:, 0, :] |
||||
|
text_features = self.text_projection(pooled_output) |
||||
|
|
||||
|
return text_features |
||||
|
|
||||
|
def encode_image(self, pixel_values, **kwargs): |
||||
|
return self.get_image_features(pixel_values=pixel_values, **kwargs) |
||||
|
|
||||
|
def get_image_features( |
||||
|
self, |
||||
|
pixel_values: Optional[torch.FloatTensor] = None, |
||||
|
output_attentions: Optional[bool] = None, |
||||
|
output_hidden_states: Optional[bool] = None, |
||||
|
return_dict: Optional[bool] = None, |
||||
|
) -> torch.FloatTensor: |
||||
|
# Use CLOOB model's config for some fields (if specified) instead of those of vision & text components. |
||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
||||
|
output_hidden_states = ( |
||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
||||
|
) |
||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
||||
|
|
||||
|
vision_outputs = self.vision_model( |
||||
|
pixel_values=pixel_values, |
||||
|
output_attentions=output_attentions, |
||||
|
output_hidden_states=output_hidden_states, |
||||
|
return_dict=return_dict, |
||||
|
) |
||||
|
pooled_output = vision_outputs.last_hidden_state[:, 0, :] |
||||
|
image_features = self.visual_projection(pooled_output) |
||||
|
|
||||
|
return image_features |
||||
|
|
||||
|
def forward( |
||||
|
self, |
||||
|
input_ids: Optional[torch.LongTensor] = None, |
||||
|
pixel_values: Optional[torch.FloatTensor] = None, |
||||
|
attention_mask: Optional[torch.Tensor] = None, |
||||
|
position_ids: Optional[torch.LongTensor] = None, |
||||
|
return_loss: Optional[bool] = None, |
||||
|
output_attentions: Optional[bool] = None, |
||||
|
output_hidden_states: Optional[bool] = None, |
||||
|
return_dict: Optional[bool] = None, |
||||
|
) -> Union[Tuple, CLOOBOutput]: |
||||
|
# Use CLOOB model's config for some fields (if specified) instead of those of vision & text components. |
||||
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
||||
|
output_hidden_states = ( |
||||
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
||||
|
) |
||||
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
||||
|
|
||||
|
vision_outputs = self.vision_model( |
||||
|
pixel_values=pixel_values, |
||||
|
output_attentions=output_attentions, |
||||
|
output_hidden_states=output_hidden_states, |
||||
|
return_dict=return_dict, |
||||
|
) |
||||
|
|
||||
|
text_outputs = self.text_model( |
||||
|
input_ids=input_ids, |
||||
|
attention_mask=attention_mask, |
||||
|
position_ids=position_ids, |
||||
|
output_attentions=output_attentions, |
||||
|
output_hidden_states=output_hidden_states, |
||||
|
return_dict=return_dict, |
||||
|
) |
||||
|
image_embeds = vision_outputs.last_hidden_state[:, 0, :] |
||||
|
image_embeds = self.visual_projection(image_embeds) |
||||
|
|
||||
|
text_embeds = text_outputs.last_hidden_state[:, 0, :] |
||||
|
text_embeds = self.text_projection(text_embeds) |
||||
|
|
||||
|
# normalized features |
||||
|
image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True) |
||||
|
text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True) |
||||
|
|
||||
|
loss = None |
||||
|
if return_loss: |
||||
|
loss = cloob_loss(image_embeds, text_embeds, self.inv_tau, self.scale_hopfield) |
||||
|
|
||||
|
if not return_dict: |
||||
|
output = (text_embeds, image_embeds, self.inv_tau, text_outputs, vision_outputs) |
||||
|
return ((loss,) + output) if loss is not None else output |
||||
|
|
||||
|
return CLOOBOutput( |
||||
|
loss=loss, |
||||
|
text_embeds=text_embeds, |
||||
|
image_embeds=image_embeds, |
||||
|
inv_tau=self.inv_tau, |
||||
|
text_model_output=text_outputs, |
||||
|
vision_model_output=vision_outputs, |
||||
|
) |
@ -0,0 +1,63 @@ |
|||||
|
# coding=utf-8 |
||||
|
# Copyright 2022 rinna Co., Ltd. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
||||
|
|
||||
|
from typing import Union, List |
||||
|
import torch |
||||
|
from transformers import T5Tokenizer |
||||
|
|
||||
|
|
||||
|
def load_tokenizer(): |
||||
|
""" |
||||
|
https://huggingface.co/rinna/japanese-roberta-base |
||||
|
""" |
||||
|
tokenizer = T5Tokenizer.from_pretrained("rinna/japanese-roberta-base") |
||||
|
tokenizer.do_lower_case = True # due to some bug of tokenizer config loading |
||||
|
return tokenizer |
||||
|
|
||||
|
|
||||
|
def tokenize( |
||||
|
texts: Union[str, List[str]], |
||||
|
tokenizer: T5Tokenizer = None, |
||||
|
max_seq_len: int = 77, |
||||
|
device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", |
||||
|
): |
||||
|
""" |
||||
|
This is a function that have the original clip's code has. |
||||
|
https://github.com/openai/CLIP/blob/main/clip/clip.py#L195 |
||||
|
""" |
||||
|
if isinstance(texts, str): |
||||
|
texts = [texts] |
||||
|
if tokenizer is None: |
||||
|
tokenizer = load_tokenizer() |
||||
|
inputs = tokenizer( |
||||
|
texts, |
||||
|
max_length=max_seq_len-1, |
||||
|
padding="max_length", |
||||
|
truncation=True, |
||||
|
add_special_tokens=False, |
||||
|
) |
||||
|
# add cls token at first place |
||||
|
input_ids = [[tokenizer.cls_token_id] + ids for ids in inputs['input_ids']] |
||||
|
attention_mask = [[1] + am for am in inputs['attention_mask']] |
||||
|
position_ids = [list(range(0, len(input_ids[0])))] * len(texts) |
||||
|
|
||||
|
input_ids = torch.tensor(input_ids, dtype=torch.long) |
||||
|
attention_mask = torch.tensor(attention_mask, dtype=torch.long) |
||||
|
position_ids = torch.tensor(position_ids, dtype=torch.long) |
||||
|
return { |
||||
|
"input_ids": input_ids.to(device), |
||||
|
"attention_mask": attention_mask.to(device), |
||||
|
"position_ids": position_ids.to(device), |
||||
|
} |
@ -0,0 +1,96 @@ |
|||||
|
# coding=utf-8 |
||||
|
# Copyright 2022 rinna Co., Ltd. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
||||
|
|
||||
|
from tqdm.auto import tqdm |
||||
|
import numpy as np |
||||
|
import torch |
||||
|
|
||||
|
|
||||
|
def accuracy(output, target, topk=(1,)): |
||||
|
output = torch.from_numpy(np.asarray(output)) |
||||
|
target = torch.from_numpy(np.asarray(target)) |
||||
|
pred = output.topk(max(topk), dim=1, largest=True, sorted=True)[1].t() |
||||
|
correct = pred.eq(target.view(1, -1).expand_as(pred)) |
||||
|
return [ |
||||
|
float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) |
||||
|
for k in topk |
||||
|
] |
||||
|
|
||||
|
|
||||
|
class ImagenetClassificationCallback: |
||||
|
def __init__( |
||||
|
self, |
||||
|
imagenet_classes, |
||||
|
imagenet_templates, |
||||
|
imagenet_dataloader, |
||||
|
): |
||||
|
self.imagenet_classes = imagenet_classes |
||||
|
self.imagenet_templates = imagenet_templates |
||||
|
self.imagenet_dataloader = imagenet_dataloader |
||||
|
|
||||
|
def tokenize(self, tokenizer, examples, device): |
||||
|
encoding_inputs = tokenizer(examples, max_length=76, padding="max_length", truncation=True, add_special_tokens=False) |
||||
|
# add cls token at first place |
||||
|
input_ids = [[tokenizer.cls_token_id] + ids for ids in encoding_inputs['input_ids']] |
||||
|
attention_mask = [[1] + am for am in encoding_inputs['attention_mask']] |
||||
|
position_ids = [list(range(0, len(input_ids[0])))] * len(examples) |
||||
|
|
||||
|
input_ids = torch.tensor(input_ids, dtype=torch.long, device=device) |
||||
|
attention_mask = torch.tensor(attention_mask, dtype=torch.long, device=device) |
||||
|
position_ids = torch.tensor(position_ids, dtype=torch.long, device=device) |
||||
|
return { |
||||
|
"input_ids": input_ids, |
||||
|
"attention_mask": attention_mask, |
||||
|
"position_ids": position_ids, |
||||
|
} |
||||
|
|
||||
|
def zeroshot_classifier(self, model, tokenizer, classnames, templates): |
||||
|
zeroshot_weights = [] |
||||
|
for classname in tqdm(classnames): |
||||
|
texts = [template.format(classname) for template in templates] |
||||
|
class_embeddings = model.get_text_features(**self.tokenize(tokenizer, texts, model.device)).detach().cpu().numpy() |
||||
|
class_embeddings = class_embeddings / np.linalg.norm( |
||||
|
class_embeddings, axis=-1, keepdims=True |
||||
|
) |
||||
|
class_embedding = np.mean(class_embeddings, axis=0) |
||||
|
class_embedding /= np.linalg.norm(class_embedding, axis=-1) |
||||
|
zeroshot_weights.append(class_embedding) |
||||
|
zeroshot_weights = np.stack(zeroshot_weights, axis=1) |
||||
|
return zeroshot_weights |
||||
|
|
||||
|
def zeroshot(self, model, tokenizer) -> dict: |
||||
|
print("Imagenet Zeroshot Classification...") |
||||
|
zeroshot_weights = self.zeroshot_classifier(model, tokenizer, self.imagenet_classes, self.imagenet_templates) |
||||
|
top_ns = [1, 5, 10, 100] |
||||
|
acc_counters = [0.0 for _ in top_ns] |
||||
|
n = 0.0 |
||||
|
|
||||
|
for i, (images, target) in enumerate(tqdm(self.imagenet_dataloader)): |
||||
|
target = target.numpy() |
||||
|
# predict |
||||
|
image_features = model.get_image_features(images.to(model.device)).detach().cpu().numpy() |
||||
|
image_features = image_features / np.linalg.norm(image_features, axis=-1, keepdims=True) |
||||
|
logits = 100.0 * image_features @ zeroshot_weights |
||||
|
|
||||
|
# measure accuracy |
||||
|
accs = accuracy(logits, target, topk=top_ns) |
||||
|
for j in range(len(top_ns)): |
||||
|
acc_counters[j] += accs[j] |
||||
|
n += images.shape[0] |
||||
|
|
||||
|
tops = {f"imagenet/top{top_ns[i]}": acc_counters[i] / n * 100 for i in range(len(top_ns))} |
||||
|
|
||||
|
return tops |
||||
|
|
File diff suppressed because it is too large
@ -0,0 +1,248 @@ |
|||||
|
imagenet_classnames = ["tench", "goldfish", "great white shark", "tiger shark", "hammerhead shark", "electric ray", |
||||
|
"stingray", "rooster", "hen", "ostrich", "brambling", "goldfinch", "house finch", "junco", |
||||
|
"indigo bunting", "American robin", "bulbul", "jay", "magpie", "chickadee", "American dipper", |
||||
|
"kite (bird of prey)", "bald eagle", "vulture", "great grey owl", "fire salamander", |
||||
|
"smooth newt", "newt", "spotted salamander", "axolotl", "American bullfrog", "tree frog", |
||||
|
"tailed frog", "loggerhead sea turtle", "leatherback sea turtle", "mud turtle", "terrapin", |
||||
|
"box turtle", "banded gecko", "green iguana", "Carolina anole", |
||||
|
"desert grassland whiptail lizard", "agama", "frilled-necked lizard", "alligator lizard", |
||||
|
"Gila monster", "European green lizard", "chameleon", "Komodo dragon", "Nile crocodile", |
||||
|
"American alligator", "triceratops", "worm snake", "ring-necked snake", |
||||
|
"eastern hog-nosed snake", "smooth green snake", "kingsnake", "garter snake", "water snake", |
||||
|
"vine snake", "night snake", "boa constrictor", "African rock python", "Indian cobra", |
||||
|
"green mamba", "sea snake", "Saharan horned viper", "eastern diamondback rattlesnake", |
||||
|
"sidewinder rattlesnake", "trilobite", "harvestman", "scorpion", "yellow garden spider", |
||||
|
"barn spider", "European garden spider", "southern black widow", "tarantula", "wolf spider", |
||||
|
"tick", "centipede", "black grouse", "ptarmigan", "ruffed grouse", "prairie grouse", "peafowl", |
||||
|
"quail", "partridge", "african grey parrot", "macaw", "sulphur-crested cockatoo", "lorikeet", |
||||
|
"coucal", "bee eater", "hornbill", "hummingbird", "jacamar", "toucan", "duck", |
||||
|
"red-breasted merganser", "goose", "black swan", "tusker", "echidna", "platypus", "wallaby", |
||||
|
"koala", "wombat", "jellyfish", "sea anemone", "brain coral", "flatworm", "nematode", "conch", |
||||
|
"snail", "slug", "sea slug", "chiton", "chambered nautilus", "Dungeness crab", "rock crab", |
||||
|
"fiddler crab", "red king crab", "American lobster", "spiny lobster", "crayfish", "hermit crab", |
||||
|
"isopod", "white stork", "black stork", "spoonbill", "flamingo", "little blue heron", |
||||
|
"great egret", "bittern bird", "crane bird", "limpkin", "common gallinule", "American coot", |
||||
|
"bustard", "ruddy turnstone", "dunlin", "common redshank", "dowitcher", "oystercatcher", |
||||
|
"pelican", "king penguin", "albatross", "grey whale", "killer whale", "dugong", "sea lion", |
||||
|
"Chihuahua", "Japanese Chin", "Maltese", "Pekingese", "Shih Tzu", "King Charles Spaniel", |
||||
|
"Papillon", "toy terrier", "Rhodesian Ridgeback", "Afghan Hound", "Basset Hound", "Beagle", |
||||
|
"Bloodhound", "Bluetick Coonhound", "Black and Tan Coonhound", "Treeing Walker Coonhound", |
||||
|
"English foxhound", "Redbone Coonhound", "borzoi", "Irish Wolfhound", "Italian Greyhound", |
||||
|
"Whippet", "Ibizan Hound", "Norwegian Elkhound", "Otterhound", "Saluki", "Scottish Deerhound", |
||||
|
"Weimaraner", "Staffordshire Bull Terrier", "American Staffordshire Terrier", |
||||
|
"Bedlington Terrier", "Border Terrier", "Kerry Blue Terrier", "Irish Terrier", |
||||
|
"Norfolk Terrier", "Norwich Terrier", "Yorkshire Terrier", "Wire Fox Terrier", |
||||
|
"Lakeland Terrier", "Sealyham Terrier", "Airedale Terrier", "Cairn Terrier", |
||||
|
"Australian Terrier", "Dandie Dinmont Terrier", "Boston Terrier", "Miniature Schnauzer", |
||||
|
"Giant Schnauzer", "Standard Schnauzer", "Scottish Terrier", "Tibetan Terrier", |
||||
|
"Australian Silky Terrier", "Soft-coated Wheaten Terrier", "West Highland White Terrier", |
||||
|
"Lhasa Apso", "Flat-Coated Retriever", "Curly-coated Retriever", "Golden Retriever", |
||||
|
"Labrador Retriever", "Chesapeake Bay Retriever", "German Shorthaired Pointer", "Vizsla", |
||||
|
"English Setter", "Irish Setter", "Gordon Setter", "Brittany dog", "Clumber Spaniel", |
||||
|
"English Springer Spaniel", "Welsh Springer Spaniel", "Cocker Spaniel", "Sussex Spaniel", |
||||
|
"Irish Water Spaniel", "Kuvasz", "Schipperke", "Groenendael dog", "Malinois", "Briard", |
||||
|
"Australian Kelpie", "Komondor", "Old English Sheepdog", "Shetland Sheepdog", "collie", |
||||
|
"Border Collie", "Bouvier des Flandres dog", "Rottweiler", "German Shepherd Dog", "Dobermann", |
||||
|
"Miniature Pinscher", "Greater Swiss Mountain Dog", "Bernese Mountain Dog", |
||||
|
"Appenzeller Sennenhund", "Entlebucher Sennenhund", "Boxer", "Bullmastiff", "Tibetan Mastiff", |
||||
|
"French Bulldog", "Great Dane", "St. Bernard", "husky", "Alaskan Malamute", "Siberian Husky", |
||||
|
"Dalmatian", "Affenpinscher", "Basenji", "pug", "Leonberger", "Newfoundland dog", |
||||
|
"Great Pyrenees dog", "Samoyed", "Pomeranian", "Chow Chow", "Keeshond", "brussels griffon", |
||||
|
"Pembroke Welsh Corgi", "Cardigan Welsh Corgi", "Toy Poodle", "Miniature Poodle", |
||||
|
"Standard Poodle", "Mexican hairless dog (xoloitzcuintli)", "grey wolf", "Alaskan tundra wolf", |
||||
|
"red wolf or maned wolf", "coyote", "dingo", "dhole", "African wild dog", "hyena", "red fox", |
||||
|
"kit fox", "Arctic fox", "grey fox", "tabby cat", "tiger cat", "Persian cat", "Siamese cat", |
||||
|
"Egyptian Mau", "cougar", "lynx", "leopard", "snow leopard", "jaguar", "lion", "tiger", |
||||
|
"cheetah", "brown bear", "American black bear", "polar bear", "sloth bear", "mongoose", |
||||
|
"meerkat", "tiger beetle", "ladybug", "ground beetle", "longhorn beetle", "leaf beetle", |
||||
|
"dung beetle", "rhinoceros beetle", "weevil", "fly", "bee", "ant", "grasshopper", |
||||
|
"cricket insect", "stick insect", "cockroach", "praying mantis", "cicada", "leafhopper", |
||||
|
"lacewing", "dragonfly", "damselfly", "red admiral butterfly", "ringlet butterfly", |
||||
|
"monarch butterfly", "small white butterfly", "sulphur butterfly", "gossamer-winged butterfly", |
||||
|
"starfish", "sea urchin", "sea cucumber", "cottontail rabbit", "hare", "Angora rabbit", |
||||
|
"hamster", "porcupine", "fox squirrel", "marmot", "beaver", "guinea pig", "common sorrel horse", |
||||
|
"zebra", "pig", "wild boar", "warthog", "hippopotamus", "ox", "water buffalo", "bison", |
||||
|
"ram (adult male sheep)", "bighorn sheep", "Alpine ibex", "hartebeest", "impala (antelope)", |
||||
|
"gazelle", "arabian camel", "llama", "weasel", "mink", "European polecat", |
||||
|
"black-footed ferret", "otter", "skunk", "badger", "armadillo", "three-toed sloth", "orangutan", |
||||
|
"gorilla", "chimpanzee", "gibbon", "siamang", "guenon", "patas monkey", "baboon", "macaque", |
||||
|
"langur", "black-and-white colobus", "proboscis monkey", "marmoset", "white-headed capuchin", |
||||
|
"howler monkey", "titi monkey", "Geoffroy's spider monkey", "common squirrel monkey", |
||||
|
"ring-tailed lemur", "indri", "Asian elephant", "African bush elephant", "red panda", |
||||
|
"giant panda", "snoek fish", "eel", "silver salmon", "rock beauty fish", "clownfish", |
||||
|
"sturgeon", "gar fish", "lionfish", "pufferfish", "abacus", "abaya", "academic gown", |
||||
|
"accordion", "acoustic guitar", "aircraft carrier", "airliner", "airship", "altar", "ambulance", |
||||
|
"amphibious vehicle", "analog clock", "apiary", "apron", "trash can", "assault rifle", |
||||
|
"backpack", "bakery", "balance beam", "balloon", "ballpoint pen", "Band-Aid", "banjo", |
||||
|
"baluster / handrail", "barbell", "barber chair", "barbershop", "barn", "barometer", "barrel", |
||||
|
"wheelbarrow", "baseball", "basketball", "bassinet", "bassoon", "swimming cap", "bath towel", |
||||
|
"bathtub", "station wagon", "lighthouse", "beaker", "military hat (bearskin or shako)", |
||||
|
"beer bottle", "beer glass", "bell tower", "baby bib", "tandem bicycle", "bikini", |
||||
|
"ring binder", "binoculars", "birdhouse", "boathouse", "bobsleigh", "bolo tie", "poke bonnet", |
||||
|
"bookcase", "bookstore", "bottle cap", "hunting bow", "bow tie", "brass memorial plaque", "bra", |
||||
|
"breakwater", "breastplate", "broom", "bucket", "buckle", "bulletproof vest", |
||||
|
"high-speed train", "butcher shop", "taxicab", "cauldron", "candle", "cannon", "canoe", |
||||
|
"can opener", "cardigan", "car mirror", "carousel", "tool kit", "cardboard box / carton", |
||||
|
"car wheel", "automated teller machine", "cassette", "cassette player", "castle", "catamaran", |
||||
|
"CD player", "cello", "mobile phone", "chain", "chain-link fence", "chain mail", "chainsaw", |
||||
|
"storage chest", "chiffonier", "bell or wind chime", "china cabinet", "Christmas stocking", |
||||
|
"church", "movie theater", "cleaver", "cliff dwelling", "cloak", "clogs", "cocktail shaker", |
||||
|
"coffee mug", "coffeemaker", "spiral or coil", "combination lock", "computer keyboard", |
||||
|
"candy store", "container ship", "convertible", "corkscrew", "cornet", "cowboy boot", |
||||
|
"cowboy hat", "cradle", "construction crane", "crash helmet", "crate", "infant bed", |
||||
|
"Crock Pot", "croquet ball", "crutch", "cuirass", "dam", "desk", "desktop computer", |
||||
|
"rotary dial telephone", "diaper", "digital clock", "digital watch", "dining table", |
||||
|
"dishcloth", "dishwasher", "disc brake", "dock", "dog sled", "dome", "doormat", "drilling rig", |
||||
|
"drum", "drumstick", "dumbbell", "Dutch oven", "electric fan", "electric guitar", |
||||
|
"electric locomotive", "entertainment center", "envelope", "espresso machine", "face powder", |
||||
|
"feather boa", "filing cabinet", "fireboat", "fire truck", "fire screen", "flagpole", "flute", |
||||
|
"folding chair", "football helmet", "forklift", "fountain", "fountain pen", "four-poster bed", |
||||
|
"freight car", "French horn", "frying pan", "fur coat", "garbage truck", |
||||
|
"gas mask or respirator", "gas pump", "goblet", "go-kart", "golf ball", "golf cart", "gondola", |
||||
|
"gong", "gown", "grand piano", "greenhouse", "radiator grille", "grocery store", "guillotine", |
||||
|
"hair clip", "hair spray", "half-track", "hammer", "hamper", "hair dryer", "hand-held computer", |
||||
|
"handkerchief", "hard disk drive", "harmonica", "harp", "combine harvester", "hatchet", |
||||
|
"holster", "home theater", "honeycomb", "hook", "hoop skirt", "gymnastic horizontal bar", |
||||
|
"horse-drawn vehicle", "hourglass", "iPod", "clothes iron", "carved pumpkin", "jeans", "jeep", |
||||
|
"T-shirt", "jigsaw puzzle", "rickshaw", "joystick", "kimono", "knee pad", "knot", "lab coat", |
||||
|
"ladle", "lampshade", "laptop computer", "lawn mower", "lens cap", "letter opener", "library", |
||||
|
"lifeboat", "lighter", "limousine", "ocean liner", "lipstick", "slip-on shoe", "lotion", |
||||
|
"music speaker", "loupe magnifying glass", "sawmill", "magnetic compass", "messenger bag", |
||||
|
"mailbox", "tights", "one-piece bathing suit", "manhole cover", "maraca", "marimba", "mask", |
||||
|
"matchstick", "maypole", "maze", "measuring cup", "medicine cabinet", "megalith", "microphone", |
||||
|
"microwave oven", "military uniform", "milk can", "minibus", "miniskirt", "minivan", "missile", |
||||
|
"mitten", "mixing bowl", "mobile home", "ford model t", "modem", "monastery", "monitor", |
||||
|
"moped", "mortar and pestle", "graduation cap", "mosque", "mosquito net", "vespa", |
||||
|
"mountain bike", "tent", "computer mouse", "mousetrap", "moving van", "muzzle", "metal nail", |
||||
|
"neck brace", "necklace", "baby pacifier", "notebook computer", "obelisk", "oboe", "ocarina", |
||||
|
"odometer", "oil filter", "pipe organ", "oscilloscope", "overskirt", "bullock cart", |
||||
|
"oxygen mask", "product packet / packaging", "paddle", "paddle wheel", "padlock", "paintbrush", |
||||
|
"pajamas", "palace", "pan flute", "paper towel", "parachute", "parallel bars", "park bench", |
||||
|
"parking meter", "railroad car", "patio", "payphone", "pedestal", "pencil case", |
||||
|
"pencil sharpener", "perfume", "Petri dish", "photocopier", "plectrum", "Pickelhaube", |
||||
|
"picket fence", "pickup truck", "pier", "piggy bank", "pill bottle", "pillow", "ping-pong ball", |
||||
|
"pinwheel", "pirate ship", "drink pitcher", "block plane", "planetarium", "plastic bag", |
||||
|
"plate rack", "farm plow", "plunger", "Polaroid camera", "pole", "police van", "poncho", |
||||
|
"pool table", "soda bottle", "plant pot", "potter's wheel", "power drill", "prayer rug", |
||||
|
"printer", "prison", "missile", "projector", "hockey puck", "punching bag", "purse", "quill", |
||||
|
"quilt", "race car", "racket", "radiator", "radio", "radio telescope", "rain barrel", |
||||
|
"recreational vehicle", "fishing casting reel", "reflex camera", "refrigerator", |
||||
|
"remote control", "restaurant", "revolver", "rifle", "rocking chair", "rotisserie", "eraser", |
||||
|
"rugby ball", "ruler measuring stick", "sneaker", "safe", "safety pin", "salt shaker", "sandal", |
||||
|
"sarong", "saxophone", "scabbard", "weighing scale", "school bus", "schooner", "scoreboard", |
||||
|
"CRT monitor", "screw", "screwdriver", "seat belt", "sewing machine", "shield", "shoe store", |
||||
|
"shoji screen / room divider", "shopping basket", "shopping cart", "shovel", "shower cap", |
||||
|
"shower curtain", "ski", "balaclava ski mask", "sleeping bag", "slide rule", "sliding door", |
||||
|
"slot machine", "snorkel", "snowmobile", "snowplow", "soap dispenser", "soccer ball", "sock", |
||||
|
"solar thermal collector", "sombrero", "soup bowl", "keyboard space bar", "space heater", |
||||
|
"space shuttle", "spatula", "motorboat", "spider web", "spindle", "sports car", "spotlight", |
||||
|
"stage", "steam locomotive", "through arch bridge", "steel drum", "stethoscope", "scarf", |
||||
|
"stone wall", "stopwatch", "stove", "strainer", "tram", "stretcher", "couch", "stupa", |
||||
|
"submarine", "suit", "sundial", "sunglasses", "sunglasses", "sunscreen", "suspension bridge", |
||||
|
"mop", "sweatshirt", "swim trunks / shorts", "swing", "electrical switch", "syringe", |
||||
|
"table lamp", "tank", "tape player", "teapot", "teddy bear", "television", "tennis ball", |
||||
|
"thatched roof", "front curtain", "thimble", "threshing machine", "throne", "tile roof", |
||||
|
"toaster", "tobacco shop", "toilet seat", "torch", "totem pole", "tow truck", "toy store", |
||||
|
"tractor", "semi-trailer truck", "tray", "trench coat", "tricycle", "trimaran", "tripod", |
||||
|
"triumphal arch", "trolleybus", "trombone", "hot tub", "turnstile", "typewriter keyboard", |
||||
|
"umbrella", "unicycle", "upright piano", "vacuum cleaner", "vase", "vaulted or arched ceiling", |
||||
|
"velvet fabric", "vending machine", "vestment", "viaduct", "violin", "volleyball", |
||||
|
"waffle iron", "wall clock", "wallet", "wardrobe", "military aircraft", "sink", |
||||
|
"washing machine", "water bottle", "water jug", "water tower", "whiskey jug", "whistle", |
||||
|
"hair wig", "window screen", "window shade", "Windsor tie", "wine bottle", "airplane wing", |
||||
|
"wok", "wooden spoon", "wool", "split-rail fence", "shipwreck", "sailboat", "yurt", "website", |
||||
|
"comic book", "crossword", "traffic or street sign", "traffic light", "dust jacket", "menu", |
||||
|
"plate", "guacamole", "consomme", "hot pot", "trifle", "ice cream", "popsicle", "baguette", |
||||
|
"bagel", "pretzel", "cheeseburger", "hot dog", "mashed potatoes", "cabbage", "broccoli", |
||||
|
"cauliflower", "zucchini", "spaghetti squash", "acorn squash", "butternut squash", "cucumber", |
||||
|
"artichoke", "bell pepper", "cardoon", "mushroom", "Granny Smith apple", "strawberry", "orange", |
||||
|
"lemon", "fig", "pineapple", "banana", "jackfruit", "cherimoya (custard apple)", "pomegranate", |
||||
|
"hay", "carbonara", "chocolate syrup", "dough", "meatloaf", "pizza", "pot pie", "burrito", |
||||
|
"red wine", "espresso", "tea cup", "eggnog", "mountain", "bubble", "cliff", "coral reef", |
||||
|
"geyser", "lakeshore", "promontory", "sandbar", "beach", "valley", "volcano", "baseball player", |
||||
|
"bridegroom", "scuba diver", "rapeseed", "daisy", "yellow lady's slipper", "corn", "acorn", |
||||
|
"rose hip", "horse chestnut seed", "coral fungus", "agaric", "gyromitra", "stinkhorn mushroom", |
||||
|
"earth star fungus", "hen of the woods mushroom", "bolete", "corn cob", "toilet paper"] |
||||
|
|
||||
|
imagenet_templates = [ |
||||
|
'a bad photo of a {}.', |
||||
|
'a photo of many {}.', |
||||
|
'a sculpture of a {}.', |
||||
|
'a photo of the hard to see {}.', |
||||
|
'a low resolution photo of the {}.', |
||||
|
'a rendering of a {}.', |
||||
|
'graffiti of a {}.', |
||||
|
'a bad photo of the {}.', |
||||
|
'a cropped photo of the {}.', |
||||
|
'a tattoo of a {}.', |
||||
|
'the embroidered {}.', |
||||
|
'a photo of a hard to see {}.', |
||||
|
'a bright photo of a {}.', |
||||
|
'a photo of a clean {}.', |
||||
|
'a photo of a dirty {}.', |
||||
|
'a dark photo of the {}.', |
||||
|
'a drawing of a {}.', |
||||
|
'a photo of my {}.', |
||||
|
'the plastic {}.', |
||||
|
'a photo of the cool {}.', |
||||
|
'a close-up photo of a {}.', |
||||
|
'a black and white photo of the {}.', |
||||
|
'a painting of the {}.', |
||||
|
'a painting of a {}.', |
||||
|
'a pixelated photo of the {}.', |
||||
|
'a sculpture of the {}.', |
||||
|
'a bright photo of the {}.', |
||||
|
'a cropped photo of a {}.', |
||||
|
'a plastic {}.', |
||||
|
'a photo of the dirty {}.', |
||||
|
'a jpeg corrupted photo of a {}.', |
||||
|
'a blurry photo of the {}.', |
||||
|
'a photo of the {}.', |
||||
|
'a good photo of the {}.', |
||||
|
'a rendering of the {}.', |
||||
|
'a {} in a video game.', |
||||
|
'a photo of one {}.', |
||||
|
'a doodle of a {}.', |
||||
|
'a close-up photo of the {}.', |
||||
|
'a photo of a {}.', |
||||
|
'the origami {}.', |
||||
|
'the {} in a video game.', |
||||
|
'a sketch of a {}.', |
||||
|
'a doodle of the {}.', |
||||
|
'a origami {}.', |
||||
|
'a low resolution photo of a {}.', |
||||
|
'the toy {}.', |
||||
|
'a rendition of the {}.', |
||||
|
'a photo of the clean {}.', |
||||
|
'a photo of a large {}.', |
||||
|
'a rendition of a {}.', |
||||
|
'a photo of a nice {}.', |
||||
|
'a photo of a weird {}.', |
||||
|
'a blurry photo of a {}.', |
||||
|
'a cartoon {}.', |
||||
|
'art of a {}.', |
||||
|
'a sketch of the {}.', |
||||
|
'a embroidered {}.', |
||||
|
'a pixelated photo of a {}.', |
||||
|
'itap of the {}.', |
||||
|
'a jpeg corrupted photo of the {}.', |
||||
|
'a good photo of a {}.', |
||||
|
'a plushie {}.', |
||||
|
'a photo of the nice {}.', |
||||
|
'a photo of the small {}.', |
||||
|
'a photo of the weird {}.', |
||||
|
'the cartoon {}.', |
||||
|
'art of the {}.', |
||||
|
'a drawing of the {}.', |
||||
|
'a photo of the large {}.', |
||||
|
'a black and white photo of a {}.', |
||||
|
'the plushie {}.', |
||||
|
'a dark photo of a {}.', |
||||
|
'itap of a {}.', |
||||
|
'graffiti of the {}.', |
||||
|
'a toy {}.', |
||||
|
'itap of my {}.', |
||||
|
'a photo of a cool {}.', |
||||
|
'a photo of a small {}.', |
||||
|
'a tattoo of the {}.', |
||||
|
] |
@ -0,0 +1,16 @@ |
|||||
|
# coding=utf-8 |
||||
|
# Copyright 2022 rinna Co., Ltd. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
||||
|
|
||||
|
__version__ = '0.2.0' |
@ -0,0 +1,79 @@ |
|||||
|
# Copyright 2021 Zilliz. All rights reserved. |
||||
|
# |
||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); |
||||
|
# you may not use this file except in compliance with the License. |
||||
|
# You may obtain a copy of the License at |
||||
|
# |
||||
|
# http://www.apache.org/licenses/LICENSE-2.0 |
||||
|
# |
||||
|
# Unless required by applicable law or agreed to in writing, software |
||||
|
# distributed under the License is distributed on an "AS IS" BASIS, |
||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
||||
|
# See the License for the specific language governing permissions and |
||||
|
# limitations under the License. |
||||
|
|
||||
|
import sys |
||||
|
from pathlib import Path |
||||
|
|
||||
|
import torch |
||||
|
|
||||
|
from towhee import register |
||||
|
from towhee.operator.base import NNOperator, OperatorFlag |
||||
|
from towhee.types.arg import arg, to_image_color |
||||
|
from towhee.types.image_utils import from_pil, to_pil |
||||
|
|
||||
|
@register(output_schema=['vec']) |
||||
|
class Jaclip(NNOperator): |
||||
|
""" |
||||
|
Japanese CLIP multi-modal embedding operator |
||||
|
""" |
||||
|
def __init__(self, model_name: str, modality: str): |
||||
|
super().__init__() |
||||
|
path = str(Path(__file__).parent) |
||||
|
sys.path.append(path) |
||||
|
import japanese_clip as ja_clip |
||||
|
sys.path.pop() |
||||
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
||||
|
model, preprocess = ja_clip.load("rinna/japanese-clip-vit-b-16", cache_dir="{}/weights/japanese_clip".format(path), device=self.device) |
||||
|
self.model = model |
||||
|
self.tfms = preprocess |
||||
|
self.tokenizer = ja_clip.load_tokenizer() |
||||
|
self.ja_clip = ja_clip |
||||
|
|
||||
|
|
||||
|
def __call__(self, data): |
||||
|
if self._modality == 'image': |
||||
|
vec = self._inference_from_image(data) |
||||
|
elif self._modality == 'text': |
||||
|
vec = self._inference_from_text(data) |
||||
|
else: |
||||
|
raise ValueError("modality[{}] not implemented.".format(self._modality)) |
||||
|
return vec.detach().cpu().numpy().flatten() |
||||
|
|
||||
|
def _inference_from_text(self, text): |
||||
|
encodings = ja_clip.tokenize( |
||||
|
texts=[text], |
||||
|
max_seq_len=77, |
||||
|
device=self.device, |
||||
|
tokenizer=self.tokenizer, # this is optional. if you don't pass, load tokenizer each time |
||||
|
) |
||||
|
text_feature = model.get_text_features(**encodings) |
||||
|
return text_feature |
||||
|
|
||||
|
@arg(1, to_image_color('RGB')) |
||||
|
def _inference_from_image(self, img): |
||||
|
img = self._preprocess(img) |
||||
|
caption = '' |
||||
|
image_feature = self.model.get_image_features(image) |
||||
|
return image_feature |
||||
|
|
||||
|
def _preprocess(self, img): |
||||
|
img = to_pil(img) |
||||
|
processed_img = self.tfms(img).unsqueeze(0).to(self.device) |
||||
|
return processed_img |
||||
|
|
||||
|
def _configs(self): |
||||
|
config = {} |
||||
|
config['blip_base'] = {} |
||||
|
config['blip_base']['weights'] = 'https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth' |
||||
|
return config |
Loading…
Reference in new issue