You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
318 lines
12 KiB
318 lines
12 KiB
2 years ago
|
from functools import partial
|
||
|
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
from timm.models.efficientnet import tf_efficientnet_b7_ns
|
||
|
from torch import nn
|
||
|
from torch.nn.modules.dropout import Dropout
|
||
|
from torch.nn.modules.linear import Linear
|
||
|
from torch.nn.modules.pooling import AdaptiveAvgPool2d
|
||
|
#from facebook_deit import deit_base_patch16_224, deit_distill_large_patch16_384, deit_distill_large_patch32_384
|
||
|
#from taming_transformer import Decoder, VUNet, ActNorm
|
||
|
import functools
|
||
|
#from vit_pytorch.distill import DistillableViT, DistillWrapper, DistillableEfficientViT
|
||
|
import re
|
||
|
|
||
|
encoder_params = {
|
||
|
"tf_efficientnet_b7_ns": {
|
||
|
"features": 2560,
|
||
|
"init_op": partial(tf_efficientnet_b7_ns, pretrained=True, drop_path_rate=0.2)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
class GlobalWeightedAvgPool2d(nn.Module):
|
||
|
"""
|
||
|
Global Weighted Average Pooling from paper "Global Weighted Average
|
||
|
Pooling Bridges Pixel-level Localization and Image-level Classification"
|
||
|
"""
|
||
|
|
||
|
def __init__(self, features: int, flatten=False):
|
||
|
super().__init__()
|
||
|
self.conv = nn.Conv2d(features, 1, kernel_size=1, bias=True)
|
||
|
self.flatten = flatten
|
||
|
|
||
|
def fscore(self, x):
|
||
|
m = self.conv(x)
|
||
|
m = m.sigmoid().exp()
|
||
|
return m
|
||
|
|
||
|
def norm(self, x: torch.Tensor):
|
||
|
return x / x.sum(dim=[2, 3], keepdim=True)
|
||
|
|
||
|
def forward(self, x):
|
||
|
input_x = x
|
||
|
x = self.fscore(x)
|
||
|
x = self.norm(x)
|
||
|
x = x * input_x
|
||
|
x = x.sum(dim=[2, 3], keepdim=not self.flatten)
|
||
|
return x
|
||
|
|
||
|
|
||
|
class DeepFakeClassifier(nn.Module):
|
||
|
def __init__(self, encoder, dropout_rate=0.0) -> None:
|
||
|
super().__init__()
|
||
|
self.encoder = encoder_params[encoder]["init_op"]()
|
||
|
self.avg_pool = AdaptiveAvgPool2d((1, 1))
|
||
|
self.dropout = Dropout(dropout_rate)
|
||
|
self.fc = Linear(encoder_params[encoder]["features"], 1)
|
||
|
|
||
|
def forward(self, x):
|
||
|
x = self.encoder.forward_features(x)
|
||
|
x = self.avg_pool(x).flatten(1)
|
||
|
x = self.dropout(x)
|
||
|
x = self.fc(x)
|
||
|
return x
|
||
|
|
||
|
class NLayerDiscriminator(nn.Module):
|
||
|
"""Defines a PatchGAN discriminator as in Pix2Pix
|
||
|
--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
|
||
|
"""
|
||
|
def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
|
||
|
"""Construct a PatchGAN discriminator
|
||
|
Parameters:
|
||
|
input_nc (int) -- the number of channels in input images
|
||
|
ndf (int) -- the number of filters in the last conv layer
|
||
|
n_layers (int) -- the number of conv layers in the discriminator
|
||
|
norm_layer -- normalization layer
|
||
|
"""
|
||
|
super(NLayerDiscriminator, self).__init__()
|
||
|
if not use_actnorm:
|
||
|
norm_layer = nn.BatchNorm2d
|
||
|
else:
|
||
|
norm_layer = ActNorm
|
||
|
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters
|
||
|
use_bias = norm_layer.func != nn.BatchNorm2d
|
||
|
else:
|
||
|
use_bias = norm_layer != nn.BatchNorm2d
|
||
|
|
||
|
kw = 4
|
||
|
padw = 1
|
||
|
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
|
||
|
nf_mult = 1
|
||
|
nf_mult_prev = 1
|
||
|
for n in range(1, n_layers): # gradually increase the number of filters
|
||
|
nf_mult_prev = nf_mult
|
||
|
nf_mult = min(2 ** n, 8)
|
||
|
sequence += [
|
||
|
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
||
|
norm_layer(ndf * nf_mult),
|
||
|
nn.LeakyReLU(0.2, True)
|
||
|
]
|
||
|
|
||
|
nf_mult_prev = nf_mult
|
||
|
nf_mult = min(2 ** n_layers, 8)
|
||
|
sequence += [
|
||
|
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
|
||
|
norm_layer(ndf * nf_mult),
|
||
|
nn.LeakyReLU(0.2, True)
|
||
|
]
|
||
|
|
||
|
sequence += [
|
||
|
nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
|
||
|
self.main = nn.Sequential(*sequence)
|
||
|
|
||
|
def forward(self, input):
|
||
|
"""Standard forward."""
|
||
|
return self.main(input)
|
||
|
|
||
|
class Discriminator(nn.Module):
|
||
|
def __init__(self, channel = 3, n_strided=6):
|
||
|
super(Discriminator, self).__init__()
|
||
|
self.main = nn.Sequential(
|
||
|
nn.Conv2d(channel, 64, 4, 2, 1, bias=False), #384 -> 192
|
||
|
nn.LeakyReLU(0.2, inplace=True),
|
||
|
nn.Conv2d(64, 128, 4, 2, 1, bias=False), #192->96
|
||
|
nn.BatchNorm2d(128),
|
||
|
nn.LeakyReLU(0.2, inplace=True),
|
||
|
nn.Conv2d(128, 256, 4, 2, 1, bias=False), # 96->48
|
||
|
nn.BatchNorm2d(256),
|
||
|
nn.LeakyReLU(0.2, inplace=True),
|
||
|
nn.Conv2d(256, 512, 4, 2, 1, bias=False), #48->24
|
||
|
nn.BatchNorm2d(512),
|
||
|
nn.LeakyReLU(0.2, inplace=True),
|
||
|
nn.Conv2d(512, 1024, 4, 2, 1, bias=False), #24->12
|
||
|
nn.BatchNorm2d(1024),
|
||
|
nn.LeakyReLU(0.2, inplace=True),
|
||
|
nn.Conv2d(1024, 1, 4, 2, 1, bias=False), #12->6
|
||
|
)
|
||
|
self.last = nn.Sequential(
|
||
|
#(B, 6*6)
|
||
|
nn.Linear(6*6, 1),
|
||
|
#nn.Sigmoid()
|
||
|
)
|
||
|
|
||
|
def discriminator_block(in_filters, out_filters):
|
||
|
layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1), nn.LeakyReLU(0.01)]
|
||
|
return layers
|
||
|
|
||
|
layers = discriminator_block(channel, 32)
|
||
|
curr_dim = 32
|
||
|
for _ in range(n_strided-1):
|
||
|
layers.extend(discriminator_block(curr_dim, curr_dim*2))
|
||
|
curr_dim *= 2
|
||
|
layers.extend(discriminator_block(curr_dim,curr_dim))
|
||
|
self.model = nn.Sequential(*layers)
|
||
|
self.out1 = nn.Conv2d(curr_dim, 1, 3, stride=1, padding=0, bias=False)
|
||
|
def forward(self, x):
|
||
|
#x = self.main(x).view(-1,6*6)
|
||
|
feature_repr = self.model(x)
|
||
|
x = self.out1(feature_repr)
|
||
|
return x.view(-1, 1)#self.last(x)
|
||
|
|
||
|
##############################
|
||
|
# RESNET
|
||
|
##############################
|
||
|
|
||
|
|
||
|
class ResidualBlock(nn.Module):
|
||
|
def __init__(self, in_features):
|
||
|
super(ResidualBlock, self).__init__()
|
||
|
|
||
|
conv_block = [
|
||
|
nn.Conv2d(in_features, in_features, 3, stride=1, padding=1, bias=False),
|
||
|
nn.InstanceNorm2d(in_features, affine=True, track_running_stats=True),
|
||
|
nn.ReLU(inplace=True),
|
||
|
nn.Conv2d(in_features, in_features, 3, stride=1, padding=1, bias=False),
|
||
|
nn.InstanceNorm2d(in_features, affine=True, track_running_stats=True),
|
||
|
]
|
||
|
|
||
|
self.conv_block = nn.Sequential(*conv_block)
|
||
|
|
||
|
def forward(self, x):
|
||
|
return x + self.conv_block(x)
|
||
|
|
||
|
class Pre_training(nn.Module):
|
||
|
def __init__(self, encoder, channel=3, res_blocks=5, dropout_rate=0.0, patch_size=16) -> None:
|
||
|
super().__init__()
|
||
|
self.encoder = encoder_params[encoder]["init_op"]()
|
||
|
self.emb_ch = encoder_params[encoder]["features"]
|
||
|
|
||
|
'''
|
||
|
self.teacher = DeepFakeClassifier(encoder="tf_efficientnet_b7_ns").to("cuda")
|
||
|
checkpoint = torch.load('weights/final_111_DeepFakeClassifier_tf_efficientnet_b7_ns_0_36', map_location='cpu')
|
||
|
state_dict = checkpoint.get("state_dict", checkpoint)
|
||
|
self.teacher.load_state_dict({re.sub("^module.", "", k): v for k, v in state_dict.items()}, strict=False)
|
||
|
'''
|
||
|
'''
|
||
|
self.deconv = nn.Sequential(
|
||
|
nn.Conv2d(self.emb_ch, self.emb_ch//2, kernel_size=3, stride=1, padding=1),
|
||
|
nn.BatchNorm2d(self.emb_ch // 2),
|
||
|
nn.ReLU(True),
|
||
|
nn.Conv2d(self.emb_ch//2, self.emb_ch //4, kernel_size=3, stride=1, padding=1),
|
||
|
nn.BatchNorm2d(self.emb_ch //4),
|
||
|
nn.ReLU(True),
|
||
|
)
|
||
|
'''
|
||
|
'''
|
||
|
self.deconv = nn.Sequential(
|
||
|
nn.ConvTranspose2d(self.emb_ch, self.emb_ch//2 , kernel_size=4, stride=2, padding=1, bias=False),
|
||
|
nn.BatchNorm2d(self.emb_ch//2),
|
||
|
nn.ReLU(True),
|
||
|
nn.ConvTranspose2d(self.emb_ch//2, self.emb_ch // 4, kernel_size=4, stride=2, padding=1, bias=False),
|
||
|
nn.BatchNorm2d(self.emb_ch // 4),
|
||
|
nn.ReLU(True),
|
||
|
nn.ConvTranspose2d(self.emb_ch//4, self.emb_ch // 8, kernel_size=4, stride=2, padding=1, bias=False),
|
||
|
nn.BatchNorm2d(self.emb_ch // 8),
|
||
|
nn.ReLU(True),
|
||
|
nn.ConvTranspose2d(self.emb_ch//8, channel, kernel_size=4, stride=2, padding=1, bias=False),
|
||
|
nn.Tanh()
|
||
|
)
|
||
|
'''
|
||
|
#self.deconv = nn.ConvTranspose2d(self.emb_ch, 3, kernel_size=16, stride=16)
|
||
|
#self.decoder = Decoder(double_z = False, z_channels = 1024, resolution= 384, in_channels=3, out_ch=3, ch=64
|
||
|
# , ch_mult=[1,1,2,2], num_res_blocks = 0, attn_resolutions=[16], dropout=0.0)
|
||
|
#nn.ConvTranspose2d(encoder_params[encoder]["features"], channel, kernel_size=patch_size, stride=patch_size)
|
||
|
channels = self.emb_ch
|
||
|
model = [
|
||
|
nn.ConvTranspose2d(channels, channels, 7, stride=1, padding=3, bias=False),
|
||
|
nn.InstanceNorm2d(channels, affine=True, track_running_stats=True),
|
||
|
nn.ReLU(inplace=True),
|
||
|
]
|
||
|
curr_dim = channels
|
||
|
|
||
|
for _ in range(2):
|
||
|
model+=[
|
||
|
nn.ConvTranspose2d(curr_dim, curr_dim//2, 4, stride=2, padding=1, bias=False),
|
||
|
nn.InstanceNorm2d(curr_dim//2, affine=True, track_running_stats=True),
|
||
|
nn.ReLU(inplace=True),
|
||
|
]
|
||
|
curr_dim //= 2
|
||
|
|
||
|
#Residual blocks
|
||
|
for _ in range(res_blocks):
|
||
|
model += [ResidualBlock(curr_dim)]
|
||
|
#Upsampling
|
||
|
for _ in range(2):
|
||
|
model += [
|
||
|
nn.ConvTranspose2d(curr_dim, curr_dim//2, 4, stride=2, padding=1, bias=False),
|
||
|
nn.InstanceNorm2d(curr_dim//2, affine=True, track_running_stats=True),
|
||
|
nn.ReLU(inplace=True),
|
||
|
]
|
||
|
curr_dim = curr_dim //2
|
||
|
#output layer
|
||
|
model += [nn.Conv2d(curr_dim, channel, 7, stride=1, padding=3), nn.Tanh()]
|
||
|
self.model = nn.Sequential(*model)
|
||
|
self.fc = Linear(encoder_params[encoder]["features"], 1)
|
||
|
self.dropout = Dropout(dropout_rate)
|
||
|
'''
|
||
|
def generator(self, x, freeze):
|
||
|
if freeze:
|
||
|
with torch.no_grad():
|
||
|
_, z = self.encoder.pre_training(x)
|
||
|
for param in self.encoder.parameters():
|
||
|
param.requires_grad = False
|
||
|
else:
|
||
|
#with torch.enable_grad():
|
||
|
for param in self.encoder.parameters():
|
||
|
param.requires_grad = True
|
||
|
_, z = self.encoder.pre_training(x)
|
||
|
x = self.model(z)
|
||
|
return x
|
||
|
def discriminator(self, x ,freeze):
|
||
|
if freeze:
|
||
|
with torch.no_grad():
|
||
|
cls_token, _ = self.encoder.pre_training(x)
|
||
|
for param in self.encoder.parameters():
|
||
|
param.requires_grad = False
|
||
|
else:
|
||
|
#with torch.enable_grad():
|
||
|
for param in self.encoder.parameters():
|
||
|
param.requires_grad = True
|
||
|
cls_token, _ = self.encoder.pre_training(x)
|
||
|
x = self.dropout(cls_token)
|
||
|
cls = self.fc(x)
|
||
|
return cls
|
||
|
'''
|
||
|
def get_class(self,x):
|
||
|
for param in self.teacher.parameters():
|
||
|
param.requires_grad = False
|
||
|
teacher_logits = self.teacher(x)
|
||
|
return teacher_logits
|
||
|
|
||
|
def forward(self, x):
|
||
|
cls_token, z = self.encoder.pre_training(x)
|
||
|
#with torch.no_grad():
|
||
|
# teacher_logits = self.teacher(x)
|
||
|
#x = self.deconv(x)
|
||
|
#x = self.decoder(x)
|
||
|
#cls = self.dropout(cls_token)
|
||
|
#cls_token = self.fc(cls)
|
||
|
|
||
|
x = self.model(z)
|
||
|
return x#, cls_token, teacher_logits#, labels
|
||
|
|
||
|
class DeepFakeClassifierGWAP(nn.Module):
|
||
|
def __init__(self, encoder, dropout_rate=0.5) -> None:
|
||
|
super().__init__()
|
||
|
self.encoder = encoder_params[encoder]["init_op"]()
|
||
|
self.avg_pool = GlobalWeightedAvgPool2d(encoder_params[encoder]["features"])
|
||
|
self.dropout = Dropout(dropout_rate)
|
||
|
self.fc = Linear(encoder_params[encoder]["features"], 1)
|
||
|
|
||
|
def forward(self, x):
|
||
|
x = self.encoder.forward_features(x)
|
||
|
x = self.avg_pool(x).flatten(1)
|
||
|
x = self.dropout(x)
|
||
|
x = self.fc(x)
|
||
|
return x
|