logo
Browse Source

init the operator.

Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb 3 years ago
parent
commit
898f49fa86
  1. 85
      README.md
  2. 18
      __init__.py
  3. 115
      camel.py
  4. BIN
      data/.DS_Store
  5. 10
      data/__init__.py
  6. 303
      data/dataset.py
  7. 26
      data/example.py
  8. 127
      data/field.py
  9. 0
      data/tokenizer/__init__.py
  10. BIN
      data/tokenizer/bpe_simple_vocab_16e6.txt.gz
  11. 144
      data/tokenizer/simple_tokenizer.py
  12. BIN
      models/.DS_Store
  13. 1
      models/__init__.py
  14. 1
      models/beam_search/__init__.py
  15. 149
      models/beam_search/beam_search.py
  16. 619
      models/clip.py
  17. 81
      models/containers.py
  18. 4
      models/transformer/__init__.py
  19. 196
      models/transformer/attention.py
  20. 93
      models/transformer/captioner.py
  21. 199
      models/transformer/decoders.py
  22. 65
      models/transformer/encoders.py
  23. 50
      models/transformer/utils.py
  24. 12
      models/utils.py

85
README.md

@ -1,2 +1,85 @@
# camel
# Image Captioning with CaMEL
*author: David Wang*
<br />
## Description
This operator generates the caption with [CapDec](https://arxiv.org/abs/2211.00575) which describes the content of the given image. ExpansionNet v2 introduces the Block Static Expansion which distributes and processes the input over a heterogeneous and arbitrarily big collection of sequences characterized by a different length compared to the input one. This is an adaptation from [DavidHuji/CapDec](https://github.com/DavidHuji/CapDec).
<br />
## Code Example
Load an image from path './image.jpg' to generate the caption.
*Write the pipeline in simplified style*:
```python
import towhee
towhee.glob('./image.jpg') \
.image_decode() \
.image_captioning.capdec(model_name='capdec_noise_0') \
.show()
```
<img src="./cap.png" alt="result1" style="height:20px;"/>
*Write a same pipeline with explicit inputs/outputs name specifications:*
```python
import towhee
towhee.glob['path']('./image.jpg') \
.image_decode['path', 'img']() \
.image_captioning.capdec['img', 'text'](model_name='capdec_noise_0') \
.select['img', 'text']() \
.show()
```
<img src="./tabular.png" alt="result2" style="height:60px;"/>
<br />
## Factory Constructor
Create the operator via the following factory method
***capdec(model_name)***
**Parameters:**
***model_name:*** *str*
​ The model name of CapDec. Supported model names:
- capdec_noise_0
- capdec_noise_01
- capdec_noise_001
- capdec_noise_0001
<br />
## Interface
An image captioning operator takes a [towhee image](link/to/towhee/image/api/doc) as input and generate the correspoing caption.
**Parameters:**
***data:*** *towhee.types.Image (a sub-class of numpy.ndarray)*
​ The image to generate caption.
**Returns:** *str*
​ The caption generated by model.

18
__init__.py

@ -0,0 +1,18 @@
# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .camel import Camel
def camel(model_name: str):
return Camel(model_name)

115
camel.py

@ -0,0 +1,115 @@
# Copyright 2021 Zilliz. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import os
from pathlib import Path
from easydict import EasyDict as edict
import torch
from torchvision import transforms
from transformers import GPT2Tokenizer
from towhee.types.arg import arg, to_image_color
from towhee.types.image_utils import to_pil
from towhee.operator.base import NNOperator, OperatorFlag
from towhee import register
from towhee.models import clip
class Camel(NNOperator):
"""
Camel image captioning operator
"""
def _gen_args(self):
args = edict()
args.image_dim =
args.N_enc = 3
args.d_model = 512
args.d_ff = 2048
args.head = 8
args.m = 40
args.disable_mesh = True
args.d_model = 512
args.with_pe = True
return args
def __init__(self, model_name: str):
super().__init__()
sys.path.append(str(Path(__file__).parent))
self.device = "cuda" if torch.cuda.is_available() else "cpu"
from models import Captioner
from data import ImageField, TextField
# Pipeline for text
self.text_field = TextField()
args = self._gen_args()
self.clip_model = clip.create_model(model_name='clip_resnet_r50x4', pretrained=True, jit=True)
self.clip_tfms = clip.get_transforms(model_name='clip_resnet_r50x4')
self.image_model = self.clip_model.visual
self.image_model.forward = self.image_model.intermediate_features
image_field = ImageField(transform=self.clip_tfms)
args.image_dim = self.mage_model.embed_dim
# Create the model
self.model = Captioner(args, self.text_field).to(self.device)
self.model.forward = self.model.beam_search
self.image_model = self.image_model.to(self.device)
self.model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
self.model = self.model.eval()
sys.path.pop()
@arg(1, to_image_color('RGB'))
def inference_single_data(self, data):
text = self._inference_from_image(data)
return text
def _preprocess(self, img):
img = to_pil(img)
processed_img = self.clip_tfms(img).unsqueeze(0).to(self.device)
return processed_img
def __call__(self, data):
if not isinstance(data, list):
data = [data]
else:
data = data
results = []
for single_data in data:
result = self.inference_single_data(single_data)
results.append(result)
if len(data) == 1:
return results[0]
else:
return results
@arg(1, to_image_color('RGB'))
def _inference_from_image(self, img):
img = self._preprocess(img)
text, _ = self.model.beam_search(img, beam_size=5, out_size=1)
return text
def _configs(self):
config = {}
config['clipcap_coco'] = {}
config['clipcap_coco']['weights'] = 'coco_weights.pt'
config['clipcap_conceptual'] = {}
config['clipcap_conceptual']['weights'] = 'conceptual_weights.pt'
return config
if __name__ == '__main__':
pass

BIN
data/.DS_Store

Binary file not shown.

10
data/__init__.py

@ -0,0 +1,10 @@
from .field import RawField, Merge, ImageField, TextField
from .dataset import *
from torch.utils.data import DataLoader as TorchDataLoader
from .dataset import *
class DataLoader(TorchDataLoader):
def __init__(self, dataset, *args, **kwargs):
super(DataLoader, self).__init__(dataset, *args, collate_fn=dataset.collate_fn(), **kwargs)

303
data/dataset.py

@ -0,0 +1,303 @@
import collections
import itertools
import os
import numpy as np
import torch
from pycocotools.coco import COCO as pyCOCO
from torch.utils.data import Dataset as PthDataset
from utils import nostdout
from .example import Example
class Dataset(PthDataset):
def __init__(self, examples, fields):
self.examples = examples
self.fields = dict(fields)
def collate_fn(self):
def collate(batch):
if len(self.fields) == 1:
batch = [batch, ]
else:
batch = list(zip(*batch))
tensors = []
for field, data in zip(self.fields.values(), batch):
tensor = field.process(data)
if isinstance(tensor, collections.Sequence) and any(isinstance(t, torch.Tensor) for t in tensor):
tensors.extend(tensor)
else:
tensors.append(tensor)
if len(tensors) > 1:
return tensors
else:
return tensors[0]
return collate
def __getitem__(self, i):
example = self.examples[i]
data = []
for field_name, field in self.fields.items():
data.append(field.preprocess(getattr(example, field_name, None)))
if len(data) == 1:
data = data[0]
return data
def __len__(self):
return len(self.examples)
class ValueDataset(Dataset):
def __init__(self, examples, fields, dictionary):
self.dictionary = dictionary
super(ValueDataset, self).__init__(examples, fields)
def collate_fn(self):
def collate(batch):
value_batch_flattened = list(itertools.chain(*batch))
value_tensors_flattened = super(ValueDataset, self).collate_fn()(value_batch_flattened)
lengths = [0, ] + list(itertools.accumulate([len(x) for x in batch]))
if isinstance(value_tensors_flattened, collections.Sequence) \
and any(isinstance(t, torch.Tensor) for t in value_tensors_flattened):
value_tensors = [[vt[s:e] for (s, e) in zip(lengths[:-1], lengths[1:])]
for vt in value_tensors_flattened]
else:
value_tensors = [value_tensors_flattened[s:e] for (s, e) in zip(lengths[:-1], lengths[1:])]
return value_tensors
return collate
def __getitem__(self, i):
if i not in self.dictionary:
raise IndexError
values_data = []
for idx in self.dictionary[i]:
value_data = super(ValueDataset, self).__getitem__(idx)
values_data.append(value_data)
return values_data
def __len__(self):
return len(self.dictionary)
class DictionaryDataset(Dataset):
def __init__(self, examples, fields, key_fields):
if not isinstance(key_fields, (tuple, list)):
key_fields = (key_fields,)
for field in key_fields:
assert (field in fields)
dictionary = collections.defaultdict(list)
key_fields = {k: fields[k] for k in key_fields}
value_fields = {k: fields[k] for k in fields.keys() if k not in key_fields}
key_examples = []
key_dict = dict()
value_examples = []
for i, e in enumerate(examples):
key_example = Example.fromdict({k: getattr(e, k) for k in key_fields})
value_example = Example.fromdict({v: getattr(e, v) for v in value_fields})
if key_example not in key_dict:
key_dict[key_example] = len(key_examples)
key_examples.append(key_example)
value_examples.append(value_example)
dictionary[key_dict[key_example]].append(i)
self.key_dataset = Dataset(key_examples, key_fields)
self.value_dataset = ValueDataset(value_examples, value_fields, dictionary)
super(DictionaryDataset, self).__init__(examples, fields)
def collate_fn(self):
def collate(batch):
key_batch, value_batch = list(zip(*batch))
key_tensors = self.key_dataset.collate_fn()(key_batch)
value_tensors = self.value_dataset.collate_fn()(value_batch)
return key_tensors, value_tensors
return collate
def __getitem__(self, i):
return self.key_dataset[i], self.value_dataset[i]
def __len__(self):
return len(self.key_dataset)
def unique(sequence):
seen = set()
if isinstance(sequence[0], list):
return [x for x in sequence if not (tuple(x) in seen or seen.add(tuple(x)))]
else:
return [x for x in sequence if not (x in seen or seen.add(x))]
# class ImageDataset(Dataset):
# def __init__(self, examples, image_field):
# super().__init__(examples, {'image': image_field})
#
#
# class COCO2017Unlabeled(ImageDataset):
# def __init__(self, image_field, img_root, load_in_tmp=False):
# tmp_path = os.path.join('/tmp/coco2017unlabeled')
# if load_in_tmp and sync_path(img_root, tmp_path, size=19*1024**3):
# img_root = tmp_path
#
# examples = [Example.fromdict({'image': os.path.join(img_root, f)}) for f in os.listdir(img_root)]
# super().__init__(examples, image_field)
#
#
# class Nocaps(ImageDataset):
# def __init__(self, image_field, img_root, load_in_tmp=False):
# tmp_path = os.path.join('/tmp/nocaps')
# if load_in_tmp and sync_path(img_root, tmp_path, size=3.6*1024**3):
# img_root = tmp_path
#
# examples = []
# for split in ('validation', 'test'):
# examples += [Example.fromdict({'image': os.path.join(img_root, split, f)}) for f in os.listdir(img_root + '/' + split)]
# super().__init__(examples, image_field)
class PairedDataset(Dataset):
def __init__(self, examples, image_field, text_field):
super(PairedDataset, self).__init__(examples, {'image': image_field, 'text': text_field})
self.image_field = self.fields['image']
self.text_field = self.fields['text']
def image_set(self):
img_list = [e.image for e in self.examples]
image_set = unique(img_list)
examples = [Example.fromdict({'image': i}) for i in image_set]
dataset = Dataset(examples, {'image': self.image_field})
return dataset
def text_set(self):
text_list = [e.text for e in self.examples]
text_list = unique(text_list)
examples = [Example.fromdict({'text': t}) for t in text_list]
dataset = Dataset(examples, {'text': self.text_field})
return dataset
def image_dictionary(self, fields=None):
if not fields:
fields = self.fields
dataset = DictionaryDataset(self.examples, fields, key_fields='image')
return dataset
def text_dictionary(self, fields=None):
if not fields:
fields = self.fields
dataset = DictionaryDataset(self.examples, fields, key_fields='text')
return dataset
@property
def splits(self):
raise NotImplementedError
class COCO(PairedDataset):
def __init__(self, image_field, text_field, img_root, ann_root, id_root=None, use_restval=True,
cut_validation=False):
roots = {}
roots['train'] = {
'img': os.path.join(img_root, 'train2014'),
'cap': os.path.join(ann_root, 'captions_train2014.json')
}
roots['val'] = {
'img': os.path.join(img_root, 'val2014'),
'cap': os.path.join(ann_root, 'captions_val2014.json')
}
roots['test'] = {
'img': os.path.join(img_root, 'val2014'),
'cap': os.path.join(ann_root, 'captions_val2014.json')
}
roots['trainrestval'] = {
'img': (roots['train']['img'], roots['val']['img']),
'cap': (roots['train']['cap'], roots['val']['cap'])
}
if id_root is not None:
ids = {}
ids['train'] = np.load(os.path.join(id_root, 'coco_train_ids.npy'))
ids['val'] = np.load(os.path.join(id_root, 'coco_dev_ids.npy'))
if cut_validation:
ids['val'] = ids['val'][:5000]
ids['test'] = np.load(os.path.join(id_root, 'coco_test_ids.npy'))
ids['trainrestval'] = (
ids['train'],
np.load(os.path.join(id_root, 'coco_restval_ids.npy')))
if use_restval:
roots['train'] = roots['trainrestval']
ids['train'] = ids['trainrestval']
else:
ids = None
with nostdout():
self.train_examples, self.val_examples, self.test_examples = self.get_samples(roots, ids)
examples = self.train_examples + self.val_examples + self.test_examples
super(COCO, self).__init__(examples, image_field, text_field)
@property
def splits(self):
train_split = PairedDataset(self.train_examples, self.image_field, self.text_field)
val_split = PairedDataset(self.val_examples, self.image_field, self.text_field)
test_split = PairedDataset(self.test_examples, self.image_field, self.text_field)
return train_split, val_split, test_split
@classmethod
def get_samples(cls, roots, ids_dataset=None):
train_samples = []
val_samples = []
test_samples = []
for split in ['train', 'val', 'test']:
if isinstance(roots[split]['cap'], tuple):
coco_dataset = (pyCOCO(roots[split]['cap'][0]), pyCOCO(roots[split]['cap'][1]))
root = roots[split]['img']
else:
coco_dataset = (pyCOCO(roots[split]['cap']),)
root = (roots[split]['img'],)
if ids_dataset is None:
ids = list(coco_dataset.anns.keys())
else:
ids = ids_dataset[split]
if isinstance(ids, tuple):
bp = len(ids[0])
ids = list(ids[0]) + list(ids[1])
else:
bp = len(ids)
for index in range(len(ids)):
if index < bp:
coco = coco_dataset[0]
img_root = root[0]
else:
coco = coco_dataset[1]
img_root = root[1]
ann_id = ids[index]
caption = coco.anns[ann_id]['caption']
img_id = coco.anns[ann_id]['image_id']
filename = coco.loadImgs(img_id)[0]['file_name']
example = Example.fromdict({'image': os.path.join(img_root, filename), 'text': caption})
if split == 'train':
train_samples.append(example)
elif split == 'val':
val_samples.append(example)
elif split == 'test':
test_samples.append(example)
return train_samples, val_samples, test_samples

26
data/example.py

@ -0,0 +1,26 @@
class Example(object):
"""Defines a single training or test example.
Stores each column of the example as an attribute.
"""
@classmethod
def fromdict(cls, data):
ex = cls(data)
return ex
def __init__(self, data):
for key, val in data.items():
super(Example, self).__setattr__(key, val)
def __setattr__(self, key, value):
raise AttributeError
def __hash__(self):
return hash(tuple(x for x in self.__dict__.values()))
def __eq__(self, other):
this = tuple(x for x in self.__dict__.values())
other = tuple(x for x in other.__dict__.values())
return this == other
def __ne__(self, other):
return not self.__eq__(other)

127
data/field.py

@ -0,0 +1,127 @@
# coding: utf8
from itertools import takewhile
import torch
from torch.utils.data.dataloader import default_collate
from torchvision.datasets.folder import default_loader
from .tokenizer.simple_tokenizer import SimpleTokenizer as _Tokenizer
class RawField(object):
""" Defines a general datatype.
Every dataset consists of one or more types of data. For instance,
a machine translation dataset contains paired examples of text, while
an image captioning dataset contains images and texts.
Each of these types of data is represented by a RawField object.
An RawField object does not assume any property of the data type and
it holds parameters relating to how a datatype should be processed.
Attributes:
preprocessing: The Pipeline that will be applied to examples
using this field before creating an example.
Default: None.
postprocessing: A Pipeline that will be applied to a list of examples
using this field before assigning to a batch.
Function signature: (batch(list)) -> object
Default: None.
"""
def __init__(self, preprocessing=None, postprocessing=None):
self.preprocessing = preprocessing
self.postprocessing = postprocessing
def preprocess(self, x):
""" Preprocess an example if the `preprocessing` Pipeline is provided. """
if self.preprocessing is not None:
return self.preprocessing(x)
else:
return x
def process(self, batch, *args, **kwargs):
""" Process a list of examples to create a batch.
Postprocess the batch with user-provided Pipeline.
Args:
batch (list(object)): A list of object from a batch of examples.
Returns:
object: Processed object given the input and custom
postprocessing Pipeline.
"""
if self.postprocessing is not None:
batch = self.postprocessing(batch)
return default_collate(batch)
class Merge(RawField):
def __init__(self, *fields):
super(Merge, self).__init__()
self.fields = fields
def preprocess(self, x):
return tuple(f.preprocess(x) for f in self.fields)
def process(self, batch, *args, **kwargs):
if len(self.fields) == 1:
batch = [batch, ]
else:
batch = list(zip(*batch))
out = list(f.process(b, *args, **kwargs) for f, b in zip(self.fields, batch))
return out
class ImageField(RawField):
def __init__(self, preprocessing=None, postprocessing=None, loader=default_loader, transform=None):
self.loader = loader
self.transform = transform
super().__init__(preprocessing, postprocessing)
def preprocess(self, x):
sample = self.loader(x)
if self.transform is not None:
sample = self.transform(sample)
return sample
class TextField(RawField):
def __init__(self):
self._tokenizer = _Tokenizer()
super(TextField, self).__init__()
def preprocess(self, x):
if x is None:
return ''
return x
def process(self, texts):
if isinstance(texts, str):
texts = [texts]
sot_token = self._tokenizer.bos_idx
eot_token = self._tokenizer.eos_idx
all_tokens = [[sot_token] + self._tokenizer.encode(text) + [eot_token] for text in texts]
result = torch.zeros(len(all_tokens), max(len(s) for s in all_tokens), dtype=torch.long)
for i, tokens in enumerate(all_tokens):
result[i, :len(tokens)] = torch.tensor(tokens)
return result
def decode(self, word_idxs):
if isinstance(word_idxs, list) and len(word_idxs) == 0:
return self.decode([word_idxs, ])[0]
if isinstance(word_idxs, list) and isinstance(word_idxs[0], int):
return self.decode([word_idxs, ])[0]
elif isinstance(word_idxs, torch.Tensor) and word_idxs.ndimension() == 1:
return self.decode(word_idxs.unsqueeze(0))[0]
captions = []
for wis in word_idxs:
wis = wis.tolist()
wis = list(takewhile(lambda tok: tok != self._tokenizer.eos_idx, wis))
caption = self._tokenizer.decode(wis)
captions.append(caption)
return captions

0
data/tokenizer/__init__.py

BIN
data/tokenizer/bpe_simple_vocab_16e6.txt.gz (Stored with Git LFS)

Binary file not shown.

144
data/tokenizer/simple_tokenizer.py

@ -0,0 +1,144 @@
import gzip
import html
import os
from functools import lru_cache
import ftfy
import regex as re
@lru_cache()
def default_bpe():
return os.path.join(os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz")
@lru_cache()
def bytes_to_unicode():
"""
Returns list of utf-8 byte and a corresponding list of unicode strings.
The reversible bpe codes work on unicode strings.
This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
This is a signficant percentage of your normal, say, 32K bpe vocab.
To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
And avoids mapping to whitespace/control characters the bpe code barfs on.
"""
bs = list(range(ord("!"), ord("~")+1))+list(range(ord("¡"), ord("¬")+1))+list(range(ord("®"), ord("ÿ")+1))
cs = bs[:]
n = 0
for b in range(2**8):
if b not in bs:
bs.append(b)
cs.append(2**8+n)
n += 1
cs = [chr(n) for n in cs]
return dict(zip(bs, cs))
def get_pairs(word):
"""Return set of symbol pairs in a word.
Word is represented as tuple of symbols (symbols being variable-length strings).
"""
pairs = set()
prev_char = word[0]
for char in word[1:]:
pairs.add((prev_char, char))
prev_char = char
return pairs
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r'\s+', ' ', text)
text = text.strip()
return text
class SimpleTokenizer(object):
def __init__(self, bpe_path: str = default_bpe()):
self.byte_encoder = bytes_to_unicode()
self.byte_decoder = {v: k for k, v in self.byte_encoder.items()}
merges = gzip.open(bpe_path).read().decode("utf-8").split('\n')
merges = merges[1:49152-256-2+1]
merges = [tuple(merge.split()) for merge in merges]
vocab = list(bytes_to_unicode().values())
vocab = vocab + [v+'</w>' for v in vocab]
for merge in merges:
vocab.append(''.join(merge))
vocab.extend(['<|startoftext|>', '<|endoftext|>'])
self.encoder = dict(zip(vocab, range(len(vocab))))
self.decoder = {v: k for k, v in self.encoder.items()}
self.bpe_ranks = dict(zip(merges, range(len(merges))))
self.cache = {'<|startoftext|>': '<|startoftext|>', '<|endoftext|>': '<|endoftext|>'}
self.pat = re.compile(r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", re.IGNORECASE)
@property
def vocab_size(self):
return len(self.encoder)
@property
def eos_idx(self):
return self.encoder['<|endoftext|>']
@property
def bos_idx(self):
return self.encoder['<|startoftext|>']
def bpe(self, token):
if token in self.cache:
return self.cache[token]
word = tuple(token[:-1]) + ( token[-1] + '</w>',)
pairs = get_pairs(word)
if not pairs:
return token+'</w>'
while True:
bigram = min(pairs, key = lambda pair: self.bpe_ranks.get(pair, float('inf')))
if bigram not in self.bpe_ranks:
break
first, second = bigram
new_word = []
i = 0
while i < len(word):
try:
j = word.index(first, i)
new_word.extend(word[i:j])
i = j
except:
new_word.extend(word[i:])
break
if word[i] == first and i < len(word)-1 and word[i+1] == second:
new_word.append(first+second)
i += 2
else:
new_word.append(word[i])
i += 1
new_word = tuple(new_word)
word = new_word
if len(word) == 1:
break
else:
pairs = get_pairs(word)
word = ' '.join(word)
self.cache[token] = word
return word
def encode(self, text):
bpe_tokens = []
text = whitespace_clean(basic_clean(text)).lower()
for token in re.findall(self.pat, text):
token = ''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
return bpe_tokens
def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
return text

BIN
models/.DS_Store

Binary file not shown.

1
models/__init__.py

@ -0,0 +1 @@
from .transformer.captioner import Captioner

1
models/beam_search/__init__.py

@ -0,0 +1 @@
from .beam_search import BeamSearch

149
models/beam_search/beam_search.py

@ -0,0 +1,149 @@
import torch
import utils
class BeamSearch(object):
def __init__(self, model, max_len: int, eos_idx: int, beam_size: int):
self.model = model
self.max_len = max_len
self.eos_idx = eos_idx
self.beam_size = beam_size
self.b_s = None
self.device = None
self.seq_mask = None
self.seq_logprob = None
self.outputs = None
self.log_probs = None
self.selected_words = None
self.all_logits = None
def _expand_state(self, selected_beam, cur_beam_size):
def fn(s):
shape = [int(sh) for sh in s.shape]
beam = selected_beam
for _ in shape[1:]:
beam = beam.unsqueeze(-1)
s = torch.gather(s.view(*([self.b_s, cur_beam_size] + shape[1:])), 1,
beam.expand(*([self.b_s, self.beam_size] + shape[1:])))
s = s.view(*([-1, ] + shape[1:]))
return s
return fn
def _expand_visual(self, visual: utils.TensorOrSequence, cur_beam_size: int, selected_beam: torch.Tensor):
if isinstance(visual, torch.Tensor):
visual_shape = visual.shape
visual_exp_shape = (self.b_s, cur_beam_size) + visual_shape[1:]
visual_red_shape = (self.b_s * self.beam_size,) + visual_shape[1:]
selected_beam_red_size = (self.b_s, self.beam_size) + tuple(1 for _ in range(len(visual_exp_shape) - 2))
selected_beam_exp_size = (self.b_s, self.beam_size) + visual_exp_shape[2:]
visual_exp = visual.view(visual_exp_shape)
selected_beam_exp = selected_beam.view(selected_beam_red_size).expand(selected_beam_exp_size)
visual = torch.gather(visual_exp, 1, selected_beam_exp).view(visual_red_shape)
else:
new_visual = []
for im in visual:
visual_shape = im.shape
visual_exp_shape = (self.b_s, cur_beam_size) + visual_shape[1:]
visual_red_shape = (self.b_s * self.beam_size,) + visual_shape[1:]
selected_beam_red_size = (self.b_s, self.beam_size) + tuple(1 for _ in range(len(visual_exp_shape) - 2))
selected_beam_exp_size = (self.b_s, self.beam_size) + visual_exp_shape[2:]
visual_exp = im.view(visual_exp_shape)
selected_beam_exp = selected_beam.view(selected_beam_red_size).expand(selected_beam_exp_size)
new_im = torch.gather(visual_exp, 1, selected_beam_exp).view(visual_red_shape)
new_visual.append(new_im)
visual = tuple(new_visual)
return visual
def apply(self, visual: utils.TensorOrSequence, out_size=1, return_logits=False, **kwargs):
self.b_s = utils.get_batch_size(visual)
self.device = utils.get_device(visual)
self.seq_mask = torch.ones((self.b_s, self.beam_size, 1), device=self.device)
self.seq_logprob = torch.zeros((self.b_s, 1, 1), device=self.device)
self.log_probs = []
self.selected_words = None
if return_logits:
self.all_logits = []
outputs = []
with self.model.statefulness(self.b_s):
for t in range(self.max_len):
visual, outputs = self.iter(t, visual, outputs, return_logits, **kwargs)
# Sort result
seq_logprob, sort_idxs = torch.sort(self.seq_logprob, 1, descending=True)
outputs = torch.cat(outputs, -1)
outputs = torch.gather(outputs, 1, sort_idxs.expand(self.b_s, self.beam_size, self.max_len))
log_probs = torch.cat(self.log_probs, -1)
log_probs = torch.gather(log_probs, 1, sort_idxs.expand(self.b_s, self.beam_size, self.max_len))
outputs = outputs.contiguous()[:, :out_size]
log_probs = log_probs.contiguous()[:, :out_size]
if return_logits:
all_logits = torch.cat(self.all_logits, 2)
all_logits = torch.gather(all_logits, 1, sort_idxs.unsqueeze(-1).expand(self.b_s, self.beam_size,
self.max_len,
all_logits.shape[-1]))
all_logits = all_logits.contiguous()[:, :out_size]
if out_size == 1:
outputs = outputs.squeeze(1)
log_probs = log_probs.squeeze(1)
if return_logits:
all_logits = all_logits.squeeze(1)
if return_logits:
return outputs, log_probs, all_logits
else:
return outputs, log_probs
def select(self, t, candidate_logprob, **kwargs):
selected_logprob, selected_idx = torch.sort(candidate_logprob.view(self.b_s, -1), -1, descending=True)
selected_logprob, selected_idx = selected_logprob[:, :self.beam_size], selected_idx[:, :self.beam_size]
return selected_idx, selected_logprob
def iter(self, t: int, visual: utils.TensorOrSequence, outputs, return_logits, **kwargs):
cur_beam_size = 1 if t == 0 else self.beam_size
word_logits = self.model.step(t, self.selected_words, visual, **kwargs)
word_logits = word_logits.view(self.b_s, cur_beam_size, -1)
word_logprob = torch.log_softmax(word_logits, dim=-1)
candidate_logprob = self.seq_logprob + word_logprob
# Mask sequence if it reaches EOS
if t > 0:
mask = (self.selected_words.view(self.b_s, cur_beam_size) != self.eos_idx).type(visual.dtype).unsqueeze(-1)
self.seq_mask = self.seq_mask * mask
word_logprob = word_logprob * self.seq_mask.expand_as(word_logprob)
old_seq_logprob = self.seq_logprob.expand_as(candidate_logprob).contiguous()
old_seq_logprob[:, :, 1:] = -999
candidate_logprob = self.seq_mask * candidate_logprob + old_seq_logprob * (1 - self.seq_mask)
selected_idx, selected_logprob = self.select(t, candidate_logprob, **kwargs)
selected_beam = torch.floor_divide(selected_idx, candidate_logprob.shape[-1])
selected_words = selected_idx - selected_beam * candidate_logprob.shape[-1]
self.model.apply_to_states(self._expand_state(selected_beam, cur_beam_size))
visual = self._expand_visual(visual, cur_beam_size, selected_beam)
self.seq_logprob = selected_logprob.unsqueeze(-1)
self.seq_mask = torch.gather(self.seq_mask, 1, selected_beam.unsqueeze(-1))
outputs = list(torch.gather(o, 1, selected_beam.unsqueeze(-1)) for o in outputs)
outputs.append(selected_words.unsqueeze(-1))
if return_logits:
if t == 0:
self.all_logits.append(word_logits.expand((self.b_s, self.beam_size, -1)).unsqueeze(2))
else:
self.all_logits.append(word_logits.unsqueeze(2))
this_word_logprob = torch.gather(word_logprob, 1,
selected_beam.unsqueeze(-1).expand(self.b_s, self.beam_size,
word_logprob.shape[-1]))
this_word_logprob = torch.gather(this_word_logprob, 2, selected_words.unsqueeze(-1))
self.log_probs = list(
torch.gather(o, 1, selected_beam.unsqueeze(-1).expand(self.b_s, self.beam_size, 1)) for o in self.log_probs)
self.log_probs.append(this_word_logprob)
self.selected_words = selected_words.view(-1, 1)
return visual, outputs

619
models/clip.py

@ -0,0 +1,619 @@
import hashlib
import math
import os
import urllib
import warnings
from collections import OrderedDict
from typing import Tuple
from typing import Union, List
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
try:
from torchvision.transforms import InterpolationMode
BICUBIC = InterpolationMode.BICUBIC
except ImportError:
BICUBIC = Image.BICUBIC
from torch import nn
from tqdm import tqdm
from models.utils import one_hot_to_index
_MODELS = {
"RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt",
"RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt",
"RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt",
"RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt",
"ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt",
"ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt",
}
def _download(url: str, root: str = os.path.expanduser("~/.cache/clip")):
os.makedirs(root, exist_ok=True)
filename = os.path.basename(url)
expected_sha256 = url.split("/")[-2]
download_target = os.path.join(root, filename)
if os.path.exists(download_target) and not os.path.isfile(download_target):
raise RuntimeError(f"{download_target} exists and is not a regular file")
if os.path.isfile(download_target):
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() == expected_sha256:
return download_target
else:
warnings.warn(f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file")
with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
with tqdm(total=int(source.info().get("Content-Length")), ncols=80, unit='iB', unit_scale=True) as loop:
while True:
buffer = source.read(8192)
if not buffer:
break
output.write(buffer)
loop.update(len(buffer))
if hashlib.sha256(open(download_target, "rb").read()).hexdigest() != expected_sha256:
raise RuntimeError(f"Model has been downloaded but the SHA256 checksum does not not match")
return download_target
def _transform(n_px):
return Compose([
Resize(n_px, interpolation=BICUBIC),
CenterCrop(n_px),
lambda image: image.convert("RGB"),
ToTensor(),
Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
def available_models() -> List[str]:
"""Returns the names of available CLIP models"""
return list(_MODELS.keys())
def load(name: str, device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", jit=True):
"""Load a CLIP model
Parameters
----------
name : str
A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
device : Union[str, torch.device]
The device to put the loaded model
jit : bool
Whether to load the optimized JIT model (default) or more hackable non-JIT model.
Returns
-------
model : torch.nn.Module
The CLIP model
preprocess : Callable[[PIL.Image], torch.Tensor]
A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
"""
if name in _MODELS:
model_path = _download(_MODELS[name])
elif os.path.isfile(name):
model_path = name
else:
raise RuntimeError(f"Model {name} not found; available models = {available_models()}")
try:
# loading JIT archive
model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
state_dict = None
except RuntimeError:
# loading saved state dict
if jit:
warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
jit = False
state_dict = torch.load(model_path, map_location="cpu")
if not jit:
model = build_model(state_dict or model.state_dict()).to(device)
if str(device) == "cpu":
model.float()
return model, _transform(model.visual.input_resolution)
# patch the device names
device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
def patch_device(module):
graphs = [module.graph] if hasattr(module, "graph") else []
if hasattr(module, "forward1"):
graphs.append(module.forward1.graph)
for graph in graphs:
for node in graph.findAllNodes("prim::Constant"):
if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
node.copyAttributes(device_node)
model.apply(patch_device)
patch_device(model.encode_image)
patch_device(model.encode_text)
# patch dtype to float32 on CPU
if str(device) == "cpu":
float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
float_node = float_input.node()
def patch_float(module):
graphs = [module.graph] if hasattr(module, "graph") else []
if hasattr(module, "forward1"):
graphs.append(module.forward1.graph)
for graph in graphs:
for node in graph.findAllNodes("aten::to"):
inputs = list(node.inputs())
for i in [1, 2]: # dtype can be the second or third argument to aten::to()
if inputs[i].node()["value"] == 5:
inputs[i].node().copyAttributes(float_node)
model.apply(patch_float)
patch_float(model.encode_image)
patch_float(model.encode_text)
model.float()
return model, _transform(model.input_resolution.item())
class LayerNorm(nn.LayerNorm):
"""Subclass torch's LayerNorm to handle fp16."""
def forward(self, x: torch.Tensor):
orig_type = x.dtype
ret = super().forward(x.type(torch.float32))
return ret.type(orig_type)
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1):
super().__init__()
# all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1
self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity()
self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = None
self.stride = stride
if stride > 1 or inplanes != planes * Bottleneck.expansion:
# downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1
self.downsample = nn.Sequential(OrderedDict([
("-1", nn.AvgPool2d(stride)),
("0", nn.Conv2d(inplanes, planes * self.expansion, 1, stride=1, bias=False)),
("1", nn.BatchNorm2d(planes * self.expansion))
]))
def forward(self, x: torch.Tensor):
identity = x
out = self.relu(self.bn1(self.conv1(x)))
out = self.relu(self.bn2(self.conv2(out)))
out = self.avgpool(out)
out = self.bn3(self.conv3(out))
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class AttentionPool2d(nn.Module):
def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None):
super().__init__()
self.positional_embedding = nn.Parameter(torch.randn(spacial_dim ** 2 + 1, embed_dim) / embed_dim ** 0.5)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim)
self.num_heads = num_heads
def forward(self, x):
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute(2, 0, 1) # NCHW -> (HW)NC
x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC
x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC
x, _ = F.multi_head_attention_forward(
query=x, key=x, value=x,
embed_dim_to_check=x.shape[-1],
num_heads=self.num_heads,
q_proj_weight=self.q_proj.weight,
k_proj_weight=self.k_proj.weight,
v_proj_weight=self.v_proj.weight,
in_proj_weight=None,
in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]),
bias_k=None,
bias_v=None,
add_zero_attn=False,
dropout_p=0,
out_proj_weight=self.c_proj.weight,
out_proj_bias=self.c_proj.bias,
use_separate_proj_weight=True,
training=self.training,
need_weights=False
)
return x[0]
class ModifiedResNet(nn.Module):
"""
A ResNet class that is similar to torchvision's but contains the following changes:
- There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool.
- Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1
- The final pooling layer is a QKV attention instead of an average pool
"""
def __init__(self, layers, output_dim, heads, input_resolution=224, width=64):
super().__init__()
self.output_dim = output_dim
self.input_resolution = input_resolution
# the 3-layer stem
self.conv1 = nn.Conv2d(3, width // 2, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(width // 2)
self.conv2 = nn.Conv2d(width // 2, width // 2, kernel_size=3, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(width // 2)
self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(width)
self.avgpool = nn.AvgPool2d(2)
self.relu = nn.ReLU(inplace=True)
# residual layers
self._inplanes = width # this is a *mutable* variable used during construction
self.layer1 = self._make_layer(width, layers[0])
self.layer2 = self._make_layer(width * 2, layers[1], stride=2)
self.layer3 = self._make_layer(width * 4, layers[2], stride=2)
self.layer4 = self._make_layer(width * 8, layers[3], stride=2)
self.embed_dim = width * 32 # the ResNet feature dimension
self.attnpool = AttentionPool2d(input_resolution // 32, self.embed_dim, heads, output_dim)
def _make_layer(self, planes, blocks, stride=1):
layers = [Bottleneck(self._inplanes, planes, stride)]
self._inplanes = planes * Bottleneck.expansion
for _ in range(1, blocks):
layers.append(Bottleneck(self._inplanes, planes))
return nn.Sequential(*layers)
def intermediate_features(self, x):
def stem(x):
for conv, bn in [(self.conv1, self.bn1), (self.conv2, self.bn2), (self.conv3, self.bn3)]:
x = self.relu(bn(conv(x)))
x = self.avgpool(x)
return x
x = x.type(self.conv1.weight.dtype)
x = stem(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
b, c = x.shape[:2]
return x.view(b, c, -1).permute(0, 2, 1)
def forward(self, x):
x = self.intermediate_features(x)
x = x.permute(0, 2, 1)
l = int(math.sqrt(x.shape[-1]))
x = x.view(x.shape[0], x.shape[1], l, l)
x = self.attnpool(x)
return x
class QuickGELU(nn.Module):
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)
class ResidualAttentionBlock(nn.Module):
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
super().__init__()
self.attn = nn.MultiheadAttention(d_model, n_head)
self.ln_1 = LayerNorm(d_model)
self.mlp = nn.Sequential(OrderedDict([
("c_fc", nn.Linear(d_model, d_model * 4)),
("gelu", QuickGELU()),
("c_proj", nn.Linear(d_model * 4, d_model))
]))
self.ln_2 = LayerNorm(d_model)
self.attn_mask = attn_mask
def attention(self, x: torch.Tensor):
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
def forward(self, x: torch.Tensor):
x = x + self.attention(self.ln_1(x))
x = x + self.mlp(self.ln_2(x))
return x
class Transformer(nn.Module):
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):
super().__init__()
self.width = width
self.layers = layers
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
def forward(self, x: torch.Tensor):
return self.resblocks(x)
class VisualTransformer(nn.Module):
def __init__(self, input_resolution: int, patch_size: int, width: int, layers: int, heads: int, output_dim: int):
super().__init__()
self.input_resolution = input_resolution
self.embed_dim = width
self.output_dim = output_dim
self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False)
scale = width ** -0.5
self.class_embedding = nn.Parameter(scale * torch.randn(width))
self.positional_embedding = nn.Parameter(scale * torch.randn((input_resolution // patch_size) ** 2 + 1, width))
self.ln_pre = LayerNorm(width)
self.transformer = Transformer(width, layers, heads)
self.ln_post = LayerNorm(width)
self.proj = nn.Parameter(scale * torch.randn(width, output_dim))
def intermediate_features(self, x):
x = self.conv1(x) # shape = [*, width, grid, grid]
x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2]
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
x = torch.cat([self.class_embedding.to(x.dtype) + torch.zeros(x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device), x], dim=1) # shape = [*, grid ** 2 + 1, width]
x = x + self.positional_embedding.to(x.dtype)
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_post(x)
return x
def forward(self, x: torch.Tensor):
x = self.intermediate_features(x)
x_cls = x[:, 0, :]
if self.proj is not None:
x_cls = x_cls @ self.proj
return x_cls
class CLIP(nn.Module):
def __init__(self,
embed_dim: int,
# vision
image_resolution: int,
vision_layers: Union[Tuple[int, int, int, int], int],
vision_width: int,
vision_patch_size: int,
# text
context_length: int,
vocab_size: int,
transformer_width: int,
transformer_heads: int,
transformer_layers: int
):
super().__init__()
self.context_length = context_length
if isinstance(vision_layers, (tuple, list)):
vision_heads = vision_width * 32 // 64
self.visual = ModifiedResNet(
layers=vision_layers,
output_dim=embed_dim,
heads=vision_heads,
input_resolution=image_resolution,
width=vision_width
)
else:
vision_heads = vision_width // 64
self.visual = VisualTransformer(
input_resolution=image_resolution,
patch_size=vision_patch_size,
width=vision_width,
layers=vision_layers,
heads=vision_heads,
output_dim=embed_dim
)
self.transformer = Transformer(
width=transformer_width,
layers=transformer_layers,
heads=transformer_heads,
attn_mask=self.build_attention_mask()
)
self.vocab_size = vocab_size
self.transformer_width = transformer_width
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
self.ln_final = LayerNorm(transformer_width)
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) # todo remove
self.initialize_parameters()
def initialize_parameters(self):
nn.init.normal_(self.token_embedding.weight, std=0.02)
nn.init.normal_(self.positional_embedding, std=0.01)
if isinstance(self.visual, ModifiedResNet):
if self.visual.attnpool is not None:
std = self.visual.attnpool.c_proj.in_features ** -0.5
nn.init.normal_(self.visual.attnpool.q_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.k_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.v_proj.weight, std=std)
nn.init.normal_(self.visual.attnpool.c_proj.weight, std=std)
for resnet_block in [self.visual.layer1, self.visual.layer2, self.visual.layer3, self.visual.layer4]:
for name, param in resnet_block.named_parameters():
if name.endswith("bn3.weight"):
nn.init.zeros_(param)
proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)
attn_std = self.transformer.width ** -0.5
fc_std = (2 * self.transformer.width) ** -0.5
for block in self.transformer.resblocks:
nn.init.normal_(block.attn.in_proj_weight, std=attn_std)
nn.init.normal_(block.attn.out_proj.weight, std=proj_std)
nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)
nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)
if self.text_projection is not None:
nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)
def build_attention_mask(self):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(self.context_length, self.context_length)
mask.fill_(float("-inf"))
mask.triu_(1) # zero out the lower diagonal
return mask
@property
def dtype(self):
return self.visual.conv1.weight.dtype
@property
def device(self):
return self.visual.conv1.weight.device
def encode_image(self, image):
return self.visual(image.type(self.dtype))
def encode_text(self, text):
if text.dtype in [torch.long, torch.int]:
text_idxs = text
x = self.token_embedding(text).type(self.dtype) # [batch_size, n_ctx, d_model]
else:
text_idxs = one_hot_to_index(text)
x = (text @ self.token_embedding.weight).type(self.dtype)
x = x + self.positional_embedding.type(self.dtype)
x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x).type(self.dtype)
# x.shape = [batch_size, n_ctx, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]), text_idxs.argmax(dim=-1)] @ self.text_projection
return x
def forward(self, image, text):
image_features = self.encode_image(image)
text_features = self.encode_text(text)
# normalized features
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_image = logit_scale * image_features @ text_features.t()
logits_per_text = logit_scale * text_features @ image_features.t()
# shape = [global_batch_size, global_batch_size]
return logits_per_image#, logits_per_text
def convert_weights(model: nn.Module):
"""Convert applicable model parameters to fp16"""
def _convert_weights_to_fp16(l):
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)):
l.weight.data = l.weight.data.half()
if l.bias is not None:
l.bias.data = l.bias.data.half()
if isinstance(l, nn.MultiheadAttention):
for attr in [*[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], "in_proj_bias", "bias_k", "bias_v"]:
tensor = getattr(l, attr)
if tensor is not None:
tensor.data = tensor.data.half()
for name in ["text_projection", "proj"]:
if hasattr(l, name):
attr = getattr(l, name)
if attr is not None:
attr.data = attr.data.half()
model.apply(_convert_weights_to_fp16)
def build_model(state_dict: dict):
vit = "visual.proj" in state_dict
if vit:
vision_width = state_dict["visual.conv1.weight"].shape[0]
vision_layers = len([k for k in state_dict.keys() if k.startswith("visual.") and k.endswith(".attn.in_proj_weight")])
vision_patch_size = state_dict["visual.conv1.weight"].shape[-1]
grid_size = round((state_dict["visual.positional_embedding"].shape[0] - 1) ** 0.5)
image_resolution = vision_patch_size * grid_size
else:
counts: list = [len(set(k.split(".")[2] for k in state_dict if k.startswith(f"visual.layer{b}"))) for b in [1, 2, 3, 4]]
vision_layers = tuple(counts)
vision_width = state_dict["visual.layer1.0.conv1.weight"].shape[0]
output_width = round((state_dict["visual.attnpool.positional_embedding"].shape[0] - 1) ** 0.5)
vision_patch_size = None
assert output_width ** 2 + 1 == state_dict["visual.attnpool.positional_embedding"].shape[0]
image_resolution = output_width * 32
embed_dim = state_dict["text_projection"].shape[1]
context_length = state_dict["positional_embedding"].shape[0]
vocab_size = state_dict["token_embedding.weight"].shape[0]
transformer_width = state_dict["ln_final.weight"].shape[0]
transformer_heads = transformer_width // 64
transformer_layers = len(set(k.split(".")[2] for k in state_dict if k.startswith(f"transformer.resblocks")))
model = CLIP(
embed_dim,
image_resolution, vision_layers, vision_width, vision_patch_size,
context_length, vocab_size, transformer_width, transformer_heads, transformer_layers
)
for key in ["input_resolution", "context_length", "vocab_size"]:
if key in state_dict:
del state_dict[key]
# convert_weights(model) todo remove
model.load_state_dict(state_dict, strict=False)
return model.eval()

81
models/containers.py

@ -0,0 +1,81 @@
from contextlib import contextmanager
from torch import nn
from utils.typing import *
class Module(nn.Module):
def __init__(self):
super(Module, self).__init__()
self._is_stateful = False
self._state_names = []
self._state_defaults = dict()
def register_state(self, name: str, default: TensorOrNone):
self._state_names.append(name)
if default is None:
self._state_defaults[name] = None
else:
self._state_defaults[name] = default.clone().detach()
self.register_buffer(name, default)
def states(self):
for name in self._state_names:
yield self._buffers[name]
for m in self.children():
if isinstance(m, Module):
yield from m.states()
def apply_to_states(self, fn):
for name in self._state_names:
if self._buffers[name] is not None:
self._buffers[name] = fn(self._buffers[name])
for m in self.children():
if isinstance(m, Module):
m.apply_to_states(fn)
def _init_states(self, batch_size: int):
for name in self._state_names:
if self._state_defaults[name] is None:
self._buffers[name] = None
else:
self._buffers[name] = self._state_defaults[name].clone().detach().to(self._buffers[name].device)
self._buffers[name] = self._buffers[name].unsqueeze(0)
self._buffers[name] = self._buffers[name].expand([batch_size, ] + list(self._buffers[name].shape[1:]))
self._buffers[name] = self._buffers[name].contiguous()
def _reset_states(self):
for name in self._state_names:
if self._state_defaults[name] is None:
self._buffers[name] = None
else:
self._buffers[name] = self._state_defaults[name].clone().detach().to(self._buffers[name].device)
def enable_statefulness(self, batch_size: int):
for m in self.children():
if isinstance(m, Module):
m.enable_statefulness(batch_size)
self._init_states(batch_size)
self._is_stateful = True
def disable_statefulness(self):
for m in self.children():
if isinstance(m, Module):
m.disable_statefulness()
self._reset_states()
self._is_stateful = False
@contextmanager
def statefulness(self, batch_size: int):
self.enable_statefulness(batch_size)
try:
yield
finally:
self.disable_statefulness()
class ModuleList(nn.ModuleList, Module):
pass
class ModuleDict(nn.ModuleDict, Module):
pass

4
models/transformer/__init__.py

@ -0,0 +1,4 @@
from .attention import *
from .encoders import *
from .decoders import *
from .captioner import *

196
models/transformer/attention.py

@ -0,0 +1,196 @@
import numpy as np
import torch
from torch import nn
from models.containers import Module
class ScaledDotProductAttention(nn.Module):
"""
Scaled dot-product attention
"""
def __init__(self, d_model, d_k, d_v, h):
'''
:param d_model: Output dimensionality of the model
:param d_k: Dimensionality of queries and keys
:param d_v: Dimensionality of values
:param h: Number of heads
'''
super(ScaledDotProductAttention, self).__init__()
self.fc_q = nn.Linear(d_model, h * d_k)
self.fc_k = nn.Linear(d_model, h * d_k)
self.fc_v = nn.Linear(d_model, h * d_v)
self.fc_o = nn.Linear(h * d_v, d_model)
self.d_model = d_model
self.d_k = d_k
self.d_v = d_v
self.h = h
self.init_weights()
def init_weights(self):
nn.init.xavier_uniform_(self.fc_q.weight)
nn.init.xavier_uniform_(self.fc_k.weight)
nn.init.xavier_uniform_(self.fc_v.weight)
nn.init.xavier_uniform_(self.fc_o.weight)
nn.init.constant_(self.fc_q.bias, 0)
nn.init.constant_(self.fc_k.bias, 0)
nn.init.constant_(self.fc_v.bias, 0)
nn.init.constant_(self.fc_o.bias, 0)
def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):
"""
Computes
:param queries: Queries (b_s, nq, d_model)
:param keys: Keys (b_s, nk, d_model)
:param values: Values (b_s, nk, d_model)
:param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking.
:param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk).
:return:
"""
b_s, nq = queries.shape[:2]
nk = keys.shape[1]
q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k)
k = self.fc_k(keys).view(b_s, nk, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk)
v = self.fc_v(values).view(b_s, nk, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v)
att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk)
if attention_weights is not None:
att = att * attention_weights
if attention_mask is not None:
att = att.masked_fill(attention_mask, -np.inf)
att = torch.softmax(att, -1)
out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v)
out = self.fc_o(out) # (b_s, nq, d_model)
return out
class ScaledDotProductAttentionMemory(nn.Module):
"""
Scaled dot-product attention with memory
"""
def __init__(self, d_model, d_k, d_v, h, m):
"""
:param d_model: Output dimensionality of the model
:param d_k: Dimensionality of queries and keys
:param d_v: Dimensionality of values
:param h: Number of heads
:param m: Number of memory slots
"""
super(ScaledDotProductAttentionMemory, self).__init__()
self.fc_q = nn.Linear(d_model, h * d_k)
self.fc_k = nn.Linear(d_model, h * d_k)
self.fc_v = nn.Linear(d_model, h * d_v)
self.fc_o = nn.Linear(h * d_v, d_model)
self.d_model = d_model
self.d_k = d_k
self.d_v = d_v
self.h = h
self.m = m
if self.m > 0:
self.m_k = nn.Parameter(torch.FloatTensor(1, m, h * d_k))
self.m_v = nn.Parameter(torch.FloatTensor(1, m, h * d_v))
self.init_weights()
def init_weights(self):
nn.init.xavier_uniform_(self.fc_q.weight)
nn.init.xavier_uniform_(self.fc_k.weight)
nn.init.xavier_uniform_(self.fc_v.weight)
nn.init.xavier_uniform_(self.fc_o.weight)
nn.init.constant_(self.fc_q.bias, 0)
nn.init.constant_(self.fc_k.bias, 0)
nn.init.constant_(self.fc_v.bias, 0)
nn.init.constant_(self.fc_o.bias, 0)
if self.m > 0:
nn.init.normal_(self.m_k, 0, 1 / self.d_k)
nn.init.normal_(self.m_v, 0, 1 / self.m)
def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):
"""
Computes
:param queries: Queries (b_s, nq, d_model)
:param keys: Keys (b_s, nk, d_model)
:param values: Values (b_s, nk, d_model)
:param attention_mask: Mask over attention values (b_s, h, nq, nk). True indicates masking.
:param attention_weights: Multiplicative weights for attention values (b_s, h, nq, nk).
:return:
"""
b_s, nq = queries.shape[:2]
nk = keys.shape[1]
q = self.fc_q(queries).view(b_s, nq, self.h, self.d_k).permute(0, 2, 1, 3) # (b_s, h, nq, d_k)
if self.m > 0:
m_k = np.sqrt(self.d_k) * self.m_k.expand(b_s, self.m, self.h * self.d_k)
m_v = np.sqrt(self.m) * self.m_v.expand(b_s, self.m, self.h * self.d_v)
k = torch.cat([self.fc_k(keys), m_k], 1)
v = torch.cat([self.fc_v(values), m_v], 1)
else:
k = self.fc_k(keys)
v = self.fc_v(values)
k = k.view(b_s, nk + self.m, self.h, self.d_k).permute(0, 2, 3, 1) # (b_s, h, d_k, nk)
v = v.view(b_s, nk + self.m, self.h, self.d_v).permute(0, 2, 1, 3) # (b_s, h, nk, d_v)
att = torch.matmul(q, k) / np.sqrt(self.d_k) # (b_s, h, nq, nk)
if attention_weights is not None:
att = torch.cat([att[:, :, :, :nk] * attention_weights, att[:, :, :, nk:]], -1)
if attention_mask is not None:
att[:, :, :, :nk] = att[:, :, :, :nk].masked_fill(attention_mask, -np.inf)
att = torch.softmax(att, -1)
out = torch.matmul(att, v).permute(0, 2, 1, 3).contiguous().view(b_s, nq, self.h * self.d_v) # (b_s, nq, h*d_v)
out = self.fc_o(out) # (b_s, nq, d_model)
return out
class MultiHeadAttention(Module):
"""
Multi-head attention layer with Dropout and Layer Normalization.
"""
def __init__(self, d_model, d_k, d_v, h, dropout=.1, identity_map_reordering=False, can_be_stateful=False,
attention_module=None, attention_module_kwargs=None):
super(MultiHeadAttention, self).__init__()
self.identity_map_reordering = identity_map_reordering
if attention_module is not None:
if attention_module_kwargs is not None:
self.attention = attention_module(d_model=d_model, d_k=d_k, d_v=d_v, h=h, **attention_module_kwargs)
else:
self.attention = attention_module(d_model=d_model, d_k=d_k, d_v=d_v, h=h)
else:
self.attention = ScaledDotProductAttention(d_model=d_model, d_k=d_k, d_v=d_v, h=h)
self.dropout = nn.Dropout(p=dropout)
self.layer_norm = nn.LayerNorm(d_model)
self.can_be_stateful = can_be_stateful
if self.can_be_stateful:
self.register_state('running_keys', None)
self.register_state('running_values', None)
def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):
if self.can_be_stateful and self._is_stateful:
if self.running_keys is None:
self.running_keys = keys
self.running_values = values
else:
self.running_keys = torch.cat([self.running_keys, keys], 1)
self.running_values = torch.cat([self.running_values, values], 1)
keys = self.running_keys
values = self.running_values
if self.identity_map_reordering:
q_norm = self.layer_norm(queries)
k_norm = self.layer_norm(keys)
v_norm = self.layer_norm(values)
out = self.attention(q_norm, k_norm, v_norm, attention_mask, attention_weights)
out = queries + self.dropout(torch.relu(out))
else:
out = self.attention(queries, keys, values, attention_mask, attention_weights)
out = self.dropout(out)
out = self.layer_norm(queries + out)
return out

93
models/transformer/captioner.py

@ -0,0 +1,93 @@
import copy
from pathlib import Path
import torch
from torch import Tensor
from torch import nn
from data.field import TextField
from models.beam_search import *
from models.containers import ModuleList, Module
from utils import TensorOrSequence
from . import Encoder, Decoder, ScaledDotProductAttentionMemory, MeshedDecoder
class Captioner(Module):
def __init__(self, args, text_field: TextField):
super(Captioner, self).__init__()
self.encoder = Encoder(args.N_enc, 500, args.image_dim, d_model=args.d_model, d_ff=args.d_ff, h=args.head,
attention_module=ScaledDotProductAttentionMemory,
attention_module_kwargs={'m': args.m},
with_pe=args.with_pe, with_mesh=not args.disable_mesh)
if args.disable_mesh:
self.decoder = Decoder(text_field._tokenizer.vocab_size, 40, args.N_dec, d_model=args.d_model,
d_ff=args.d_ff, h=args.head)
else:
self.decoder = MeshedDecoder(text_field._tokenizer.vocab_size, 40, args.N_dec, args.N_enc,
d_model=args.d_model, d_ff=args.d_ff, h=args.head)
self.bos_idx = text_field._tokenizer.bos_idx
self.eos_idx = text_field._tokenizer.eos_idx
self.vocab_size = text_field._tokenizer.vocab_size
self.max_generation_length = self.decoder.max_len
self.register_state('enc_output', None)
self.register_state('mask_enc', None)
self.init_weights()
@property
def d_model(self):
return self.decoder.d_model
def train(self, mode: bool = True):
self.encoder.train(mode)
self.decoder.train(mode)
def init_weights(self):
for p in self.encoder.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
for p in self.decoder.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, images, seq):
enc_output, mask_enc = self.encoder(images)
dec_output = self.decoder(seq, enc_output, mask_enc)
return dec_output
def step(self, t: int, prev_output: Tensor, visual: Tensor) -> Tensor:
if t == 0:
self.enc_output, self.mask_enc = self.encoder(visual)
input = visual.data.new_full((visual.shape[0], 1), self.bos_idx, dtype=torch.long)
else:
input = prev_output
logits = self.decoder(input, self.enc_output, self.mask_enc)
return logits
def beam_search(self, visual: TensorOrSequence, beam_size: int, out_size=1,
return_logits=False, **kwargs):
bs = BeamSearch(self, self.max_generation_length, self.eos_idx, beam_size)
return bs.apply(visual, out_size, return_logits, **kwargs)
class CaptionerEnsemble(Captioner):
def __init__(self, model: Captioner, args, text_field, weight_files, weight_folder=None):
super(CaptionerEnsemble, self).__init__(args, text_field)
self.n = len(weight_files)
self.models = ModuleList([copy.deepcopy(model) for _ in range(self.n)])
for model_i, weight_file_i in zip(self.models, weight_files):
if Path(weight_file_i).is_absolute():
fname = Path(weight_file_i)
else:
fname = Path(weight_folder).joinpath(weight_file_i)
state_dict_i = torch.load(fname)['state_dict_t']
model_i.load_state_dict(state_dict_i)
def step(self, t, prev_output, visual):
out_ensemble = []
for model_i in self.models:
out_i = model_i.step(t, prev_output, visual)
out_ensemble.append(out_i.unsqueeze(0))
return torch.mean(torch.cat(out_ensemble, 0), dim=0)

199
models/transformer/decoders.py

@ -0,0 +1,199 @@
import torch
from torch import nn
import numpy as np
from models.transformer.attention import MultiHeadAttention
from models.transformer.utils import sinusoid_encoding_table, PositionWiseFeedForward
from models.containers import Module, ModuleList
from models.utils import one_hot_to_index
class MeshedDecoderLayer(Module):
def __init__(self, N_enc, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, self_att_module=None,
enc_att_module=None, self_att_module_kwargs=None, enc_att_module_kwargs=None):
super(MeshedDecoderLayer, self).__init__()
self.N_enc = N_enc
self.self_att = MultiHeadAttention(d_model, d_k, d_v, h, dropout, can_be_stateful=True,
attention_module=self_att_module,
attention_module_kwargs=self_att_module_kwargs)
self.enc_att = MultiHeadAttention(d_model, d_k, d_v, h, dropout, can_be_stateful=False,
attention_module=enc_att_module,
attention_module_kwargs=enc_att_module_kwargs)
self.pwff = PositionWiseFeedForward(d_model, d_ff, dropout)
self.fc_alpha = ModuleList([nn.Linear(d_model + d_model, d_model) for _ in range(N_enc)])
self.init_weights()
def init_weights(self):
for fc_alpha in self.fc_alpha:
nn.init.xavier_uniform_(fc_alpha.weight)
nn.init.constant_(fc_alpha.bias, 0)
def forward(self, input, enc_output, mask_pad, mask_self_att, mask_enc_att):
self_att = self.self_att(input, input, input, mask_self_att)
self_att = self_att * mask_pad
enc_att = None
for i in range(self.N_enc):
enc_att_i = self.enc_att(self_att, enc_output[:, i], enc_output[:, i], mask_enc_att) * mask_pad
alpha_i = torch.sigmoid(self.fc_alpha[i](torch.cat([self_att, enc_att_i], -1)))
if enc_att is None:
enc_att = enc_att_i * alpha_i
else:
enc_att += enc_att_i * alpha_i
enc_att /= np.sqrt(self.N_enc)
enc_att *= mask_pad
ff = self.pwff(enc_att)
ff = ff * mask_pad
return ff
class DecoderLayer(Module):
def __init__(self, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, self_att_module=None,
enc_att_module=None, self_att_module_kwargs=None, enc_att_module_kwargs=None):
super(DecoderLayer, self).__init__()
self.self_att = MultiHeadAttention(d_model, d_k, d_v, h, dropout, can_be_stateful=True,
attention_module=self_att_module,
attention_module_kwargs=self_att_module_kwargs)
self.enc_att = MultiHeadAttention(d_model, d_k, d_v, h, dropout, can_be_stateful=False,
attention_module=enc_att_module,
attention_module_kwargs=enc_att_module_kwargs)
self.pwff = PositionWiseFeedForward(d_model, d_ff, dropout)
def forward(self, input, enc_output, mask_pad, mask_self_att, mask_enc_att):
self_att = self.self_att(input, input, input, mask_self_att)
enc_att = self.enc_att(self_att, enc_output, enc_output, mask_enc_att)
ff = self.pwff(enc_att)
ff = ff * mask_pad
return ff
class MeshedDecoder(Module):
def __init__(self, vocab_size, max_len, N_dec, N_enc, padding_idx=0, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048,
dropout=.1,
self_att_module=None, enc_att_module=None, self_att_module_kwargs=None, enc_att_module_kwargs=None):
super(MeshedDecoder, self).__init__()
self.d_model = d_model
self.vocab_size = vocab_size
self.word_emb = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx)
self.pos_emb = nn.Embedding.from_pretrained(sinusoid_encoding_table(max_len + 1, d_model, 0), freeze=True)
self.layers = ModuleList(
[MeshedDecoderLayer(N_enc, d_model, d_k, d_v, h, d_ff, dropout, self_att_module=self_att_module,
enc_att_module=enc_att_module, self_att_module_kwargs=self_att_module_kwargs,
enc_att_module_kwargs=enc_att_module_kwargs) for _ in range(N_dec)])
self.fc = nn.Linear(d_model, vocab_size, bias=False)
self.max_len = max_len
self.padding_idx = padding_idx
self.N = N_dec
self.register_state('running_mask_self_attention', None)
self.register_state('running_seq', torch.zeros((1,)).long())
def forward(self, input, encoder_output_list, mask_encoder):
# input (b_s, seq_len)
input = input[:, :self.max_len]
b_s, seq_len = input.shape[:2]
if input.dtype in [torch.long, torch.int]:
input_index = input
else:
input_index = one_hot_to_index(input)
mask_queries = (input_index != self.padding_idx).unsqueeze(-1).type(input.dtype) # (b_s, seq_len, 1)
mask_self_attention = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.bool, device=input.device),
diagonal=1)
mask_self_attention = mask_self_attention.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len)
mask_self_attention = mask_self_attention + (input_index == self.padding_idx).unsqueeze(1).unsqueeze(1).bool()
mask_self_attention = mask_self_attention.gt(0) # (b_s, 1, seq_len, seq_len)
if self._is_stateful:
if self.running_mask_self_attention is None:
self.running_mask_self_attention = mask_self_attention
else:
self.running_mask_self_attention = torch.cat([self.running_mask_self_attention, mask_self_attention],
-1)
mask_self_attention = self.running_mask_self_attention
seq = torch.arange(1, seq_len + 1).view(1, -1).expand(b_s, -1).to(input.device) # (b_s, seq_len)
seq = seq.masked_fill(mask_queries.squeeze(-1) == 0, 0)
if self._is_stateful:
self.running_seq.add_(1)
seq = self.running_seq
if input.dtype in [torch.long, torch.int]:
out = self.word_emb(input)
else:
out = input @ self.word_emb.weight
out = out + self.pos_emb(seq)
for i, l in enumerate(self.layers):
out = l(out, encoder_output_list, mask_queries, mask_self_attention, mask_encoder)
out = self.fc(out)
return out
class Decoder(Module):
def __init__(self, vocab_size, max_len, N_dec, padding_idx=0, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048,
dropout=.1, self_att_module=None, enc_att_module=None, self_att_module_kwargs=None,
enc_att_module_kwargs=None):
super(Decoder, self).__init__()
self.d_model = d_model
self.vocab_size = vocab_size
self.word_emb = nn.Embedding(vocab_size, d_model, padding_idx=padding_idx)
self.pos_emb = nn.Embedding.from_pretrained(sinusoid_encoding_table(max_len + 1, d_model, 0), freeze=True)
self.layers = ModuleList(
[DecoderLayer(d_model, d_k, d_v, h, d_ff, dropout, self_att_module=self_att_module,
enc_att_module=enc_att_module, self_att_module_kwargs=self_att_module_kwargs,
enc_att_module_kwargs=enc_att_module_kwargs) for _ in range(N_dec)])
self.fc = nn.Linear(d_model, vocab_size, bias=False)
self.max_len = max_len
self.padding_idx = padding_idx
self.N = N_dec
self.register_state('running_mask_self_attention', None)
self.register_state('running_seq', torch.zeros((1,)).long())
def forward(self, input, encoder_output, mask_encoder):
# input (b_s, seq_len)
input = input[:, :self.max_len]
b_s, seq_len = input.shape[:2]
if input.dtype in [torch.long, torch.int]:
input_index = input
else:
input_index = one_hot_to_index(input)
mask_queries = (input_index != self.padding_idx).unsqueeze(-1).type(input.dtype) # (b_s, seq_len, 1)
mask_self_attention = torch.triu(torch.ones((seq_len, seq_len), dtype=torch.bool, device=input.device),
diagonal=1)
mask_self_attention = mask_self_attention.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, seq_len)
mask_self_attention = mask_self_attention + (input_index == self.padding_idx).unsqueeze(1).unsqueeze(1).bool()
mask_self_attention = mask_self_attention.gt(0) # (b_s, 1, seq_len, seq_len)
if self._is_stateful:
if self.running_mask_self_attention is None:
self.running_mask_self_attention = mask_self_attention
else:
self.running_mask_self_attention = torch.cat([self.running_mask_self_attention, mask_self_attention],
-1)
mask_self_attention = self.running_mask_self_attention
seq = torch.arange(1, seq_len + 1).view(1, -1).expand(b_s, -1).to(input.device) # (b_s, seq_len)
seq = seq.masked_fill(mask_queries.squeeze(-1) == 0, 0)
if self._is_stateful:
self.running_seq.add_(1)
seq = self.running_seq
if input.dtype in [torch.long, torch.int]:
out = self.word_emb(input)
else:
out = input @ self.word_emb.weight
out = out + self.pos_emb(seq)
for i, l in enumerate(self.layers):
out = l(out, encoder_output, mask_queries, mask_self_attention, mask_encoder)
out = self.fc(out)
return out

65
models/transformer/encoders.py

@ -0,0 +1,65 @@
from torch.nn import functional as F
from models.transformer.utils import sinusoid_encoding_table, PositionWiseFeedForward
import torch
from torch import nn
from models.transformer.attention import MultiHeadAttention
class EncoderLayer(nn.Module):
def __init__(self, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1, identity_map_reordering=False,
attention_module=None, attention_module_kwargs=None):
super(EncoderLayer, self).__init__()
self.identity_map_reordering = identity_map_reordering
self.mhatt = MultiHeadAttention(d_model, d_k, d_v, h, dropout, identity_map_reordering=identity_map_reordering,
attention_module=attention_module,
attention_module_kwargs=attention_module_kwargs)
self.pwff = PositionWiseFeedForward(d_model, d_ff, dropout, identity_map_reordering=identity_map_reordering)
def forward(self, queries, keys, values, attention_mask=None, attention_weights=None):
att = self.mhatt(queries, keys, values, attention_mask, attention_weights)
ff = self.pwff(att)
return ff
class Encoder(nn.Module):
def __init__(self, N, max_len, d_in, d_model=512, d_k=64, d_v=64, h=8, d_ff=2048, dropout=.1,
identity_map_reordering=False, attention_module=None, attention_module_kwargs=None,
with_pe=False, with_mesh=False):
super(Encoder, self).__init__()
self.d_in = d_in
self.d_model = d_model
self.dropout = dropout
self.layers = nn.ModuleList([EncoderLayer(d_model, d_k, d_v, h, d_ff, dropout,
identity_map_reordering=identity_map_reordering,
attention_module=attention_module,
attention_module_kwargs=attention_module_kwargs)
for _ in range(N)])
self.pos_emb = nn.Embedding.from_pretrained(sinusoid_encoding_table(max_len + 1, self.d_in, 0), freeze=True)
self.fc = nn.Linear(d_in, self.d_model)
self.dropout = nn.Dropout(p=self.dropout)
self.layer_norm = nn.LayerNorm(self.d_model)
self.with_pe = with_pe
self.with_mesh = with_mesh
def forward(self, input):
# input (b_s, seq_len, d_in)
b_s, seq_len = input.shape[:2]
seq = torch.arange(1, seq_len + 1, device=input.device).view(1, -1).expand(b_s, -1) # (b_s, seq_len)
out = input
if self.with_pe:
out = out + self.pos_emb(seq)
out = F.relu(self.fc(out))
out = self.dropout(out)
out = self.layer_norm(out)
outs = list()
for l in self.layers:
out = l(out, out, out)
if self.with_mesh:
outs.append(out.unsqueeze(1))
if self.with_mesh:
outs = torch.cat(outs, 1)
return outs, None
return out, None

50
models/transformer/utils.py

@ -0,0 +1,50 @@
import torch
from torch import nn
from torch.nn import functional as F
def position_embedding(input, d_model):
input = input.view(-1, 1)
dim = torch.arange(d_model // 2, dtype=input.dtype, device=input.device).view(1, -1)
sin = torch.sin(input / 10000 ** (2 * dim / d_model))
cos = torch.cos(input / 10000 ** (2 * dim / d_model))
out = torch.zeros((input.shape[0], d_model), device=input.device)
out[:, ::2] = sin
out[:, 1::2] = cos
return out
def sinusoid_encoding_table(max_len, d_model, padding_idx=None, dtype=torch.float32):
pos = torch.arange(max_len, dtype=dtype)
out = position_embedding(pos, d_model)
if padding_idx is not None:
out[padding_idx] = 0
return out
class PositionWiseFeedForward(nn.Module):
"""
Position-wise feed forward layer
"""
def __init__(self, d_model=512, d_ff=2048, dropout=.1, identity_map_reordering=False):
super(PositionWiseFeedForward, self).__init__()
self.identity_map_reordering = identity_map_reordering
self.fc1 = nn.Linear(d_model, d_ff)
self.fc2 = nn.Linear(d_ff, d_model)
self.dropout = nn.Dropout(p=dropout)
self.dropout_2 = nn.Dropout(p=dropout)
self.layer_norm = nn.LayerNorm(d_model)
def forward(self, input):
if self.identity_map_reordering:
out = self.layer_norm(input)
out = self.fc2(self.dropout_2(F.relu(self.fc1(out))))
out = input + self.dropout(torch.relu(out))
else:
out = self.fc2(self.dropout_2(F.relu(self.fc1(input))))
out = self.dropout(out)
out = self.layer_norm(input + out)
return out

12
models/utils.py

@ -0,0 +1,12 @@
import torch
from torch import Tensor
def one_hot_to_index(one_hot: Tensor) -> Tensor:
"""
Converts a one-hot tensor into a tensor with corresponding indexes
"""
device, dtype = one_hot.device, one_hot.dtype
vocab_size = one_hot.shape[-1]
oh2idx = torch.tensor(range(vocab_size), dtype=dtype, device=device)
return (one_hot @ oh2idx.unsqueeze(dim=1)).long().squeeze(dim=-1)
Loading…
Cancel
Save