diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 0000000..9bb3ccd Binary files /dev/null and b/.DS_Store differ diff --git a/README.md b/README.md index c909341..a832f4a 100644 --- a/README.md +++ b/README.md @@ -55,7 +55,7 @@ Create the operator via the following factory method ​ ***model_name:*** *str* -​ The model name of BLIP. Supported model names: +​ The model name of CLIPReward. Supported model names: - clipRN50_clips_grammar
diff --git a/__init__.py b/__init__.py index bfd331d..c27ca48 100644 --- a/__init__.py +++ b/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from clip_caption_reward import ClipCaptionReward +from .clip_caption_reward import ClipCaptionReward def clip_caption_reward(model_name: str): return ClipCaptionReward(model_name) diff --git a/cap.png b/cap.png new file mode 100644 index 0000000..1d641a7 Binary files /dev/null and b/cap.png differ diff --git a/clip/__init__.py b/clip/__init__.py deleted file mode 100644 index dcc5619..0000000 --- a/clip/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .clip import * diff --git a/clip/bpe_simple_vocab_16e6.txt.gz b/clip/bpe_simple_vocab_16e6.txt.gz deleted file mode 100644 index 36a1585..0000000 --- a/clip/bpe_simple_vocab_16e6.txt.gz +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a -size 1356917 diff --git a/clip/clip.py b/clip/clip.py deleted file mode 100644 index 76f241b..0000000 --- a/clip/clip.py +++ /dev/null @@ -1,193 +0,0 @@ -import hashlib -import os -import urllib -import warnings -from typing import Union, List - -import torch -from PIL import Image -from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize -from tqdm import tqdm - -from .model import build_model -from .simple_tokenizer import SimpleTokenizer as _Tokenizer - -__all__ = ["available_models", "load", "tokenize"] -_tokenizer = _Tokenizer() - -_MODELS = { - "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", - "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", - "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", - "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", -} - - -def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): - os.makedirs(root, exist_ok=True) - filename = os.path.basename(url) - - expected_sha256 = url.split("/")[-2] - download_target = os.path.join(root, filename) - - if os.path.exists(download_target) and not os.path.isfile(download_target): - raise RuntimeError(f"{download_target} exists and is not a regular file") - - if os.path.isfile(download_target): - if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256: - return download_target - else: - warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file") - - with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: - with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop: - while True: - buffer = source.read(8192) - if not buffer: - break - - output.write(buffer) - loop.update(len(buffer)) - - if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256: - raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match") - - return download_target - - -def _transform(n_px): - return Compose([ - Resize(n_px, interpolation=Image.BICUBIC), - CenterCrop(n_px), - lambda image: image.convert("RGB"), - ToTensor(), - Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), - ]) - - -def available_models() -> List[str]: - """Returns the names of available CLIP models""" - return list(_MODELS.keys()) - - -def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True): - """Load a CLIP model - - Parameters - ---------- - name : str - A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict - - device : Union[str, torch.device] - The device to put the loaded model - - jit : bool - Whether to load the optimized JIT model (default) or more hackable non-JIT model. - - Returns - ------- - model : torch.nn.Module - The CLIP model - - preprocess : Callable[[PIL.Image], torch.Tensor] - A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input - """ - if name in _MODELS: - model_path = _download(_MODELS[name]) - elif os.path.isfile(name): - model_path = name - else: - raise RuntimeError(f"Model {name} not found; available models = {available_models()}") - - try: - # loading JIT archive - model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() - state_dict = None - except RuntimeError: - # loading saved state dict - if jit: - warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead") - jit = False - state_dict = torch.load(model_path, map_location="cpu") - - if not jit: - model = build_model(state_dict or model.state_dict()).to(device) - if str(device) == "cpu": - model.float() - return model, _transform(model.visual.input_resolution) - - # patch the device names - device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[]) - device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1] - - def patch_device(module): - graphs = [module.graph] if hasattr(module, "graph") else [] - if hasattr(module, "forward1"): - graphs.append(module.forward1.graph) - - for graph in graphs: - for node in graph.findAllNodes("prim::Constant"): - if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"): - node.copyAttributes(device_node) - - model.apply(patch_device) - patch_device(model.encode_image) - patch_device(model.encode_text) - - # patch dtype to float32 on CPU - if str(device) == "cpu": - float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[]) - float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] - float_node = float_input.node() - - def patch_float(module): - graphs = [module.graph] if hasattr(module, "graph") else [] - if hasattr(module, "forward1"): - graphs.append(module.forward1.graph) - - for graph in graphs: - for node in graph.findAllNodes("aten::to"): - inputs = list(node.inputs()) - for i in [1, 2]: # dtype can be the second or third argument to aten::to() - if inputs[i].node()["value"] == 5: - inputs[i].node().copyAttributes(float_node) - - model.apply(patch_float) - patch_float(model.encode_image) - patch_float(model.encode_text) - - model.float() - - return model, _transform(model.input_resolution.item()) - - -def tokenize(texts: Union[str, List[str]], context_length: int = 77) -> torch.LongTensor: - """ - Returns the tokenized representation of given input string(s) - - Parameters - ---------- - texts : Union[str, List[str]] - An input string or a list of input strings to tokenize - - context_length : int - The context length to use; all CLIP models use 77 as the context length - - Returns - ------- - A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] - """ - if isinstance(texts, str): - texts = [texts] - - sot_token = _tokenizer.encoder["<|startoftext|>"] - eot_token = _tokenizer.encoder["<|endoftext|>"] - all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] - result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) - - for i, tokens in enumerate(all_tokens): - if len(tokens) > context_length: - raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}") - result[i, :len(tokens)] = torch.tensor(tokens) - - return result diff --git a/clip/model.py b/clip/model.py deleted file mode 100644 index 049391e..0000000 --- a/clip/model.py +++ /dev/null @@ -1,437 +0,0 @@ -from collections import OrderedDict -from typing import Tuple, Union - -import torch -import torch.nn.functional as F -from torch import nn - - -class Bottleneck(nn.Module): - expansion = 4 - - def __init__(self, inplanes, planes, stride=1): - super().__init__() - - # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 - self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) - self.bn1 = nn.BatchNorm2d(planes) - - self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) - self.bn2 = nn.BatchNorm2d(planes) - - self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() - - self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) - self.bn3 = nn.BatchNorm2d(planes * self.expansion) - - self.relu = nn.ReLU(inplace=True) - self.downsample = None - self.stride = stride - - if stride > 1 or inplanes != planes * Bottleneck.expansion: - # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 - self.downsample = nn.Sequential(OrderedDict([ - ("-1", nn.AvgPool2d(stride)), - ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)), - ("1", nn.BatchNorm2d(planes * self.expansion)) - ])) - - def forward(self, x: torch.Tensor): - identity = x - - out = self.relu(self.bn1(self.conv1(x))) - out = self.relu(self.bn2(self.conv2(out))) - out = self.avgpool(out) - out = self.bn3(self.conv3(out)) - - if self.downsample is not None: - identity = self.downsample(x) - - out += identity - out = self.relu(out) - return out - - -class AttentionPool2d(nn.Module): - def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): - super().__init__() - self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5) - self.k_proj = nn.Linear(embed_dim, embed_dim) - self.q_proj = nn.Linear(embed_dim, embed_dim) - self.v_proj = nn.Linear(embed_dim, embed_dim) - self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) - self.num_heads = num_heads - - def forward(self, x): - x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC - # print(x.shape, self.positional_embedding.shape) - x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC - x = x + self.positional_embedding[0, :, None, :].to(x.dtype) # (HW+1)NC - x, _ = F.multi_head_attention_forward( - query=x, key=x, value=x, - embed_dim_to_check=x.shape[-1], - num_heads=self.num_heads, - q_proj_weight=self.q_proj.weight, - k_proj_weight=self.k_proj.weight, - v_proj_weight=self.v_proj.weight, - in_proj_weight=None, - in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), - bias_k=None, - bias_v=None, - add_zero_attn=False, - dropout_p=0, - out_proj_weight=torch.ones_like(self.q_proj.weight), - out_proj_bias=torch.zeros_like(self.q_proj.bias), - # out_proj_weight=self.c_proj.weight, - # out_proj_bias=self.c_proj.bias, - use_separate_proj_weight=True, - training=self.training, - need_weights=False - ) - - return x[0] - - -class ModifiedResNet(nn.Module): - """ - A ResNet class that is similar to torchvision's but contains the following changes: - - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. - - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 - - The final pooling layer is a QKV attention instead of an average pool - """ - - def __init__(self, layers, output_dim, heads, input_resolution=224, width=64): - super().__init__() - self.output_dim = output_dim - self.input_resolution = input_resolution - - # the 3-layer stem - self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False) - self.bn1 = nn.BatchNorm2d(width // 2) - self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False) - self.bn2 = nn.BatchNorm2d(width // 2) - self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) - self.bn3 = nn.BatchNorm2d(width) - self.avgpool = nn.AvgPool2d(2) - self.relu = nn.ReLU(inplace=True) - - # residual layers - self._inplanes = width # this is a *mutable* variable used during construction - self.layer1 = self._make_layer(width, layers[0]) - self.layer2 = self._make_layer(width * 2, layers[1], stride=2) - self.layer3 = self._make_layer(width * 4, layers[2], stride=2) - self.layer4 = self._make_layer(width * 8, layers[3], stride=2) - - embed_dim = width * 32 # the ResNet feature dimension - self.attnpool = AttentionPool2d(input_resolution // 32, embed_dim, heads, output_dim) - - def _make_layer(self, planes, blocks, stride=1): - layers = [Bottleneck(self._inplanes, planes, stride)] - - self._inplanes = planes * Bottleneck.expansion - for _ in range(1, blocks): - layers.append(Bottleneck(self._inplanes, planes)) - - return nn.Sequential(*layers) - - def forward(self, x): - def stem(x): - for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]: - x = self.relu(bn(conv(x))) - x = self.avgpool(x) - return x - - x = x.type(self.conv1.weight.dtype) - x = stem(x) - x = self.layer1(x) - x = self.layer2(x) - x = self.layer3(x) - x = self.layer4(x) - # print(x.shape) - # x = self.attnpool(x) - attnpool = self.attnpool(x) - - return (x, attnpool) - - -class LayerNorm(nn.LayerNorm): - """Subclass torch's LayerNorm to handle fp16.""" - - def forward(self, x: torch.Tensor): - orig_type = x.dtype - ret = super().forward(x.type(torch.float32)) - return ret.type(orig_type) - - -class QuickGELU(nn.Module): - def forward(self, x: torch.Tensor): - return x * torch.sigmoid(1.702 * x) - - -class ResidualAttentionBlock(nn.Module): - def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None): - super().__init__() - - self.attn = nn.MultiheadAttention(d_model, n_head) - self.ln_1 = LayerNorm(d_model) - self.mlp = nn.Sequential(OrderedDict([ - ("c_fc", nn.Linear(d_model, d_model * 4)), - ("gelu", QuickGELU()), - ("c_proj", nn.Linear(d_model * 4, d_model)) - ])) - self.ln_2 = LayerNorm(d_model) - self.attn_mask = attn_mask - - def attention(self, x: torch.Tensor): - self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None - return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0] - - def forward(self, x: torch.Tensor): - x = x + self.attention(self.ln_1(x)) - x = x + self.mlp(self.ln_2(x)) - return x - - -class Transformer(nn.Module): - def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None): - super().__init__() - self.width = width - self.layers = layers - self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)]) - - def forward(self, x: torch.Tensor): - return self.resblocks(x) - - -class VisualTransformer(nn.Module): - def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int): - super().__init__() - self.input_resolution = input_resolution - self.output_dim = output_dim - self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) - - scale = width ** -0.5 - self.class_embedding = nn.Parameter(scale * torch.randn(width)) - self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width)) - self.ln_pre = LayerNorm(width) - - self.transformer = Transformer(width, layers, heads) - - self.ln_post = LayerNorm(width) - self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) - - def forward(self, x: torch.Tensor): - x = self.conv1(x) # shape = [*, width, grid, grid] - x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] - x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] - x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width] - x = x + self.positional_embedding.to(x.dtype) - x = self.ln_pre(x) - - x = x.permute(1, 0, 2) # NLD -> LND - x = self.transformer(x) - x = x.permute(1, 0, 2) # LND -> NLD - - # x = self.ln_post(x[:, 0, :]) - - x = self.ln_post(x) - # if self.proj is not None: - # x = x @ self.proj - - return x - - -class CLIP(nn.Module): - def __init__(self, - embed_dim: int, - # vision - image_resolution: int, - vision_layers: Union[Tuple[int, int, int, int], int], - vision_width: int, - vision_patch_size: int, - # text - context_length: int, - vocab_size: int, - transformer_width: int, - transformer_heads: int, - transformer_layers: int - ): - super().__init__() - - self.context_length = context_length - - if isinstance(vision_layers, (tuple, list)): - vision_heads = vision_width * 32 // 64 - self.visual = ModifiedResNet( - layers=vision_layers, - output_dim=embed_dim, - heads=vision_heads, - input_resolution=image_resolution, - width=vision_width - ) - else: - vision_heads = vision_width // 64 - self.visual = VisualTransformer( - input_resolution=image_resolution, - patch_size=vision_patch_size, - width=vision_width, - layers=vision_layers, - heads=vision_heads, - output_dim=embed_dim - ) - - self.transformer = Transformer( - width=transformer_width, - layers=transformer_layers, - heads=transformer_heads, - attn_mask=self.build_attention_mask() - ) - - self.vocab_size = vocab_size - self.token_embedding = nn.Embedding(vocab_size, transformer_width) - self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width)) - self.ln_final = LayerNorm(transformer_width) - - self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim)) - self.logit_scale = nn.Parameter(torch.ones([])) - - self.initialize_parameters() - - def initialize_parameters(self): - nn.init.normal_(self.token_embedding.weight, std=0.02) - nn.init.normal_(self.positional_embedding, std=0.01) - - if isinstance(self.visual, ModifiedResNet): - if self.visual.attnpool is not None: - std = self.visual.attnpool.c_proj.in_features ** -0.5 - nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std) - nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std) - nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std) - nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std) - - for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]: - for name, param in resnet_block.named_parameters(): - if name.endswith("bn3.weight"): - nn.init.zeros_(param) - - proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5) - attn_std = self.transformer.width ** -0.5 - fc_std = (2 * self.transformer.width) ** -0.5 - for block in self.transformer.resblocks: - nn.init.normal_(block.attn.in_proj_weight, std=attn_std) - nn.init.normal_(block.attn.out_proj.weight, std=proj_std) - nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) - nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) - - if self.text_projection is not None: - nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5) - - def build_attention_mask(self): - # lazily create causal attention mask, with full attention between the vision tokens - # pytorch uses additive attention mask; fill with -inf - mask = torch.empty(self.context_length, self.context_length) - mask.fill_(float("-inf")) - mask.triu_(1) # zero out the lower diagonal - return mask - - @property - def dtype(self): - return self.visual.conv1.weight.dtype - - def encode_image(self, image): - return self.visual(image.type(self.dtype)) - - def encode_text(self, text): - x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model] - - x = x + self.positional_embedding.type(self.dtype) - x = x.permute(1, 0, 2) # NLD -> LND - x = self.transformer(x) - x = x.permute(1, 0, 2) # LND -> NLD - x = self.ln_final(x).type(self.dtype) - - # x.shape = [batch_size, n_ctx, transformer.width] - # take features from the eot embedding (eot_token is the highest number in each sequence) - x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection - - return x - - def forward(self, image, text): - image_features = self.encode_image(image) - text_features = self.encode_text(text) - - # normalized features - image_features = image_features / image_features.norm(dim=-1, keepdim=True) - text_features = text_features / text_features.norm(dim=-1, keepdim=True) - - # cosine similarity as logits - logit_scale = self.logit_scale.exp() - logits_per_image = logit_scale * image_features @ text_features.t() - logits_per_text = logit_scale * text_features @ image_features.t() - - # shape = [global_batch_size, global_batch_size] - return logits_per_image, logits_per_text - - -def convert_weights(model: nn.Module): - """Convert applicable model parameters to fp16""" - - def _convert_weights_to_fp16(l): - if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): - l.weight.data = l.weight.data.half() - if l.bias is not None: - l.bias.data = l.bias.data.half() - - if isinstance(l, nn.MultiheadAttention): - for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]: - tensor = getattr(l, attr) - if tensor is not None: - tensor.data = tensor.data.half() - - for name in ["text_projection", "proj"]: - if hasattr(l, name): - attr = getattr(l, name) - if attr is not None: - attr.data = attr.data.half() - - model.apply(_convert_weights_to_fp16) - - -def build_model(state_dict: dict): - vit = "visual.proj" in state_dict - - if vit: - vision_width = state_dict["visual.conv1.weight"].shape[0] - vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")]) - vision_patch_size = state_dict["visual.conv1.weight"].shape[-1] - grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5) - image_resolution = vision_patch_size * grid_size - else: - counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]] - vision_layers = tuple(counts) - vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0] - output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5) - vision_patch_size = None - assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0] - image_resolution = output_width * 32 - - embed_dim = state_dict["text_projection"].shape[1] - context_length = state_dict["positional_embedding"].shape[0] - vocab_size = state_dict["token_embedding.weight"].shape[0] - transformer_width = state_dict["ln_final.weight"].shape[0] - transformer_heads = transformer_width // 64 - transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks"))) - - model = CLIP( - embed_dim, - image_resolution, vision_layers, vision_width, vision_patch_size, - context_length, vocab_size, transformer_width, transformer_heads, transformer_layers - ) - - for key in ["input_resolution", "context_length", "vocab_size"]: - if key in state_dict: - del state_dict[key] - - convert_weights(model) - model.load_state_dict(state_dict) - return model.eval() diff --git a/clip/simple_tokenizer.py b/clip/simple_tokenizer.py deleted file mode 100644 index 0a66286..0000000 --- a/clip/simple_tokenizer.py +++ /dev/null @@ -1,132 +0,0 @@ -import gzip -import html -import os -from functools import lru_cache - -import ftfy -import regex as re - - -@lru_cache() -def default_bpe(): - return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz") - - -@lru_cache() -def bytes_to_unicode(): - """ - Returns list of utf-8 byte and a corresponding list of unicode strings. - The reversible bpe codes work on unicode strings. - This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. - When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. - This is a signficant percentage of your normal, say, 32K bpe vocab. - To avoid that, we want lookup tables between utf-8 bytes and unicode strings. - And avoids mapping to whitespace/control characters the bpe code barfs on. - """ - bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1)) - cs = bs[:] - n = 0 - for b in range(2**8): - if b not in bs: - bs.append(b) - cs.append(2**8+n) - n += 1 - cs = [chr(n) for n in cs] - return dict(zip(bs, cs)) - - -def get_pairs(word): - """Return set of symbol pairs in a word. - Word is represented as tuple of symbols (symbols being variable-length strings). - """ - pairs = set() - prev_char = word[0] - for char in word[1:]: - pairs.add((prev_char, char)) - prev_char = char - return pairs - - -def basic_clean(text): - text = ftfy.fix_text(text) - text = html.unescape(html.unescape(text)) - return text.strip() - - -def whitespace_clean(text): - text = re.sub(r'\s+', ' ', text) - text = text.strip() - return text - - -class SimpleTokenizer(object): - def __init__(self, bpe_path: str = default_bpe()): - self.byte_encoder = bytes_to_unicode() - self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} - merges = gzip.open(bpe_path).read().decode("utf-8").split('\n') - merges = merges[1:49152-256-2+1] - merges = [tuple(merge.split()) for merge in merges] - vocab = list(bytes_to_unicode().values()) - vocab = vocab + [v+'' for v in vocab] - for merge in merges: - vocab.append(''.join(merge)) - vocab.extend(['<|startoftext|>', '<|endoftext|>']) - self.encoder = dict(zip(vocab, range(len(vocab)))) - self.decoder = {v: k for k, v in self.encoder.items()} - self.bpe_ranks = dict(zip(merges, range(len(merges)))) - self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'} - self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE) - - def bpe(self, token): - if token in self.cache: - return self.cache[token] - word = tuple(token[:-1]) + ( token[-1] + '',) - pairs = get_pairs(word) - - if not pairs: - return token+'' - - while True: - bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf'))) - if bigram not in self.bpe_ranks: - break - first, second = bigram - new_word = [] - i = 0 - while i < len(word): - try: - j = word.index(first, i) - new_word.extend(word[i:j]) - i = j - except: - new_word.extend(word[i:]) - break - - if word[i] == first and i < len(word)-1 and word[i+1] == second: - new_word.append(first+second) - i += 2 - else: - new_word.append(word[i]) - i += 1 - new_word = tuple(new_word) - word = new_word - if len(word) == 1: - break - else: - pairs = get_pairs(word) - word = ' '.join(word) - self.cache[token] = word - return word - - def encode(self, text): - bpe_tokens = [] - text = whitespace_clean(basic_clean(text)).lower() - for token in re.findall(self.pat, text): - token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8')) - bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' ')) - return bpe_tokens - - def decode(self, tokens): - text = ''.join([self.decoder[token] for token in tokens]) - text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ') - return text diff --git a/clip_caption_reward.py b/clip_caption_reward.py index e528172..aba1a33 100644 --- a/clip_caption_reward.py +++ b/clip_caption_reward.py @@ -11,11 +11,20 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import sys import pathlib - +import json +from pathlib import Path +from PIL import Image +import numpy as np +import torch from torch import nn +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize +from torchvision.transforms.functional import InterpolationMode from timm.models.vision_transformer import resize_pos_embed from towhee.types.image_utils import to_pil +from towhee.types.arg import arg, to_image_color +from towhee.operator.base import NNOperator, OperatorFlag class ClipCaptionReward(NNOperator): @@ -26,14 +35,22 @@ class ClipCaptionReward(NNOperator): super().__init__() sys.path.append(str(Path(__file__).parent)) from utils import opts - import clip - opt = opts.parse_opt(parse=False, cfg=cfg) + from transformer_model import TransformerModel + from captioning.models.model_utils import decode_sequence + self.decode_sequence = decode_sequence + import mclip + sys.path.pop() + path = pathlib.Path(__file__).parent + + cfg = self._configs()[model_name] + config = str(path) + cfg['config'] + opt = opts.parse_opt(parse=False, cfg=(config)) dict_json = json.load(open("{}/data/cocotalk.json".format(path))) ix_to_word = dict_json["ix_to_word"] self.device = "cuda" if torch.cuda.is_available() else "cpu" - clip_model, clip_transform = clip.load("RN50", jit=False, device=self.device) + clip_model, clip_transform = mclip.load("RN50", jit=False, device=self.device) self.clip_model = clip_model self.clip_transform = clip_transform @@ -59,7 +76,12 @@ class ClipCaptionReward(NNOperator): ) self.clip_model.visual.attnpool.positional_embedding = pos_embed + ckpt_path = str(path) + cfg['weights'] + raw_state_dict = torch.load(ckpt_path, map_location=torch.device('cpu')) + self.model = TransformerModel(opt) + self.model.load_state_dict(raw_state_dict) + self.image_mean = ( torch.Tensor([0.48145466, 0.4578275, 0.40821073]) .to(self.device) @@ -70,17 +92,25 @@ class ClipCaptionReward(NNOperator): .to(self.device) .reshape(3, 1, 1) ) + self._preprocess = Compose( + [ + Resize((448, 448), interpolation= InterpolationMode.BILINEAR), + CenterCrop((448, 448)), + ToTensor(), + ] + ) + self.eval_kwargs = {} + self.eval_kwargs.update(vars(opt)) @arg(1, to_image_color('RGB')) def inference_single_data(self, data): text = self._inference_from_image(data) return text - @arg(1, to_image_color('RGB')) def _inference_from_image(self, img): img = to_pil(img) img = self._preprocess(img) - self._inference_from_image(img) + img = torch.tensor(np.stack([img])).to(self.device) img -= self.image_mean img /= self.image_std tmp_att, tmp_fc = self.clip_model.encode_image(img) @@ -88,9 +118,27 @@ class ClipCaptionReward(NNOperator): att_feat = tmp_att - return att_feat + + with torch.no_grad(): + fc_feats = torch.zeros((1, 0)).to(self.device) + att_feats = att_feat.view(1, 196, 2048).float().to(self.device) + att_masks = None + + # forward the model to also get generated samples for each image + # Only leave one feature for each image, in case duplicate sample + tmp_eval_kwargs = self.eval_kwargs.copy() + tmp_eval_kwargs.update({"sample_n": 1}) + seq, seq_logprobs = self.model( + fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode="sample" + ) + seq = seq.data + + sents = self.decode_sequence(self.model.vocab, seq) + + return sents[0] def __call__(self, data): + results = [] if not isinstance(data, list): data = [data] else: @@ -103,3 +151,10 @@ class ClipCaptionReward(NNOperator): else: return results + def _configs(self): + config = {} + config['clipRN50_clips_grammar'] = {} + config['clipRN50_clips_grammar']['weights'] = '/weights/clipRN50_clips_grammar-last.pth' + config['clipRN50_clips_grammar']['config'] = '/configs/phase2/clipRN50_clips_grammar.yml' + return config + diff --git a/mclip/clip.py b/mclip/clip.py index 76f241b..262fa02 100644 --- a/mclip/clip.py +++ b/mclip/clip.py @@ -7,6 +7,7 @@ from typing import Union, List import torch from PIL import Image from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize +from torchvision.transforms.functional import InterpolationMode from tqdm import tqdm from .model import build_model @@ -57,7 +58,7 @@ def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")): def _transform(n_px): return Compose([ - Resize(n_px, interpolation=Image.BICUBIC), + Resize(n_px, interpolation=InterpolationMode.BILINEAR), CenterCrop(n_px), lambda image: image.convert("RGB"), ToTensor(), diff --git a/tabular.png b/tabular.png new file mode 100644 index 0000000..81a55ee Binary files /dev/null and b/tabular.png differ