diff --git a/README.md b/README.md
index 2f5be52..4cc011e 100644
--- a/README.md
+++ b/README.md
@@ -1,2 +1,85 @@
-# camel
+# Image Captioning with CaMEL
+
+*author: David Wang*
+
+
+
+
+
+## Description
+
+This operator generates the caption with [CapDec](https://arxiv.org/abs/2211.00575) which describes the content of the given image. ExpansionNet v2 introduces the Block Static Expansion which distributes and processes the input over a heterogeneous and arbitrarily big collection of sequences characterized by a different length compared to the input one. This is an adaptation from [DavidHuji/CapDec](https://github.com/DavidHuji/CapDec).
+
+
+
+
+
+## Code Example
+
+Load an image from path './image.jpg' to generate the caption.
+
+ *Write the pipeline in simplified style*:
+
+```python
+import towhee
+
+towhee.glob('./image.jpg') \
+ .image_decode() \
+ .image_captioning.capdec(model_name='capdec_noise_0') \
+ .show()
+```
+
+
+*Write a same pipeline with explicit inputs/outputs name specifications:*
+
+```python
+import towhee
+
+towhee.glob['path']('./image.jpg') \
+ .image_decode['path', 'img']() \
+ .image_captioning.capdec['img', 'text'](model_name='capdec_noise_0') \
+ .select['img', 'text']() \
+ .show()
+```
+
+
+
+
+
+
+## Factory Constructor
+
+Create the operator via the following factory method
+
+***capdec(model_name)***
+
+**Parameters:**
+
+ ***model_name:*** *str*
+
+ The model name of CapDec. Supported model names:
+- capdec_noise_0
+- capdec_noise_01
+- capdec_noise_001
+- capdec_noise_0001
+
+
+
+## Interface
+
+An image captioning operator takes a [towhee image](link/to/towhee/image/api/doc) as input and generate the correspoing caption.
+
+
+**Parameters:**
+
+ ***data:*** *towhee.types.Image (a sub-class of numpy.ndarray)*
+
+ The image to generate caption.
+
+
+
+**Returns:** *str*
+
+ The caption generated by model.
+
diff --git a/__init__.py b/__init__.py
new file mode 100644
index 0000000..ef1f1fd
--- /dev/null
+++ b/__init__.py
@@ -0,0 +1,18 @@
+# Copyright 2021 Zilliz. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from .camel import Camel
+
+def camel(model_name: str):
+ return Camel(model_name)
diff --git a/camel.py b/camel.py
new file mode 100644
index 0000000..f3683f9
--- /dev/null
+++ b/camel.py
@@ -0,0 +1,115 @@
+# Copyright 2021 Zilliz. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import sys
+import os
+from pathlib import Path
+from easydict import EasyDict as edict
+
+import torch
+from torchvision import transforms
+from transformers import GPT2Tokenizer
+
+from towhee.types.arg import arg, to_image_color
+from towhee.types.image_utils import to_pil
+from towhee.operator.base import NNOperator, OperatorFlag
+from towhee import register
+from towhee.models import clip
+
+class Camel(NNOperator):
+ """
+ Camel image captioning operator
+ """
+ def _gen_args(self):
+ args = edict()
+ args.image_dim =
+ args.N_enc = 3
+ args.d_model = 512
+ args.d_ff = 2048
+ args.head = 8
+ args.m = 40
+ args.disable_mesh = True
+ args.d_model = 512
+ args.with_pe = True
+ return args
+
+ def __init__(self, model_name: str):
+ super().__init__()
+ sys.path.append(str(Path(__file__).parent))
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
+ from models import Captioner
+ from data import ImageField, TextField
+
+ # Pipeline for text
+ self.text_field = TextField()
+ args = self._gen_args()
+
+ self.clip_model = clip.create_model(model_name='clip_resnet_r50x4', pretrained=True, jit=True)
+ self.clip_tfms = clip.get_transforms(model_name='clip_resnet_r50x4')
+ self.image_model = self.clip_model.visual
+ self.image_model.forward = self.image_model.intermediate_features
+ image_field = ImageField(transform=self.clip_tfms)
+ args.image_dim = self.mage_model.embed_dim
+
+
+ # Create the model
+ self.model = Captioner(args, self.text_field).to(self.device)
+ self.model.forward = self.model.beam_search
+ self.image_model = self.image_model.to(self.device)
+
+ self.model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
+ self.model = self.model.eval()
+ sys.path.pop()
+
+
+ @arg(1, to_image_color('RGB'))
+ def inference_single_data(self, data):
+ text = self._inference_from_image(data)
+ return text
+
+ def _preprocess(self, img):
+ img = to_pil(img)
+ processed_img = self.clip_tfms(img).unsqueeze(0).to(self.device)
+ return processed_img
+
+ def __call__(self, data):
+ if not isinstance(data, list):
+ data = [data]
+ else:
+ data = data
+ results = []
+ for single_data in data:
+ result = self.inference_single_data(single_data)
+ results.append(result)
+ if len(data) == 1:
+ return results[0]
+ else:
+ return results
+
+ @arg(1, to_image_color('RGB'))
+ def _inference_from_image(self, img):
+ img = self._preprocess(img)
+ text, _ = self.model.beam_search(img, beam_size=5, out_size=1)
+ return text
+
+ def _configs(self):
+ config = {}
+ config['clipcap_coco'] = {}
+ config['clipcap_coco']['weights'] = 'coco_weights.pt'
+ config['clipcap_conceptual'] = {}
+ config['clipcap_conceptual']['weights'] = 'conceptual_weights.pt'
+ return config
+
+if __name__ == '__main__':
+ pass
diff --git a/data/.DS_Store b/data/.DS_Store
new file mode 100644
index 0000000..841752f
Binary files /dev/null and b/data/.DS_Store differ
diff --git a/data/__init__.py b/data/__init__.py
new file mode 100644
index 0000000..c5191c3
--- /dev/null
+++ b/data/__init__.py
@@ -0,0 +1,10 @@
+from .field import RawField, Merge, ImageField, TextField
+from .dataset import *
+from torch.utils.data import DataLoader as TorchDataLoader
+
+from .dataset import *
+
+
+class DataLoader(TorchDataLoader):
+ def __init__(self, dataset, *args, **kwargs):
+ super(DataLoader, self).__init__(dataset, *args, collate_fn=dataset.collate_fn(), **kwargs)
diff --git a/data/dataset.py b/data/dataset.py
new file mode 100644
index 0000000..740c492
--- /dev/null
+++ b/data/dataset.py
@@ -0,0 +1,303 @@
+import collections
+import itertools
+import os
+
+import numpy as np
+import torch
+from pycocotools.coco import COCO as pyCOCO
+from torch.utils.data import Dataset as PthDataset
+
+from utils import nostdout
+from .example import Example
+
+
+class Dataset(PthDataset):
+ def __init__(self, examples, fields):
+ self.examples = examples
+ self.fields = dict(fields)
+
+ def collate_fn(self):
+ def collate(batch):
+ if len(self.fields) == 1:
+ batch = [batch, ]
+ else:
+ batch = list(zip(*batch))
+
+ tensors = []
+ for field, data in zip(self.fields.values(), batch):
+ tensor = field.process(data)
+ if isinstance(tensor, collections.Sequence) and any(isinstance(t, torch.Tensor) for t in tensor):
+ tensors.extend(tensor)
+ else:
+ tensors.append(tensor)
+
+ if len(tensors) > 1:
+ return tensors
+ else:
+ return tensors[0]
+
+ return collate
+
+ def __getitem__(self, i):
+ example = self.examples[i]
+ data = []
+ for field_name, field in self.fields.items():
+ data.append(field.preprocess(getattr(example, field_name, None)))
+
+ if len(data) == 1:
+ data = data[0]
+ return data
+
+ def __len__(self):
+ return len(self.examples)
+
+
+class ValueDataset(Dataset):
+ def __init__(self, examples, fields, dictionary):
+ self.dictionary = dictionary
+ super(ValueDataset, self).__init__(examples, fields)
+
+ def collate_fn(self):
+ def collate(batch):
+ value_batch_flattened = list(itertools.chain(*batch))
+ value_tensors_flattened = super(ValueDataset, self).collate_fn()(value_batch_flattened)
+
+ lengths = [0, ] + list(itertools.accumulate([len(x) for x in batch]))
+ if isinstance(value_tensors_flattened, collections.Sequence) \
+ and any(isinstance(t, torch.Tensor) for t in value_tensors_flattened):
+ value_tensors = [[vt[s:e] for (s, e) in zip(lengths[:-1], lengths[1:])]
+ for vt in value_tensors_flattened]
+ else:
+ value_tensors = [value_tensors_flattened[s:e] for (s, e) in zip(lengths[:-1], lengths[1:])]
+
+ return value_tensors
+ return collate
+
+ def __getitem__(self, i):
+ if i not in self.dictionary:
+ raise IndexError
+
+ values_data = []
+ for idx in self.dictionary[i]:
+ value_data = super(ValueDataset, self).__getitem__(idx)
+ values_data.append(value_data)
+ return values_data
+
+ def __len__(self):
+ return len(self.dictionary)
+
+
+class DictionaryDataset(Dataset):
+ def __init__(self, examples, fields, key_fields):
+ if not isinstance(key_fields, (tuple, list)):
+ key_fields = (key_fields,)
+ for field in key_fields:
+ assert (field in fields)
+
+ dictionary = collections.defaultdict(list)
+ key_fields = {k: fields[k] for k in key_fields}
+ value_fields = {k: fields[k] for k in fields.keys() if k not in key_fields}
+ key_examples = []
+ key_dict = dict()
+ value_examples = []
+
+ for i, e in enumerate(examples):
+ key_example = Example.fromdict({k: getattr(e, k) for k in key_fields})
+ value_example = Example.fromdict({v: getattr(e, v) for v in value_fields})
+ if key_example not in key_dict:
+ key_dict[key_example] = len(key_examples)
+ key_examples.append(key_example)
+
+ value_examples.append(value_example)
+ dictionary[key_dict[key_example]].append(i)
+
+ self.key_dataset = Dataset(key_examples, key_fields)
+ self.value_dataset = ValueDataset(value_examples, value_fields, dictionary)
+ super(DictionaryDataset, self).__init__(examples, fields)
+
+ def collate_fn(self):
+ def collate(batch):
+ key_batch, value_batch = list(zip(*batch))
+ key_tensors = self.key_dataset.collate_fn()(key_batch)
+ value_tensors = self.value_dataset.collate_fn()(value_batch)
+ return key_tensors, value_tensors
+ return collate
+
+ def __getitem__(self, i):
+ return self.key_dataset[i], self.value_dataset[i]
+
+ def __len__(self):
+ return len(self.key_dataset)
+
+
+def unique(sequence):
+ seen = set()
+ if isinstance(sequence[0], list):
+ return [x for x in sequence if not (tuple(x) in seen or seen.add(tuple(x)))]
+ else:
+ return [x for x in sequence if not (x in seen or seen.add(x))]
+
+
+# class ImageDataset(Dataset):
+# def __init__(self, examples, image_field):
+# super().__init__(examples, {'image': image_field})
+#
+#
+# class COCO2017Unlabeled(ImageDataset):
+# def __init__(self, image_field, img_root, load_in_tmp=False):
+# tmp_path = os.path.join('/tmp/coco2017unlabeled')
+# if load_in_tmp and sync_path(img_root, tmp_path, size=19*1024**3):
+# img_root = tmp_path
+#
+# examples = [Example.fromdict({'image': os.path.join(img_root, f)}) for f in os.listdir(img_root)]
+# super().__init__(examples, image_field)
+#
+#
+# class Nocaps(ImageDataset):
+# def __init__(self, image_field, img_root, load_in_tmp=False):
+# tmp_path = os.path.join('/tmp/nocaps')
+# if load_in_tmp and sync_path(img_root, tmp_path, size=3.6*1024**3):
+# img_root = tmp_path
+#
+# examples = []
+# for split in ('validation', 'test'):
+# examples += [Example.fromdict({'image': os.path.join(img_root, split, f)}) for f in os.listdir(img_root + '/' + split)]
+# super().__init__(examples, image_field)
+
+
+class PairedDataset(Dataset):
+ def __init__(self, examples, image_field, text_field):
+ super(PairedDataset, self).__init__(examples, {'image': image_field, 'text': text_field})
+ self.image_field = self.fields['image']
+ self.text_field = self.fields['text']
+
+ def image_set(self):
+ img_list = [e.image for e in self.examples]
+ image_set = unique(img_list)
+ examples = [Example.fromdict({'image': i}) for i in image_set]
+ dataset = Dataset(examples, {'image': self.image_field})
+ return dataset
+
+ def text_set(self):
+ text_list = [e.text for e in self.examples]
+ text_list = unique(text_list)
+ examples = [Example.fromdict({'text': t}) for t in text_list]
+ dataset = Dataset(examples, {'text': self.text_field})
+ return dataset
+
+ def image_dictionary(self, fields=None):
+ if not fields:
+ fields = self.fields
+ dataset = DictionaryDataset(self.examples, fields, key_fields='image')
+ return dataset
+
+ def text_dictionary(self, fields=None):
+ if not fields:
+ fields = self.fields
+ dataset = DictionaryDataset(self.examples, fields, key_fields='text')
+ return dataset
+
+ @property
+ def splits(self):
+ raise NotImplementedError
+
+
+class COCO(PairedDataset):
+ def __init__(self, image_field, text_field, img_root, ann_root, id_root=None, use_restval=True,
+ cut_validation=False):
+
+ roots = {}
+ roots['train'] = {
+ 'img': os.path.join(img_root, 'train2014'),
+ 'cap': os.path.join(ann_root, 'captions_train2014.json')
+ }
+ roots['val'] = {
+ 'img': os.path.join(img_root, 'val2014'),
+ 'cap': os.path.join(ann_root, 'captions_val2014.json')
+ }
+ roots['test'] = {
+ 'img': os.path.join(img_root, 'val2014'),
+ 'cap': os.path.join(ann_root, 'captions_val2014.json')
+ }
+ roots['trainrestval'] = {
+ 'img': (roots['train']['img'], roots['val']['img']),
+ 'cap': (roots['train']['cap'], roots['val']['cap'])
+ }
+
+ if id_root is not None:
+ ids = {}
+ ids['train'] = np.load(os.path.join(id_root, 'coco_train_ids.npy'))
+ ids['val'] = np.load(os.path.join(id_root, 'coco_dev_ids.npy'))
+ if cut_validation:
+ ids['val'] = ids['val'][:5000]
+ ids['test'] = np.load(os.path.join(id_root, 'coco_test_ids.npy'))
+ ids['trainrestval'] = (
+ ids['train'],
+ np.load(os.path.join(id_root, 'coco_restval_ids.npy')))
+
+ if use_restval:
+ roots['train'] = roots['trainrestval']
+ ids['train'] = ids['trainrestval']
+ else:
+ ids = None
+
+ with nostdout():
+ self.train_examples, self.val_examples, self.test_examples = self.get_samples(roots, ids)
+ examples = self.train_examples + self.val_examples + self.test_examples
+ super(COCO, self).__init__(examples, image_field, text_field)
+
+ @property
+ def splits(self):
+ train_split = PairedDataset(self.train_examples, self.image_field, self.text_field)
+ val_split = PairedDataset(self.val_examples, self.image_field, self.text_field)
+ test_split = PairedDataset(self.test_examples, self.image_field, self.text_field)
+ return train_split, val_split, test_split
+
+ @classmethod
+ def get_samples(cls, roots, ids_dataset=None):
+ train_samples = []
+ val_samples = []
+ test_samples = []
+
+ for split in ['train', 'val', 'test']:
+ if isinstance(roots[split]['cap'], tuple):
+ coco_dataset = (pyCOCO(roots[split]['cap'][0]), pyCOCO(roots[split]['cap'][1]))
+ root = roots[split]['img']
+ else:
+ coco_dataset = (pyCOCO(roots[split]['cap']),)
+ root = (roots[split]['img'],)
+
+ if ids_dataset is None:
+ ids = list(coco_dataset.anns.keys())
+ else:
+ ids = ids_dataset[split]
+
+ if isinstance(ids, tuple):
+ bp = len(ids[0])
+ ids = list(ids[0]) + list(ids[1])
+ else:
+ bp = len(ids)
+
+ for index in range(len(ids)):
+ if index < bp:
+ coco = coco_dataset[0]
+ img_root = root[0]
+ else:
+ coco = coco_dataset[1]
+ img_root = root[1]
+
+ ann_id = ids[index]
+ caption = coco.anns[ann_id]['caption']
+ img_id = coco.anns[ann_id]['image_id']
+ filename = coco.loadImgs(img_id)[0]['file_name']
+
+ example = Example.fromdict({'image': os.path.join(img_root, filename), 'text': caption})
+
+ if split == 'train':
+ train_samples.append(example)
+ elif split == 'val':
+ val_samples.append(example)
+ elif split == 'test':
+ test_samples.append(example)
+
+ return train_samples, val_samples, test_samples
diff --git a/data/example.py b/data/example.py
new file mode 100644
index 0000000..61d1772
--- /dev/null
+++ b/data/example.py
@@ -0,0 +1,26 @@
+class Example(object):
+ """Defines a single training or test example.
+ Stores each column of the example as an attribute.
+ """
+ @classmethod
+ def fromdict(cls, data):
+ ex = cls(data)
+ return ex
+
+ def __init__(self, data):
+ for key, val in data.items():
+ super(Example, self).__setattr__(key, val)
+
+ def __setattr__(self, key, value):
+ raise AttributeError
+
+ def __hash__(self):
+ return hash(tuple(x for x in self.__dict__.values()))
+
+ def __eq__(self, other):
+ this = tuple(x for x in self.__dict__.values())
+ other = tuple(x for x in other.__dict__.values())
+ return this == other
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
diff --git a/data/field.py b/data/field.py
new file mode 100644
index 0000000..4b8ee5c
--- /dev/null
+++ b/data/field.py
@@ -0,0 +1,127 @@
+# coding: utf8
+from itertools import takewhile
+
+import torch
+from torch.utils.data.dataloader import default_collate
+from torchvision.datasets.folder import default_loader
+
+from .tokenizer.simple_tokenizer import SimpleTokenizer as _Tokenizer
+
+
+class RawField(object):
+ """ Defines a general datatype.
+
+ Every dataset consists of one or more types of data. For instance,
+ a machine translation dataset contains paired examples of text, while
+ an image captioning dataset contains images and texts.
+ Each of these types of data is represented by a RawField object.
+ An RawField object does not assume any property of the data type and
+ it holds parameters relating to how a datatype should be processed.
+
+ Attributes:
+ preprocessing: The Pipeline that will be applied to examples
+ using this field before creating an example.
+ Default: None.
+ postprocessing: A Pipeline that will be applied to a list of examples
+ using this field before assigning to a batch.
+ Function signature: (batch(list)) -> object
+ Default: None.
+ """
+
+ def __init__(self, preprocessing=None, postprocessing=None):
+ self.preprocessing = preprocessing
+ self.postprocessing = postprocessing
+
+ def preprocess(self, x):
+ """ Preprocess an example if the `preprocessing` Pipeline is provided. """
+ if self.preprocessing is not None:
+ return self.preprocessing(x)
+ else:
+ return x
+
+ def process(self, batch, *args, **kwargs):
+ """ Process a list of examples to create a batch.
+
+ Postprocess the batch with user-provided Pipeline.
+
+ Args:
+ batch (list(object)): A list of object from a batch of examples.
+ Returns:
+ object: Processed object given the input and custom
+ postprocessing Pipeline.
+ """
+ if self.postprocessing is not None:
+ batch = self.postprocessing(batch)
+ return default_collate(batch)
+
+
+class Merge(RawField):
+ def __init__(self, *fields):
+ super(Merge, self).__init__()
+ self.fields = fields
+
+ def preprocess(self, x):
+ return tuple(f.preprocess(x) for f in self.fields)
+
+ def process(self, batch, *args, **kwargs):
+ if len(self.fields) == 1:
+ batch = [batch, ]
+ else:
+ batch = list(zip(*batch))
+
+ out = list(f.process(b, *args, **kwargs) for f, b in zip(self.fields, batch))
+ return out
+
+
+class ImageField(RawField):
+ def __init__(self, preprocessing=None, postprocessing=None, loader=default_loader, transform=None):
+ self.loader = loader
+ self.transform = transform
+ super().__init__(preprocessing, postprocessing)
+
+ def preprocess(self, x):
+ sample = self.loader(x)
+ if self.transform is not None:
+ sample = self.transform(sample)
+ return sample
+
+
+class TextField(RawField):
+ def __init__(self):
+ self._tokenizer = _Tokenizer()
+ super(TextField, self).__init__()
+
+ def preprocess(self, x):
+ if x is None:
+ return ''
+ return x
+
+ def process(self, texts):
+ if isinstance(texts, str):
+ texts = [texts]
+
+ sot_token = self._tokenizer.bos_idx
+ eot_token = self._tokenizer.eos_idx
+ all_tokens = [[sot_token] + self._tokenizer.encode(text) + [eot_token] for text in texts]
+ result = torch.zeros(len(all_tokens), max(len(s) for s in all_tokens), dtype=torch.long)
+
+ for i, tokens in enumerate(all_tokens):
+ result[i, :len(tokens)] = torch.tensor(tokens)
+
+ return result
+
+ def decode(self, word_idxs):
+ if isinstance(word_idxs, list) and len(word_idxs) == 0:
+ return self.decode([word_idxs, ])[0]
+ if isinstance(word_idxs, list) and isinstance(word_idxs[0], int):
+ return self.decode([word_idxs, ])[0]
+ elif isinstance(word_idxs, torch.Tensor) and word_idxs.ndimension() == 1:
+ return self.decode(word_idxs.unsqueeze(0))[0]
+
+ captions = []
+ for wis in word_idxs:
+ wis = wis.tolist()
+ wis = list(takewhile(lambda tok: tok != self._tokenizer.eos_idx, wis))
+ caption = self._tokenizer.decode(wis)
+ captions.append(caption)
+ return captions
diff --git a/data/tokenizer/__init__.py b/data/tokenizer/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/data/tokenizer/bpe_simple_vocab_16e6.txt.gz b/data/tokenizer/bpe_simple_vocab_16e6.txt.gz
new file mode 100644
index 0000000..36a1585
--- /dev/null
+++ b/data/tokenizer/bpe_simple_vocab_16e6.txt.gz
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a
+size 1356917
diff --git a/data/tokenizer/simple_tokenizer.py b/data/tokenizer/simple_tokenizer.py
new file mode 100644
index 0000000..1c58a78
--- /dev/null
+++ b/data/tokenizer/simple_tokenizer.py
@@ -0,0 +1,144 @@
+import gzip
+import html
+import os
+from functools import lru_cache
+
+import ftfy
+import regex as re
+
+
+@lru_cache()
+def default_bpe():
+ return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
+
+
+@lru_cache()
+def bytes_to_unicode():
+ """
+ Returns list of utf-8 byte and a corresponding list of unicode strings.
+ The reversible bpe codes work on unicode strings.
+ This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
+ When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
+ This is a signficant percentage of your normal, say, 32K bpe vocab.
+ To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
+ And avoids mapping to whitespace/control characters the bpe code barfs on.
+ """
+ bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
+ cs = bs[:]
+ n = 0
+ for b in range(2**8):
+ if b not in bs:
+ bs.append(b)
+ cs.append(2**8+n)
+ n += 1
+ cs = [chr(n) for n in cs]
+ return dict(zip(bs, cs))
+
+
+def get_pairs(word):
+ """Return set of symbol pairs in a word.
+ Word is represented as tuple of symbols (symbols being variable-length strings).
+ """
+ pairs = set()
+ prev_char = word[0]
+ for char in word[1:]:
+ pairs.add((prev_char, char))
+ prev_char = char
+ return pairs
+
+
+def basic_clean(text):
+ text = ftfy.fix_text(text)
+ text = html.unescape(html.unescape(text))
+ return text.strip()
+
+
+def whitespace_clean(text):
+ text = re.sub(r'\s+', ' ', text)
+ text = text.strip()
+ return text
+
+
+class SimpleTokenizer(object):
+ def __init__(self, bpe_path: str = default_bpe()):
+ self.byte_encoder = bytes_to_unicode()
+ self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
+ merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
+ merges = merges[1:49152-256-2+1]
+ merges = [tuple(merge.split()) for merge in merges]
+ vocab = list(bytes_to_unicode().values())
+ vocab = vocab + [v+'' for v in vocab]
+ for merge in merges:
+ vocab.append(''.join(merge))
+ vocab.extend(['<|startoftext|>', '<|endoftext|>'])
+ self.encoder = dict(zip(vocab, range(len(vocab))))
+ self.decoder = {v: k for k, v in self.encoder.items()}
+ self.bpe_ranks = dict(zip(merges, range(len(merges))))
+ self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
+ self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
+
+ @property
+ def vocab_size(self):
+ return len(self.encoder)
+
+ @property
+ def eos_idx(self):
+ return self.encoder['<|endoftext|>']
+
+ @property
+ def bos_idx(self):
+ return self.encoder['<|startoftext|>']
+
+ def bpe(self, token):
+ if token in self.cache:
+ return self.cache[token]
+ word = tuple(token[:-1]) + ( token[-1] + '',)
+ pairs = get_pairs(word)
+
+ if not pairs:
+ return token+''
+
+ while True:
+ bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
+ if bigram not in self.bpe_ranks:
+ break
+ first, second = bigram
+ new_word = []
+ i = 0
+ while i < len(word):
+ try:
+ j = word.index(first, i)
+ new_word.extend(word[i:j])
+ i = j
+ except:
+ new_word.extend(word[i:])
+ break
+
+ if word[i] == first and i < len(word)-1 and word[i+1] == second:
+ new_word.append(first+second)
+ i += 2
+ else:
+ new_word.append(word[i])
+ i += 1
+ new_word = tuple(new_word)
+ word = new_word
+ if len(word) == 1:
+ break
+ else:
+ pairs = get_pairs(word)
+ word = ' '.join(word)
+ self.cache[token] = word
+ return word
+
+ def encode(self, text):
+ bpe_tokens = []
+ text = whitespace_clean(basic_clean(text)).lower()
+ for token in re.findall(self.pat, text):
+ token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
+ bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
+ return bpe_tokens
+
+ def decode(self, tokens):
+ text = ''.join([self.decoder[token] for token in tokens])
+ text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('', ' ')
+ return text
\ No newline at end of file
diff --git a/models/.DS_Store b/models/.DS_Store
new file mode 100644
index 0000000..eb79451
Binary files /dev/null and b/models/.DS_Store differ
diff --git a/models/__init__.py b/models/__init__.py
new file mode 100644
index 0000000..1585d31
--- /dev/null
+++ b/models/__init__.py
@@ -0,0 +1 @@
+from .transformer.captioner import Captioner
diff --git a/models/beam_search/__init__.py b/models/beam_search/__init__.py
new file mode 100644
index 0000000..21ac612
--- /dev/null
+++ b/models/beam_search/__init__.py
@@ -0,0 +1 @@
+from .beam_search import BeamSearch
diff --git a/models/beam_search/beam_search.py b/models/beam_search/beam_search.py
new file mode 100644
index 0000000..8dfa752
--- /dev/null
+++ b/models/beam_search/beam_search.py
@@ -0,0 +1,149 @@
+import torch
+import utils
+
+
+class BeamSearch(object):
+ def __init__(self, model, max_len: int, eos_idx: int, beam_size: int):
+ self.model = model
+ self.max_len = max_len
+ self.eos_idx = eos_idx
+ self.beam_size = beam_size
+ self.b_s = None
+ self.device = None
+ self.seq_mask = None
+ self.seq_logprob = None
+ self.outputs = None
+ self.log_probs = None
+ self.selected_words = None
+ self.all_logits = None
+
+ def _expand_state(self, selected_beam, cur_beam_size):
+ def fn(s):
+ shape = [int(sh) for sh in s.shape]
+ beam = selected_beam
+ for _ in shape[1:]:
+ beam = beam.unsqueeze(-1)
+ s = torch.gather(s.view(*([self.b_s, cur_beam_size] + shape[1:])), 1,
+ beam.expand(*([self.b_s, self.beam_size] + shape[1:])))
+ s = s.view(*([-1, ] + shape[1:]))
+ return s
+
+ return fn
+
+ def _expand_visual(self, visual: utils.TensorOrSequence, cur_beam_size: int, selected_beam: torch.Tensor):
+ if isinstance(visual, torch.Tensor):
+ visual_shape = visual.shape
+ visual_exp_shape = (self.b_s, cur_beam_size) + visual_shape[1:]
+ visual_red_shape = (self.b_s * self.beam_size,) + visual_shape[1:]
+ selected_beam_red_size = (self.b_s, self.beam_size) + tuple(1 for _ in range(len(visual_exp_shape) - 2))
+ selected_beam_exp_size = (self.b_s, self.beam_size) + visual_exp_shape[2:]
+ visual_exp = visual.view(visual_exp_shape)
+ selected_beam_exp = selected_beam.view(selected_beam_red_size).expand(selected_beam_exp_size)
+ visual = torch.gather(visual_exp, 1, selected_beam_exp).view(visual_red_shape)
+ else:
+ new_visual = []
+ for im in visual:
+ visual_shape = im.shape
+ visual_exp_shape = (self.b_s, cur_beam_size) + visual_shape[1:]
+ visual_red_shape = (self.b_s * self.beam_size,) + visual_shape[1:]
+ selected_beam_red_size = (self.b_s, self.beam_size) + tuple(1 for _ in range(len(visual_exp_shape) - 2))
+ selected_beam_exp_size = (self.b_s, self.beam_size) + visual_exp_shape[2:]
+ visual_exp = im.view(visual_exp_shape)
+ selected_beam_exp = selected_beam.view(selected_beam_red_size).expand(selected_beam_exp_size)
+ new_im = torch.gather(visual_exp, 1, selected_beam_exp).view(visual_red_shape)
+ new_visual.append(new_im)
+ visual = tuple(new_visual)
+ return visual
+
+ def apply(self, visual: utils.TensorOrSequence, out_size=1, return_logits=False, **kwargs):
+ self.b_s = utils.get_batch_size(visual)
+ self.device = utils.get_device(visual)
+ self.seq_mask = torch.ones((self.b_s, self.beam_size, 1), device=self.device)
+ self.seq_logprob = torch.zeros((self.b_s, 1, 1), device=self.device)
+ self.log_probs = []
+ self.selected_words = None
+ if return_logits:
+ self.all_logits = []
+
+ outputs = []
+ with self.model.statefulness(self.b_s):
+ for t in range(self.max_len):
+ visual, outputs = self.iter(t, visual, outputs, return_logits, **kwargs)
+
+ # Sort result
+ seq_logprob, sort_idxs = torch.sort(self.seq_logprob, 1, descending=True)
+ outputs = torch.cat(outputs, -1)
+ outputs = torch.gather(outputs, 1, sort_idxs.expand(self.b_s, self.beam_size, self.max_len))
+ log_probs = torch.cat(self.log_probs, -1)
+ log_probs = torch.gather(log_probs, 1, sort_idxs.expand(self.b_s, self.beam_size, self.max_len))
+ outputs = outputs.contiguous()[:, :out_size]
+ log_probs = log_probs.contiguous()[:, :out_size]
+
+ if return_logits:
+ all_logits = torch.cat(self.all_logits, 2)
+ all_logits = torch.gather(all_logits, 1, sort_idxs.unsqueeze(-1).expand(self.b_s, self.beam_size,
+ self.max_len,
+ all_logits.shape[-1]))
+ all_logits = all_logits.contiguous()[:, :out_size]
+
+ if out_size == 1:
+ outputs = outputs.squeeze(1)
+ log_probs = log_probs.squeeze(1)
+ if return_logits:
+ all_logits = all_logits.squeeze(1)
+
+ if return_logits:
+ return outputs, log_probs, all_logits
+ else:
+ return outputs, log_probs
+
+ def select(self, t, candidate_logprob, **kwargs):
+ selected_logprob, selected_idx = torch.sort(candidate_logprob.view(self.b_s, -1), -1, descending=True)
+ selected_logprob, selected_idx = selected_logprob[:, :self.beam_size], selected_idx[:, :self.beam_size]
+ return selected_idx, selected_logprob
+
+ def iter(self, t: int, visual: utils.TensorOrSequence, outputs, return_logits, **kwargs):
+ cur_beam_size = 1 if t == 0 else self.beam_size
+
+ word_logits = self.model.step(t, self.selected_words, visual, **kwargs)
+ word_logits = word_logits.view(self.b_s, cur_beam_size, -1)
+ word_logprob = torch.log_softmax(word_logits, dim=-1)
+ candidate_logprob = self.seq_logprob + word_logprob
+
+ # Mask sequence if it reaches EOS
+ if t > 0:
+ mask = (self.selected_words.view(self.b_s, cur_beam_size) != self.eos_idx).type(visual.dtype).unsqueeze(-1)
+ self.seq_mask = self.seq_mask * mask
+ word_logprob = word_logprob * self.seq_mask.expand_as(word_logprob)
+ old_seq_logprob = self.seq_logprob.expand_as(candidate_logprob).contiguous()
+ old_seq_logprob[:, :, 1:] = -999
+ candidate_logprob = self.seq_mask * candidate_logprob + old_seq_logprob * (1 - self.seq_mask)
+
+ selected_idx, selected_logprob = self.select(t, candidate_logprob, **kwargs)
+ selected_beam = torch.floor_divide(selected_idx, candidate_logprob.shape[-1])
+ selected_words = selected_idx - selected_beam * candidate_logprob.shape[-1]
+
+ self.model.apply_to_states(self._expand_state(selected_beam, cur_beam_size))
+ visual = self._expand_visual(visual, cur_beam_size, selected_beam)
+
+ self.seq_logprob = selected_logprob.unsqueeze(-1)
+ self.seq_mask = torch.gather(self.seq_mask, 1, selected_beam.unsqueeze(-1))
+ outputs = list(torch.gather(o, 1, selected_beam.unsqueeze(-1)) for o in outputs)
+ outputs.append(selected_words.unsqueeze(-1))
+
+ if return_logits:
+ if t == 0:
+ self.all_logits.append(word_logits.expand((self.b_s, self.beam_size, -1)).unsqueeze(2))
+ else:
+ self.all_logits.append(word_logits.unsqueeze(2))
+
+ this_word_logprob = torch.gather(word_logprob, 1,
+ selected_beam.unsqueeze(-1).expand(self.b_s, self.beam_size,
+ word_logprob.shape[-1]))
+ this_word_logprob = torch.gather(this_word_logprob, 2, selected_words.unsqueeze(-1))
+ self.log_probs = list(
+ torch.gather(o, 1, selected_beam.unsqueeze(-1).expand(self.b_s, self.beam_size, 1)) for o in self.log_probs)
+ self.log_probs.append(this_word_logprob)
+ self.selected_words = selected_words.view(-1, 1)
+
+ return visual, outputs
diff --git a/models/clip.py b/models/clip.py
new file mode 100644
index 0000000..12ac275
--- /dev/null
+++ b/models/clip.py
@@ -0,0 +1,619 @@
+import hashlib
+import math
+import os
+import urllib
+import warnings
+from collections import OrderedDict
+from typing import Tuple
+from typing import Union, List
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from PIL import Image
+from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
+
+try:
+ from torchvision.transforms import InterpolationMode
+ BICUBIC = InterpolationMode.BICUBIC
+except ImportError:
+ BICUBIC = Image.BICUBIC
+
+from torch import nn
+from tqdm import tqdm
+from models.utils import one_hot_to_index
+
+_MODELS = {
+ "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
+ "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
+ "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
+ "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
+ "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
+ "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
+}
+
+
+def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
+ os.makedirs(root, exist_ok=True)
+ filename = os.path.basename(url)
+
+ expected_sha256 = url.split("/")[-2]
+ download_target = os.path.join(root, filename)
+
+ if os.path.exists(download_target) and not os.path.isfile(download_target):
+ raise RuntimeError(f"{download_target} exists and is not a regular file")
+
+ if os.path.isfile(download_target):
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
+ return download_target
+ else:
+ warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
+
+ with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
+ with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
+ while True:
+ buffer = source.read(8192)
+ if not buffer:
+ break
+
+ output.write(buffer)
+ loop.update(len(buffer))
+
+ if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
+ raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
+
+ return download_target
+
+
+def _transform(n_px):
+ return Compose([
+ Resize(n_px, interpolation=BICUBIC),
+ CenterCrop(n_px),
+ lambda image: image.convert("RGB"),
+ ToTensor(),
+ Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
+ ])
+
+
+def available_models() -> List[str]:
+ """Returns the names of available CLIP models"""
+ return list(_MODELS.keys())
+
+
+def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True):
+ """Load a CLIP model
+
+ Parameters
+ ----------
+ name : str
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
+
+ device : Union[str, torch.device]
+ The device to put the loaded model
+
+ jit : bool
+ Whether to load the optimized JIT model (default) or more hackable non-JIT model.
+
+ Returns
+ -------
+ model : torch.nn.Module
+ The CLIP model
+
+ preprocess : Callable[[PIL.Image], torch.Tensor]
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
+ """
+ if name in _MODELS:
+ model_path = _download(_MODELS[name])
+ elif os.path.isfile(name):
+ model_path = name
+ else:
+ raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
+
+ try:
+ # loading JIT archive
+ model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
+ state_dict = None
+ except RuntimeError:
+ # loading saved state dict
+ if jit:
+ warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
+ jit = False
+ state_dict = torch.load(model_path, map_location="cpu")
+
+ if not jit:
+ model = build_model(state_dict or model.state_dict()).to(device)
+ if str(device) == "cpu":
+ model.float()
+ return model, _transform(model.visual.input_resolution)
+
+ # patch the device names
+ device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
+ device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
+
+ def patch_device(module):
+ graphs = [module.graph] if hasattr(module, "graph") else []
+ if hasattr(module, "forward1"):
+ graphs.append(module.forward1.graph)
+
+ for graph in graphs:
+ for node in graph.findAllNodes("prim::Constant"):
+ if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
+ node.copyAttributes(device_node)
+
+ model.apply(patch_device)
+ patch_device(model.encode_image)
+ patch_device(model.encode_text)
+
+ # patch dtype to float32 on CPU
+ if str(device) == "cpu":
+ float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
+ float_node = float_input.node()
+
+ def patch_float(module):
+ graphs = [module.graph] if hasattr(module, "graph") else []
+ if hasattr(module, "forward1"):
+ graphs.append(module.forward1.graph)
+
+ for graph in graphs:
+ for node in graph.findAllNodes("aten::to"):
+ inputs = list(node.inputs())
+ for i in [1, 2]: # dtype can be the second or third argument to aten::to()
+ if inputs[i].node()["value"] == 5:
+ inputs[i].node().copyAttributes(float_node)
+
+ model.apply(patch_float)
+ patch_float(model.encode_image)
+ patch_float(model.encode_text)
+
+ model.float()
+
+ return model, _transform(model.input_resolution.item())
+
+
+class LayerNorm(nn.LayerNorm):
+ """Subclass torch's LayerNorm to handle fp16."""
+
+ def forward(self, x: torch.Tensor):
+ orig_type = x.dtype
+ ret = super().forward(x.type(torch.float32))
+ return ret.type(orig_type)
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self, inplanes, planes, stride=1):
+ super().__init__()
+
+ # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
+ self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
+ self.bn1 = nn.BatchNorm2d(planes)
+
+ self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(planes)
+
+ self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
+
+ self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
+
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = None
+ self.stride = stride
+
+ if stride > 1 or inplanes != planes * Bottleneck.expansion:
+ # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
+ self.downsample = nn.Sequential(OrderedDict([
+ ("-1", nn.AvgPool2d(stride)),
+ ("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
+ ("1", nn.BatchNorm2d(planes * self.expansion))
+ ]))
+
+ def forward(self, x: torch.Tensor):
+ identity = x
+
+ out = self.relu(self.bn1(self.conv1(x)))
+ out = self.relu(self.bn2(self.conv2(out)))
+ out = self.avgpool(out)
+ out = self.bn3(self.conv3(out))
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+ out = self.relu(out)
+ return out
+
+
+class AttentionPool2d(nn.Module):
+ def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
+ super().__init__()
+ self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
+ self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
+ self.num_heads = num_heads
+
+ def forward(self, x):
+ x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
+ x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
+ x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
+ x, _ = F.multi_head_attention_forward(
+ query=x, key=x, value=x,
+ embed_dim_to_check=x.shape[-1],
+ num_heads=self.num_heads,
+ q_proj_weight=self.q_proj.weight,
+ k_proj_weight=self.k_proj.weight,
+ v_proj_weight=self.v_proj.weight,
+ in_proj_weight=None,
+ in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
+ bias_k=None,
+ bias_v=None,
+ add_zero_attn=False,
+ dropout_p=0,
+ out_proj_weight=self.c_proj.weight,
+ out_proj_bias=self.c_proj.bias,
+ use_separate_proj_weight=True,
+ training=self.training,
+ need_weights=False
+ )
+
+ return x[0]
+
+
+class ModifiedResNet(nn.Module):
+ """
+ A ResNet class that is similar to torchvision's but contains the following changes:
+ - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
+ - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
+ - The final pooling layer is a QKV attention instead of an average pool
+ """
+
+ def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
+ super().__init__()
+ self.output_dim = output_dim
+ self.input_resolution = input_resolution
+
+ # the 3-layer stem
+ self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
+ self.bn1 = nn.BatchNorm2d(width // 2)
+ self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
+ self.bn2 = nn.BatchNorm2d(width // 2)
+ self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
+ self.bn3 = nn.BatchNorm2d(width)
+ self.avgpool = nn.AvgPool2d(2)
+ self.relu = nn.ReLU(inplace=True)
+
+ # residual layers
+ self._inplanes = width # this is a *mutable* variable used during construction
+ self.layer1 = self._make_layer(width, layers[0])
+ self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
+ self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
+ self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
+
+ self.embed_dim = width * 32 # the ResNet feature dimension
+ self.attnpool = AttentionPool2d(input_resolution // 32, self.embed_dim, heads, output_dim)
+
+ def _make_layer(self, planes, blocks, stride=1):
+ layers = [Bottleneck(self._inplanes, planes, stride)]
+
+ self._inplanes = planes * Bottleneck.expansion
+ for _ in range(1, blocks):
+ layers.append(Bottleneck(self._inplanes, planes))
+
+ return nn.Sequential(*layers)
+
+ def intermediate_features(self, x):
+ def stem(x):
+ for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
+ x = self.relu(bn(conv(x)))
+ x = self.avgpool(x)
+ return x
+
+ x = x.type(self.conv1.weight.dtype)
+ x = stem(x)
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+
+ b, c = x.shape[:2]
+ return x.view(b, c, -1).permute(0, 2, 1)
+
+ def forward(self, x):
+ x = self.intermediate_features(x)
+ x = x.permute(0, 2, 1)
+ l = int(math.sqrt(x.shape[-1]))
+ x = x.view(x.shape[0], x.shape[1], l, l)
+ x = self.attnpool(x)
+ return x
+
+
+class QuickGELU(nn.Module):
+ def forward(self, x: torch.Tensor):
+ return x * torch.sigmoid(1.702 * x)
+
+
+class ResidualAttentionBlock(nn.Module):
+ def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
+ super().__init__()
+
+ self.attn = nn.MultiheadAttention(d_model, n_head)
+ self.ln_1 = LayerNorm(d_model)
+ self.mlp = nn.Sequential(OrderedDict([
+ ("c_fc", nn.Linear(d_model, d_model * 4)),
+ ("gelu", QuickGELU()),
+ ("c_proj", nn.Linear(d_model * 4, d_model))
+ ]))
+ self.ln_2 = LayerNorm(d_model)
+ self.attn_mask = attn_mask
+
+ def attention(self, x: torch.Tensor):
+ self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
+ return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
+
+ def forward(self, x: torch.Tensor):
+ x = x + self.attention(self.ln_1(x))
+ x = x + self.mlp(self.ln_2(x))
+ return x
+
+
+class Transformer(nn.Module):
+ def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
+ super().__init__()
+ self.width = width
+ self.layers = layers
+ self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
+
+ def forward(self, x: torch.Tensor):
+ return self.resblocks(x)
+
+
+class VisualTransformer(nn.Module):
+ def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
+ super().__init__()
+ self.input_resolution = input_resolution
+ self.embed_dim = width
+ self.output_dim = output_dim
+ self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
+
+ scale = width ** -0.5
+ self.class_embedding = nn.Parameter(scale * torch.randn(width))
+ self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
+ self.ln_pre = LayerNorm(width)
+
+ self.transformer = Transformer(width, layers, heads)
+
+ self.ln_post = LayerNorm(width)
+ self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
+
+ def intermediate_features(self, x):
+ x = self.conv1(x) # shape = [*, width, grid, grid]
+ x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
+ x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
+ x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
+ x = x + self.positional_embedding.to(x.dtype)
+ x = self.ln_pre(x)
+
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.transformer(x)
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = self.ln_post(x)
+ return x
+
+ def forward(self, x: torch.Tensor):
+ x = self.intermediate_features(x)
+ x_cls = x[:, 0, :]
+
+ if self.proj is not None:
+ x_cls = x_cls @ self.proj
+
+ return x_cls
+
+
+class CLIP(nn.Module):
+ def __init__(self,
+ embed_dim: int,
+ # vision
+ image_resolution: int,
+ vision_layers: Union[Tuple[int, int, int, int], int],
+ vision_width: int,
+ vision_patch_size: int,
+ # text
+ context_length: int,
+ vocab_size: int,
+ transformer_width: int,
+ transformer_heads: int,
+ transformer_layers: int
+ ):
+ super().__init__()
+
+ self.context_length = context_length
+
+ if isinstance(vision_layers, (tuple, list)):
+ vision_heads = vision_width * 32 // 64
+ self.visual = ModifiedResNet(
+ layers=vision_layers,
+ output_dim=embed_dim,
+ heads=vision_heads,
+ input_resolution=image_resolution,
+ width=vision_width
+ )
+ else:
+ vision_heads = vision_width // 64
+ self.visual = VisualTransformer(
+ input_resolution=image_resolution,
+ patch_size=vision_patch_size,
+ width=vision_width,
+ layers=vision_layers,
+ heads=vision_heads,
+ output_dim=embed_dim
+ )
+
+ self.transformer = Transformer(
+ width=transformer_width,
+ layers=transformer_layers,
+ heads=transformer_heads,
+ attn_mask=self.build_attention_mask()
+ )
+
+ self.vocab_size = vocab_size
+ self.transformer_width = transformer_width
+ self.token_embedding = nn.Embedding(vocab_size, transformer_width)
+ self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
+ self.ln_final = LayerNorm(transformer_width)
+
+ self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
+ self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) # todo remove
+
+ self.initialize_parameters()
+
+ def initialize_parameters(self):
+ nn.init.normal_(self.token_embedding.weight, std=0.02)
+ nn.init.normal_(self.positional_embedding, std=0.01)
+
+ if isinstance(self.visual, ModifiedResNet):
+ if self.visual.attnpool is not None:
+ std = self.visual.attnpool.c_proj.in_features ** -0.5
+ nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
+ nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
+ nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
+ nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
+
+ for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
+ for name, param in resnet_block.named_parameters():
+ if name.endswith("bn3.weight"):
+ nn.init.zeros_(param)
+
+ proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
+ attn_std = self.transformer.width ** -0.5
+ fc_std = (2 * self.transformer.width) ** -0.5
+ for block in self.transformer.resblocks:
+ nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
+ nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
+ nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
+ nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
+
+ if self.text_projection is not None:
+ nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
+
+ def build_attention_mask(self):
+ # lazily create causal attention mask, with full attention between the vision tokens
+ # pytorch uses additive attention mask; fill with -inf
+ mask = torch.empty(self.context_length, self.context_length)
+ mask.fill_(float("-inf"))
+ mask.triu_(1) # zero out the lower diagonal
+ return mask
+
+ @property
+ def dtype(self):
+ return self.visual.conv1.weight.dtype
+
+ @property
+ def device(self):
+ return self.visual.conv1.weight.device
+
+ def encode_image(self, image):
+ return self.visual(image.type(self.dtype))
+
+ def encode_text(self, text):
+ if text.dtype in [torch.long, torch.int]:
+ text_idxs = text
+ x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
+ else:
+ text_idxs = one_hot_to_index(text)
+ x = (text @ self.token_embedding.weight).type(self.dtype)
+
+ x = x + self.positional_embedding.type(self.dtype)
+ x = x.permute(1, 0, 2) # NLD -> LND
+ x = self.transformer(x)
+ x = x.permute(1, 0, 2) # LND -> NLD
+ x = self.ln_final(x).type(self.dtype)
+
+ # x.shape = [batch_size, n_ctx, transformer.width]
+ # take features from the eot embedding (eot_token is the highest number in each sequence)
+ x = x[torch.arange(x.shape[0]), text_idxs.argmax(dim=-1)] @ self.text_projection
+
+ return x
+
+ def forward(self, image, text):
+ image_features = self.encode_image(image)
+ text_features = self.encode_text(text)
+
+ # normalized features
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
+
+ # cosine similarity as logits
+ logit_scale = self.logit_scale.exp()
+ logits_per_image = logit_scale * image_features @ text_features.t()
+ logits_per_text = logit_scale * text_features @ image_features.t()
+
+ # shape = [global_batch_size, global_batch_size]
+ return logits_per_image#, logits_per_text
+
+
+def convert_weights(model: nn.Module):
+ """Convert applicable model parameters to fp16"""
+
+ def _convert_weights_to_fp16(l):
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
+ l.weight.data = l.weight.data.half()
+ if l.bias is not None:
+ l.bias.data = l.bias.data.half()
+
+ if isinstance(l, nn.MultiheadAttention):
+ for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
+ tensor = getattr(l, attr)
+ if tensor is not None:
+ tensor.data = tensor.data.half()
+
+ for name in ["text_projection", "proj"]:
+ if hasattr(l, name):
+ attr = getattr(l, name)
+ if attr is not None:
+ attr.data = attr.data.half()
+
+ model.apply(_convert_weights_to_fp16)
+
+
+def build_model(state_dict: dict):
+ vit = "visual.proj" in state_dict
+
+ if vit:
+ vision_width = state_dict["visual.conv1.weight"].shape[0]
+ vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
+ vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
+ grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
+ image_resolution = vision_patch_size * grid_size
+ else:
+ counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
+ vision_layers = tuple(counts)
+ vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
+ output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
+ vision_patch_size = None
+ assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
+ image_resolution = output_width * 32
+
+ embed_dim = state_dict["text_projection"].shape[1]
+ context_length = state_dict["positional_embedding"].shape[0]
+ vocab_size = state_dict["token_embedding.weight"].shape[0]
+ transformer_width = state_dict["ln_final.weight"].shape[0]
+ transformer_heads = transformer_width // 64
+ transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
+
+ model = CLIP(
+ embed_dim,
+ image_resolution, vision_layers, vision_width, vision_patch_size,
+ context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
+ )
+
+ for key in ["input_resolution", "context_length", "vocab_size"]:
+ if key in state_dict:
+ del state_dict[key]
+
+ # convert_weights(model) todo remove
+ model.load_state_dict(state_dict, strict=False)
+ return model.eval()
diff --git a/models/containers.py b/models/containers.py
new file mode 100644
index 0000000..c5bf573
--- /dev/null
+++ b/models/containers.py
@@ -0,0 +1,81 @@
+from contextlib import contextmanager
+from torch import nn
+from utils.typing import *
+
+
+class Module(nn.Module):
+ def __init__(self):
+ super(Module, self).__init__()
+ self._is_stateful = False
+ self._state_names = []
+ self._state_defaults = dict()
+
+ def register_state(self, name: str, default: TensorOrNone):
+ self._state_names.append(name)
+ if default is None:
+ self._state_defaults[name] = None
+ else:
+ self._state_defaults[name] = default.clone().detach()
+ self.register_buffer(name, default)
+
+ def states(self):
+ for name in self._state_names:
+ yield self._buffers[name]
+ for m in self.children():
+ if isinstance(m, Module):
+ yield from m.states()
+
+ def apply_to_states(self, fn):
+ for name in self._state_names:
+ if self._buffers[name] is not None:
+ self._buffers[name] = fn(self._buffers[name])
+ for m in self.children():
+ if isinstance(m, Module):
+ m.apply_to_states(fn)
+
+ def _init_states(self, batch_size: int):
+ for name in self._state_names:
+ if self._state_defaults[name] is None:
+ self._buffers[name] = None
+ else:
+ self._buffers[name] = self._state_defaults[name].clone().detach().to(self._buffers[name].device)
+ self._buffers[name] = self._buffers[name].unsqueeze(0)
+ self._buffers[name] = self._buffers[name].expand([batch_size, ] + list(self._buffers[name].shape[1:]))
+ self._buffers[name] = self._buffers[name].contiguous()
+
+ def _reset_states(self):
+ for name in self._state_names:
+ if self._state_defaults[name] is None:
+ self._buffers[name] = None
+ else:
+ self._buffers[name] = self._state_defaults[name].clone().detach().to(self._buffers[name].device)
+
+ def enable_statefulness(self, batch_size: int):
+ for m in self.children():
+ if isinstance(m, Module):
+ m.enable_statefulness(batch_size)
+ self._init_states(batch_size)
+ self._is_stateful = True
+
+ def disable_statefulness(self):
+ for m in self.children():
+ if isinstance(m, Module):
+ m.disable_statefulness()
+ self._reset_states()
+ self._is_stateful = False
+
+ @contextmanager
+ def statefulness(self, batch_size: int):
+ self.enable_statefulness(batch_size)
+ try:
+ yield
+ finally:
+ self.disable_statefulness()
+
+
+class ModuleList(nn.ModuleList, Module):
+ pass
+
+
+class ModuleDict(nn.ModuleDict, Module):
+ pass
diff --git a/models/transformer/__init__.py b/models/transformer/__init__.py
new file mode 100644
index 0000000..e9e33b2
--- /dev/null
+++ b/models/transformer/__init__.py
@@ -0,0 +1,4 @@
+from .attention import *
+from .encoders import *
+from .decoders import *
+from .captioner import *
diff --git a/models/transformer/attention.py b/models/transformer/attention.py
new file mode 100644
index 0000000..039483f
--- /dev/null
+++ b/models/transformer/attention.py
@@ -0,0 +1,196 @@
+import numpy as np
+import torch
+from torch import nn
+from models.containers import Module
+
+
+class ScaledDotProductAttention(nn.Module):
+ """
+ Scaled dot-product attention
+ """
+
+ def __init__(self, d_model, d_k, d_v, h):
+ '''
+ :param d_model: Output dimensionality of the model
+ :param d_k: Dimensionality of queries and keys
+ :param d_v: Dimensionality of values
+ :param h: Number of heads
+ '''
+ super(ScaledDotProductAttention, self).__init__()
+ self.fc_q = nn.Linear(d_model, h * d_k)
+ self.fc_k = nn.Linear(d_model, h * d_k)
+ self.fc_v = nn.Linear(d_model, h * d_v)
+ self.fc_o = nn.Linear(h * d_v, d_model)
+
+ self.d_model = d_model
+ self.d_k = d_k
+ self.d_v = d_v
+ self.h = h
+
+ self.init_weights()
+
+ def init_weights(self):
+ nn.init.xavier_uniform_(self.fc_q.weight)
+ nn.init.xavier_uniform_(self.fc_k.weight)
+ nn.init.xavier_uniform_(self.fc_v.weight)
+ nn.init.xavier_uniform_(self.fc_o.weight)
+ nn.init.constant_(self.fc_q.bias, 0)
+ nn.init.constant_(self.fc_k.bias, 0)
+ nn.init.constant_(self.fc_v.bias, 0)
+ nn.init.constant_(self.fc_o.bias, 0)
+
+ def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):
+ """
+ Computes
+ :param queries: Queries (b_s, nq, d_model)
+ :param keys: Keys (b_s, nk, d_model)
+ :param values: Values (b_s, nk, d_model)
+ :param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking.
+ :param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk).
+ :return:
+ """
+ b_s, nq = queries.shape[:2]
+ nk = keys.shape[1]
+ q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k)
+ k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk)
+ v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v)
+
+ att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk)
+ if attention_weights is not None:
+ att = att * attention_weights
+ if attention_mask is not None:
+ att = att.masked_fill(attention_mask, -np.inf)
+ att = torch.softmax(att, -1)
+ out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v)
+ out = self.fc_o(out) # (b_s, nq, d_model)
+ return out
+
+
+class ScaledDotProductAttentionMemory(nn.Module):
+ """
+ Scaled dot-product attention with memory
+ """
+
+ def __init__(self, d_model, d_k, d_v, h, m):
+ """
+ :param d_model: Output dimensionality of the model
+ :param d_k: Dimensionality of queries and keys
+ :param d_v: Dimensionality of values
+ :param h: Number of heads
+ :param m: Number of memory slots
+ """
+ super(ScaledDotProductAttentionMemory, self).__init__()
+ self.fc_q = nn.Linear(d_model, h * d_k)
+ self.fc_k = nn.Linear(d_model, h * d_k)
+ self.fc_v = nn.Linear(d_model, h * d_v)
+ self.fc_o = nn.Linear(h * d_v, d_model)
+ self.d_model = d_model
+ self.d_k = d_k
+ self.d_v = d_v
+ self.h = h
+ self.m = m
+
+ if self.m > 0:
+ self.m_k = nn.Parameter(torch.FloatTensor(1, m, h * d_k))
+ self.m_v = nn.Parameter(torch.FloatTensor(1, m, h * d_v))
+
+ self.init_weights()
+
+ def init_weights(self):
+ nn.init.xavier_uniform_(self.fc_q.weight)
+ nn.init.xavier_uniform_(self.fc_k.weight)
+ nn.init.xavier_uniform_(self.fc_v.weight)
+ nn.init.xavier_uniform_(self.fc_o.weight)
+ nn.init.constant_(self.fc_q.bias, 0)
+ nn.init.constant_(self.fc_k.bias, 0)
+ nn.init.constant_(self.fc_v.bias, 0)
+ nn.init.constant_(self.fc_o.bias, 0)
+
+ if self.m > 0:
+ nn.init.normal_(self.m_k, 0, 1 / self.d_k)
+ nn.init.normal_(self.m_v, 0, 1 / self.m)
+
+ def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):
+ """
+ Computes
+ :param queries: Queries (b_s, nq, d_model)
+ :param keys: Keys (b_s, nk, d_model)
+ :param values: Values (b_s, nk, d_model)
+ :param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking.
+ :param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk).
+ :return:
+ """
+ b_s, nq = queries.shape[:2]
+ nk = keys.shape[1]
+
+ q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k)
+
+ if self.m > 0:
+ m_k = np.sqrt(self.d_k) * self.m_k.expand(b_s, self.m, self.h * self.d_k)
+ m_v = np.sqrt(self.m) * self.m_v.expand(b_s, self.m, self.h * self.d_v)
+ k = torch.cat([self.fc_k(keys), m_k], 1)
+ v = torch.cat([self.fc_v(values), m_v], 1)
+ else:
+ k = self.fc_k(keys)
+ v = self.fc_v(values)
+
+ k = k.view(b_s, nk + self.m, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk)
+ v = v.view(b_s, nk + self.m, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v)
+
+ att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk)
+ if attention_weights is not None:
+ att = torch.cat([att[:, :, :, :nk] * attention_weights, att[:, :, :, nk:]], -1)
+ if attention_mask is not None:
+ att[:, :, :, :nk] = att[:, :, :, :nk].masked_fill(attention_mask, -np.inf)
+ att = torch.softmax(att, -1)
+ out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v)
+ out = self.fc_o(out) # (b_s, nq, d_model)
+ return out
+
+
+class MultiHeadAttention(Module):
+ """
+ Multi-head attention layer with Dropout and Layer Normalization.
+ """
+
+ def __init__(self, d_model, d_k, d_v, h, dropout=.1, identity_map_reordering=False, can_be_stateful=False,
+ attention_module=None, attention_module_kwargs=None):
+ super(MultiHeadAttention, self).__init__()
+ self.identity_map_reordering = identity_map_reordering
+ if attention_module is not None:
+ if attention_module_kwargs is not None:
+ self.attention = attention_module(d_model=d_model, d_k=d_k, d_v=d_v, h=h, **attention_module_kwargs)
+ else:
+ self.attention = attention_module(d_model=d_model, d_k=d_k, d_v=d_v, h=h)
+ else:
+ self.attention = ScaledDotProductAttention(d_model=d_model, d_k=d_k, d_v=d_v, h=h)
+ self.dropout = nn.Dropout(p=dropout)
+ self.layer_norm = nn.LayerNorm(d_model)
+
+ self.can_be_stateful = can_be_stateful
+ if self.can_be_stateful:
+ self.register_state('running_keys', None)
+ self.register_state('running_values', None)
+
+ def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):
+ if self.can_be_stateful and self._is_stateful:
+ if self.running_keys is None:
+ self.running_keys = keys
+ self.running_values = values
+ else:
+ self.running_keys = torch.cat([self.running_keys, keys], 1)
+ self.running_values = torch.cat([self.running_values, values], 1)
+ keys = self.running_keys
+ values = self.running_values
+
+ if self.identity_map_reordering:
+ q_norm = self.layer_norm(queries)
+ k_norm = self.layer_norm(keys)
+ v_norm = self.layer_norm(values)
+ out = self.attention(q_norm, k_norm, v_norm, attention_mask, attention_weights)
+ out = queries + self.dropout(torch.relu(out))
+ else:
+ out = self.attention(queries, keys, values, attention_mask, attention_weights)
+ out = self.dropout(out)
+ out = self.layer_norm(queries + out)
+ return out
diff --git a/models/transformer/captioner.py b/models/transformer/captioner.py
new file mode 100644
index 0000000..df861e7
--- /dev/null
+++ b/models/transformer/captioner.py
@@ -0,0 +1,93 @@
+import copy
+from pathlib import Path
+
+import torch
+from torch import Tensor
+from torch import nn
+
+from data.field import TextField
+from models.beam_search import *
+from models.containers import ModuleList, Module
+from utils import TensorOrSequence
+from . import Encoder, Decoder, ScaledDotProductAttentionMemory, MeshedDecoder
+
+
+class Captioner(Module):
+ def __init__(self, args, text_field: TextField):
+ super(Captioner, self).__init__()
+
+ self.encoder = Encoder(args.N_enc, 500, args.image_dim, d_model=args.d_model, d_ff=args.d_ff, h=args.head,
+ attention_module=ScaledDotProductAttentionMemory,
+ attention_module_kwargs={'m': args.m},
+ with_pe=args.with_pe, with_mesh=not args.disable_mesh)
+ if args.disable_mesh:
+ self.decoder = Decoder(text_field._tokenizer.vocab_size, 40, args.N_dec, d_model=args.d_model,
+ d_ff=args.d_ff, h=args.head)
+ else:
+ self.decoder = MeshedDecoder(text_field._tokenizer.vocab_size, 40, args.N_dec, args.N_enc,
+ d_model=args.d_model, d_ff=args.d_ff, h=args.head)
+ self.bos_idx = text_field._tokenizer.bos_idx
+ self.eos_idx = text_field._tokenizer.eos_idx
+ self.vocab_size = text_field._tokenizer.vocab_size
+ self.max_generation_length = self.decoder.max_len
+
+ self.register_state('enc_output', None)
+ self.register_state('mask_enc', None)
+ self.init_weights()
+
+ @property
+ def d_model(self):
+ return self.decoder.d_model
+
+ def train(self, mode: bool = True):
+ self.encoder.train(mode)
+ self.decoder.train(mode)
+
+ def init_weights(self):
+ for p in self.encoder.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+ for p in self.decoder.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+
+ def forward(self, images, seq):
+ enc_output, mask_enc = self.encoder(images)
+ dec_output = self.decoder(seq, enc_output, mask_enc)
+ return dec_output
+
+ def step(self, t: int, prev_output: Tensor, visual: Tensor) -> Tensor:
+ if t == 0:
+ self.enc_output, self.mask_enc = self.encoder(visual)
+ input = visual.data.new_full((visual.shape[0], 1), self.bos_idx, dtype=torch.long)
+ else:
+ input = prev_output
+ logits = self.decoder(input, self.enc_output, self.mask_enc)
+ return logits
+
+ def beam_search(self, visual: TensorOrSequence, beam_size: int, out_size=1,
+ return_logits=False, **kwargs):
+ bs = BeamSearch(self, self.max_generation_length, self.eos_idx, beam_size)
+ return bs.apply(visual, out_size, return_logits, **kwargs)
+
+
+class CaptionerEnsemble(Captioner):
+ def __init__(self, model: Captioner, args, text_field, weight_files, weight_folder=None):
+ super(CaptionerEnsemble, self).__init__(args, text_field)
+ self.n = len(weight_files)
+ self.models = ModuleList([copy.deepcopy(model) for _ in range(self.n)])
+ for model_i, weight_file_i in zip(self.models, weight_files):
+ if Path(weight_file_i).is_absolute():
+ fname = Path(weight_file_i)
+ else:
+ fname = Path(weight_folder).joinpath(weight_file_i)
+ state_dict_i = torch.load(fname)['state_dict_t']
+ model_i.load_state_dict(state_dict_i)
+
+ def step(self, t, prev_output, visual):
+ out_ensemble = []
+ for model_i in self.models:
+ out_i = model_i.step(t, prev_output, visual)
+ out_ensemble.append(out_i.unsqueeze(0))
+
+ return torch.mean(torch.cat(out_ensemble, 0), dim=0)
diff --git a/models/transformer/decoders.py b/models/transformer/decoders.py
new file mode 100644
index 0000000..6e60999
--- /dev/null
+++ b/models/transformer/decoders.py
@@ -0,0 +1,199 @@
+import torch
+from torch import nn
+import numpy as np
+
+from models.transformer.attention import MultiHeadAttention
+from models.transformer.utils import sinusoid_encoding_table, PositionWiseFeedForward
+from models.containers import Module, ModuleList
+from models.utils import one_hot_to_index
+
+
+class MeshedDecoderLayer(Module):
+ def __init__(self, N_enc, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, self_att_module=None,
+ enc_att_module=None, self_att_module_kwargs=None, enc_att_module_kwargs=None):
+ super(MeshedDecoderLayer, self).__init__()
+ self.N_enc = N_enc
+ self.self_att = MultiHeadAttention(d_model, d_k, d_v, h, dropout, can_be_stateful=True,
+ attention_module=self_att_module,
+ attention_module_kwargs=self_att_module_kwargs)
+ self.enc_att = MultiHeadAttention(d_model, d_k, d_v, h, dropout, can_be_stateful=False,
+ attention_module=enc_att_module,
+ attention_module_kwargs=enc_att_module_kwargs)
+ self.pwff = PositionWiseFeedForward(d_model, d_ff, dropout)
+
+ self.fc_alpha = ModuleList([nn.Linear(d_model + d_model, d_model) for _ in range(N_enc)])
+
+ self.init_weights()
+
+ def init_weights(self):
+ for fc_alpha in self.fc_alpha:
+ nn.init.xavier_uniform_(fc_alpha.weight)
+ nn.init.constant_(fc_alpha.bias, 0)
+
+ def forward(self, input, enc_output, mask_pad, mask_self_att, mask_enc_att):
+ self_att = self.self_att(input, input, input, mask_self_att)
+ self_att = self_att * mask_pad
+
+ enc_att = None
+ for i in range(self.N_enc):
+ enc_att_i = self.enc_att(self_att, enc_output[:, i], enc_output[:, i], mask_enc_att) * mask_pad
+ alpha_i = torch.sigmoid(self.fc_alpha[i](torch.cat([self_att, enc_att_i], -1)))
+ if enc_att is None:
+ enc_att = enc_att_i * alpha_i
+ else:
+ enc_att += enc_att_i * alpha_i
+
+ enc_att /= np.sqrt(self.N_enc)
+ enc_att *= mask_pad
+
+ ff = self.pwff(enc_att)
+ ff = ff * mask_pad
+ return ff
+
+
+class DecoderLayer(Module):
+ def __init__(self, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, self_att_module=None,
+ enc_att_module=None, self_att_module_kwargs=None, enc_att_module_kwargs=None):
+ super(DecoderLayer, self).__init__()
+ self.self_att = MultiHeadAttention(d_model, d_k, d_v, h, dropout, can_be_stateful=True,
+ attention_module=self_att_module,
+ attention_module_kwargs=self_att_module_kwargs)
+ self.enc_att = MultiHeadAttention(d_model, d_k, d_v, h, dropout, can_be_stateful=False,
+ attention_module=enc_att_module,
+ attention_module_kwargs=enc_att_module_kwargs)
+ self.pwff = PositionWiseFeedForward(d_model, d_ff, dropout)
+
+ def forward(self, input, enc_output, mask_pad, mask_self_att, mask_enc_att):
+ self_att = self.self_att(input, input, input, mask_self_att)
+ enc_att = self.enc_att(self_att, enc_output, enc_output, mask_enc_att)
+ ff = self.pwff(enc_att)
+ ff = ff * mask_pad
+ return ff
+
+
+class MeshedDecoder(Module):
+ def __init__(self, vocab_size, max_len, N_dec, N_enc, padding_idx=0, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048,
+ dropout=.1,
+ self_att_module=None, enc_att_module=None, self_att_module_kwargs=None, enc_att_module_kwargs=None):
+ super(MeshedDecoder, self).__init__()
+ self.d_model = d_model
+ self.vocab_size = vocab_size
+ self.word_emb = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx)
+ self.pos_emb = nn.Embedding.from_pretrained(sinusoid_encoding_table(max_len + 1, d_model, 0), freeze=True)
+ self.layers = ModuleList(
+ [MeshedDecoderLayer(N_enc, d_model, d_k, d_v, h, d_ff, dropout, self_att_module=self_att_module,
+ enc_att_module=enc_att_module, self_att_module_kwargs=self_att_module_kwargs,
+ enc_att_module_kwargs=enc_att_module_kwargs) for _ in range(N_dec)])
+ self.fc = nn.Linear(d_model, vocab_size, bias=False)
+ self.max_len = max_len
+ self.padding_idx = padding_idx
+ self.N = N_dec
+
+ self.register_state('running_mask_self_attention', None)
+ self.register_state('running_seq', torch.zeros((1,)).long())
+
+ def forward(self, input, encoder_output_list, mask_encoder):
+ # input (b_s, seq_len)
+ input = input[:, :self.max_len]
+ b_s, seq_len = input.shape[:2]
+
+ if input.dtype in [torch.long, torch.int]:
+ input_index = input
+ else:
+ input_index = one_hot_to_index(input)
+
+ mask_queries = (input_index != self.padding_idx).unsqueeze(-1).type(input.dtype) # (b_s, seq_len, 1)
+ mask_self_attention = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.bool, device=input.device),
+ diagonal=1)
+ mask_self_attention = mask_self_attention.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len)
+ mask_self_attention = mask_self_attention + (input_index == self.padding_idx).unsqueeze(1).unsqueeze(1).bool()
+ mask_self_attention = mask_self_attention.gt(0) # (b_s, 1, seq_len, seq_len)
+ if self._is_stateful:
+ if self.running_mask_self_attention is None:
+ self.running_mask_self_attention = mask_self_attention
+ else:
+ self.running_mask_self_attention = torch.cat([self.running_mask_self_attention, mask_self_attention],
+ -1)
+ mask_self_attention = self.running_mask_self_attention
+
+ seq = torch.arange(1, seq_len + 1).view(1, -1).expand(b_s, -1).to(input.device) # (b_s, seq_len)
+ seq = seq.masked_fill(mask_queries.squeeze(-1) == 0, 0)
+ if self._is_stateful:
+ self.running_seq.add_(1)
+ seq = self.running_seq
+
+ if input.dtype in [torch.long, torch.int]:
+ out = self.word_emb(input)
+ else:
+ out = input @ self.word_emb.weight
+
+ out = out + self.pos_emb(seq)
+ for i, l in enumerate(self.layers):
+ out = l(out, encoder_output_list, mask_queries, mask_self_attention, mask_encoder)
+
+ out = self.fc(out)
+ return out
+
+
+class Decoder(Module):
+ def __init__(self, vocab_size, max_len, N_dec, padding_idx=0, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048,
+ dropout=.1, self_att_module=None, enc_att_module=None, self_att_module_kwargs=None,
+ enc_att_module_kwargs=None):
+ super(Decoder, self).__init__()
+ self.d_model = d_model
+ self.vocab_size = vocab_size
+ self.word_emb = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx)
+ self.pos_emb = nn.Embedding.from_pretrained(sinusoid_encoding_table(max_len + 1, d_model, 0), freeze=True)
+ self.layers = ModuleList(
+ [DecoderLayer(d_model, d_k, d_v, h, d_ff, dropout, self_att_module=self_att_module,
+ enc_att_module=enc_att_module, self_att_module_kwargs=self_att_module_kwargs,
+ enc_att_module_kwargs=enc_att_module_kwargs) for _ in range(N_dec)])
+ self.fc = nn.Linear(d_model, vocab_size, bias=False)
+ self.max_len = max_len
+ self.padding_idx = padding_idx
+ self.N = N_dec
+
+ self.register_state('running_mask_self_attention', None)
+ self.register_state('running_seq', torch.zeros((1,)).long())
+
+ def forward(self, input, encoder_output, mask_encoder):
+ # input (b_s, seq_len)
+ input = input[:, :self.max_len]
+ b_s, seq_len = input.shape[:2]
+
+ if input.dtype in [torch.long, torch.int]:
+ input_index = input
+ else:
+ input_index = one_hot_to_index(input)
+
+ mask_queries = (input_index != self.padding_idx).unsqueeze(-1).type(input.dtype) # (b_s, seq_len, 1)
+ mask_self_attention = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.bool, device=input.device),
+ diagonal=1)
+ mask_self_attention = mask_self_attention.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len)
+ mask_self_attention = mask_self_attention + (input_index == self.padding_idx).unsqueeze(1).unsqueeze(1).bool()
+ mask_self_attention = mask_self_attention.gt(0) # (b_s, 1, seq_len, seq_len)
+ if self._is_stateful:
+ if self.running_mask_self_attention is None:
+ self.running_mask_self_attention = mask_self_attention
+ else:
+ self.running_mask_self_attention = torch.cat([self.running_mask_self_attention, mask_self_attention],
+ -1)
+ mask_self_attention = self.running_mask_self_attention
+
+ seq = torch.arange(1, seq_len + 1).view(1, -1).expand(b_s, -1).to(input.device) # (b_s, seq_len)
+ seq = seq.masked_fill(mask_queries.squeeze(-1) == 0, 0)
+ if self._is_stateful:
+ self.running_seq.add_(1)
+ seq = self.running_seq
+
+ if input.dtype in [torch.long, torch.int]:
+ out = self.word_emb(input)
+ else:
+ out = input @ self.word_emb.weight
+
+ out = out + self.pos_emb(seq)
+ for i, l in enumerate(self.layers):
+ out = l(out, encoder_output, mask_queries, mask_self_attention, mask_encoder)
+
+ out = self.fc(out)
+ return out
diff --git a/models/transformer/encoders.py b/models/transformer/encoders.py
new file mode 100644
index 0000000..35dae4c
--- /dev/null
+++ b/models/transformer/encoders.py
@@ -0,0 +1,65 @@
+from torch.nn import functional as F
+from models.transformer.utils import sinusoid_encoding_table, PositionWiseFeedForward
+import torch
+from torch import nn
+from models.transformer.attention import MultiHeadAttention
+
+
+class EncoderLayer(nn.Module):
+ def __init__(self, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, identity_map_reordering=False,
+ attention_module=None, attention_module_kwargs=None):
+ super(EncoderLayer, self).__init__()
+ self.identity_map_reordering = identity_map_reordering
+ self.mhatt = MultiHeadAttention(d_model, d_k, d_v, h, dropout, identity_map_reordering=identity_map_reordering,
+ attention_module=attention_module,
+ attention_module_kwargs=attention_module_kwargs)
+ self.pwff = PositionWiseFeedForward(d_model, d_ff, dropout, identity_map_reordering=identity_map_reordering)
+
+ def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):
+ att = self.mhatt(queries, keys, values, attention_mask, attention_weights)
+ ff = self.pwff(att)
+ return ff
+
+
+class Encoder(nn.Module):
+ def __init__(self, N, max_len, d_in, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1,
+ identity_map_reordering=False, attention_module=None, attention_module_kwargs=None,
+ with_pe=False, with_mesh=False):
+ super(Encoder, self).__init__()
+ self.d_in = d_in
+ self.d_model = d_model
+ self.dropout = dropout
+ self.layers = nn.ModuleList([EncoderLayer(d_model, d_k, d_v, h, d_ff, dropout,
+ identity_map_reordering=identity_map_reordering,
+ attention_module=attention_module,
+ attention_module_kwargs=attention_module_kwargs)
+ for _ in range(N)])
+ self.pos_emb = nn.Embedding.from_pretrained(sinusoid_encoding_table(max_len + 1, self.d_in, 0), freeze=True)
+ self.fc = nn.Linear(d_in, self.d_model)
+ self.dropout = nn.Dropout(p=self.dropout)
+ self.layer_norm = nn.LayerNorm(self.d_model)
+ self.with_pe = with_pe
+ self.with_mesh = with_mesh
+
+ def forward(self, input):
+ # input (b_s, seq_len, d_in)
+ b_s, seq_len = input.shape[:2]
+ seq = torch.arange(1, seq_len + 1, device=input.device).view(1, -1).expand(b_s, -1) # (b_s, seq_len)
+
+ out = input
+ if self.with_pe:
+ out = out + self.pos_emb(seq)
+ out = F.relu(self.fc(out))
+ out = self.dropout(out)
+ out = self.layer_norm(out)
+ outs = list()
+ for l in self.layers:
+ out = l(out, out, out)
+ if self.with_mesh:
+ outs.append(out.unsqueeze(1))
+
+ if self.with_mesh:
+ outs = torch.cat(outs, 1)
+ return outs, None
+ return out, None
+
diff --git a/models/transformer/utils.py b/models/transformer/utils.py
new file mode 100644
index 0000000..c7662c0
--- /dev/null
+++ b/models/transformer/utils.py
@@ -0,0 +1,50 @@
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+
+def position_embedding(input, d_model):
+ input = input.view(-1, 1)
+ dim = torch.arange(d_model // 2, dtype=input.dtype, device=input.device).view(1, -1)
+ sin = torch.sin(input / 10000 ** (2 * dim / d_model))
+ cos = torch.cos(input / 10000 ** (2 * dim / d_model))
+
+ out = torch.zeros((input.shape[0], d_model), device=input.device)
+ out[:, ::2] = sin
+ out[:, 1::2] = cos
+ return out
+
+
+def sinusoid_encoding_table(max_len, d_model, padding_idx=None, dtype=torch.float32):
+ pos = torch.arange(max_len, dtype=dtype)
+ out = position_embedding(pos, d_model)
+
+ if padding_idx is not None:
+ out[padding_idx] = 0
+ return out
+
+
+class PositionWiseFeedForward(nn.Module):
+ """
+ Position-wise feed forward layer
+ """
+
+ def __init__(self, d_model=512, d_ff=2048, dropout=.1, identity_map_reordering=False):
+ super(PositionWiseFeedForward, self).__init__()
+ self.identity_map_reordering = identity_map_reordering
+ self.fc1 = nn.Linear(d_model, d_ff)
+ self.fc2 = nn.Linear(d_ff, d_model)
+ self.dropout = nn.Dropout(p=dropout)
+ self.dropout_2 = nn.Dropout(p=dropout)
+ self.layer_norm = nn.LayerNorm(d_model)
+
+ def forward(self, input):
+ if self.identity_map_reordering:
+ out = self.layer_norm(input)
+ out = self.fc2(self.dropout_2(F.relu(self.fc1(out))))
+ out = input + self.dropout(torch.relu(out))
+ else:
+ out = self.fc2(self.dropout_2(F.relu(self.fc1(input))))
+ out = self.dropout(out)
+ out = self.layer_norm(input + out)
+ return out
diff --git a/models/utils.py b/models/utils.py
new file mode 100644
index 0000000..0e2fbe0
--- /dev/null
+++ b/models/utils.py
@@ -0,0 +1,12 @@
+import torch
+from torch import Tensor
+
+
+def one_hot_to_index(one_hot: Tensor) -> Tensor:
+ """
+ Converts a one-hot tensor into a tensor with corresponding indexes
+ """
+ device, dtype = one_hot.device, one_hot.dtype
+ vocab_size = one_hot.shape[-1]
+ oh2idx = torch.tensor(range(vocab_size), dtype=dtype, device=device)
+ return (one_hot @ oh2idx.unsqueeze(dim=1)).long().squeeze(dim=-1)