haoyuanhuang
2 years ago
commit
fb1a64c7b1
12 changed files with 887 additions and 0 deletions
@ -0,0 +1,29 @@ |
|||
*.7z filter=lfs diff=lfs merge=lfs -text |
|||
*.arrow filter=lfs diff=lfs merge=lfs -text |
|||
*.bin filter=lfs diff=lfs merge=lfs -text |
|||
*.bin.* filter=lfs diff=lfs merge=lfs -text |
|||
*.bz2 filter=lfs diff=lfs merge=lfs -text |
|||
*.ftz filter=lfs diff=lfs merge=lfs -text |
|||
*.gz filter=lfs diff=lfs merge=lfs -text |
|||
*.h5 filter=lfs diff=lfs merge=lfs -text |
|||
*.joblib filter=lfs diff=lfs merge=lfs -text |
|||
*.lfs.* filter=lfs diff=lfs merge=lfs -text |
|||
*.model filter=lfs diff=lfs merge=lfs -text |
|||
*.msgpack filter=lfs diff=lfs merge=lfs -text |
|||
*.onnx filter=lfs diff=lfs merge=lfs -text |
|||
*.ot filter=lfs diff=lfs merge=lfs -text |
|||
*.parquet filter=lfs diff=lfs merge=lfs -text |
|||
*.pb filter=lfs diff=lfs merge=lfs -text |
|||
*.pt filter=lfs diff=lfs merge=lfs -text |
|||
*.pth filter=lfs diff=lfs merge=lfs -text |
|||
*.rar filter=lfs diff=lfs merge=lfs -text |
|||
saved_model/**/* filter=lfs diff=lfs merge=lfs -text |
|||
*.tar.* filter=lfs diff=lfs merge=lfs -text |
|||
*.tflite filter=lfs diff=lfs merge=lfs -text |
|||
*.tgz filter=lfs diff=lfs merge=lfs -text |
|||
*.xz filter=lfs diff=lfs merge=lfs -text |
|||
*.zip filter=lfs diff=lfs merge=lfs -text |
|||
*.zstandard filter=lfs diff=lfs merge=lfs -text |
|||
*tfevents* filter=lfs diff=lfs merge=lfs -text |
|||
weights filter=lfs diff=lfs merge=lfs -text |
|||
weights/** filter=lfs diff=lfs merge=lfs -text |
@ -0,0 +1,53 @@ |
|||
# Deepfake |
|||
|
|||
*author: Zhuoran Yu* |
|||
|
|||
<br /> |
|||
|
|||
## Description |
|||
|
|||
Deepfake techniques, which present realistic AI-generated videos of people doing and saying fictional things, have the potential to have a significant impact on how people determine the legitimacy of information presented online. |
|||
|
|||
This operator predicts the probability of a fake video for a given video.This is an adaptation from [DeepfakeDetection](https://github.com/smu-ivpl/DeepfakeDetection). |
|||
|
|||
<br /> |
|||
|
|||
## Code Example |
|||
|
|||
Load videos from path '/home/test_video' |
|||
and use deepfake operator to predict the probabilities of fake videos. |
|||
|
|||
|
|||
```python |
|||
import towhee |
|||
( |
|||
towhee.glob['path']('/home/test_video') |
|||
.deepfake['path', 'scores']() |
|||
.select['path', 'scores']() |
|||
.show() |
|||
) |
|||
``` |
|||
|
|||
<img src="./deepfake.png" height="100px"/> |
|||
|
|||
```shell |
|||
[0.9893, 0.9097] |
|||
``` |
|||
|
|||
<br /> |
|||
|
|||
## Interface |
|||
|
|||
A deepfake operator takes videos' paths as input. |
|||
It predicts the probabilities of fake videos.The higher the score, the higher the probability of it being a fake video.(It can be considered to be a fake video with score higher than 0.5) |
|||
|
|||
**Parameters:** |
|||
|
|||
***filepath:*** *str* |
|||
|
|||
Absolute address of the test videos. |
|||
|
|||
|
|||
**Returns:** *list* |
|||
|
|||
The probabilities of videos being fake ones. |
@ -0,0 +1,13 @@ |
|||
# Copyright 2021 Zilliz. All rights reserved. |
|||
# |
|||
# Licensed under the Apache License, Version 2.0 (the "License"); |
|||
# you may not use this file except in compliance with the License. |
|||
# You may obtain a copy of the License at |
|||
# |
|||
# http://www.apache.org/licenses/LICENSE-2.0 |
|||
# |
|||
# Unless required by applicable law or agreed to in writing, software |
|||
# distributed under the License is distributed on an "AS IS" BASIS, |
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
# See the License for the specific language governing permissions and |
|||
# limitations under the License. |
Binary file not shown.
Binary file not shown.
@ -0,0 +1,317 @@ |
|||
from functools import partial |
|||
|
|||
import numpy as np |
|||
import torch |
|||
from timm.models.efficientnet import tf_efficientnet_b7_ns |
|||
from torch import nn |
|||
from torch.nn.modules.dropout import Dropout |
|||
from torch.nn.modules.linear import Linear |
|||
from torch.nn.modules.pooling import AdaptiveAvgPool2d |
|||
#from facebook_deit import deit_base_patch16_224, deit_distill_large_patch16_384, deit_distill_large_patch32_384 |
|||
#from taming_transformer import Decoder, VUNet, ActNorm |
|||
import functools |
|||
#from vit_pytorch.distill import DistillableViT, DistillWrapper, DistillableEfficientViT |
|||
import re |
|||
|
|||
encoder_params = { |
|||
"tf_efficientnet_b7_ns": { |
|||
"features": 2560, |
|||
"init_op": partial(tf_efficientnet_b7_ns, pretrained=True, drop_path_rate=0.2) |
|||
} |
|||
} |
|||
|
|||
class GlobalWeightedAvgPool2d(nn.Module): |
|||
""" |
|||
Global Weighted Average Pooling from paper "Global Weighted Average |
|||
Pooling Bridges Pixel-level Localization and Image-level Classification" |
|||
""" |
|||
|
|||
def __init__(self, features: int, flatten=False): |
|||
super().__init__() |
|||
self.conv = nn.Conv2d(features, 1, kernel_size=1, bias=True) |
|||
self.flatten = flatten |
|||
|
|||
def fscore(self, x): |
|||
m = self.conv(x) |
|||
m = m.sigmoid().exp() |
|||
return m |
|||
|
|||
def norm(self, x: torch.Tensor): |
|||
return x / x.sum(dim=[2, 3], keepdim=True) |
|||
|
|||
def forward(self, x): |
|||
input_x = x |
|||
x = self.fscore(x) |
|||
x = self.norm(x) |
|||
x = x * input_x |
|||
x = x.sum(dim=[2, 3], keepdim=not self.flatten) |
|||
return x |
|||
|
|||
|
|||
class DeepFakeClassifier(nn.Module): |
|||
def __init__(self, encoder, dropout_rate=0.0) -> None: |
|||
super().__init__() |
|||
self.encoder = encoder_params[encoder]["init_op"]() |
|||
self.avg_pool = AdaptiveAvgPool2d((1, 1)) |
|||
self.dropout = Dropout(dropout_rate) |
|||
self.fc = Linear(encoder_params[encoder]["features"], 1) |
|||
|
|||
def forward(self, x): |
|||
x = self.encoder.forward_features(x) |
|||
x = self.avg_pool(x).flatten(1) |
|||
x = self.dropout(x) |
|||
x = self.fc(x) |
|||
return x |
|||
|
|||
class NLayerDiscriminator(nn.Module): |
|||
"""Defines a PatchGAN discriminator as in Pix2Pix |
|||
--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py |
|||
""" |
|||
def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False): |
|||
"""Construct a PatchGAN discriminator |
|||
Parameters: |
|||
input_nc (int) -- the number of channels in input images |
|||
ndf (int) -- the number of filters in the last conv layer |
|||
n_layers (int) -- the number of conv layers in the discriminator |
|||
norm_layer -- normalization layer |
|||
""" |
|||
super(NLayerDiscriminator, self).__init__() |
|||
if not use_actnorm: |
|||
norm_layer = nn.BatchNorm2d |
|||
else: |
|||
norm_layer = ActNorm |
|||
if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters |
|||
use_bias = norm_layer.func != nn.BatchNorm2d |
|||
else: |
|||
use_bias = norm_layer != nn.BatchNorm2d |
|||
|
|||
kw = 4 |
|||
padw = 1 |
|||
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)] |
|||
nf_mult = 1 |
|||
nf_mult_prev = 1 |
|||
for n in range(1, n_layers): # gradually increase the number of filters |
|||
nf_mult_prev = nf_mult |
|||
nf_mult = min(2 ** n, 8) |
|||
sequence += [ |
|||
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), |
|||
norm_layer(ndf * nf_mult), |
|||
nn.LeakyReLU(0.2, True) |
|||
] |
|||
|
|||
nf_mult_prev = nf_mult |
|||
nf_mult = min(2 ** n_layers, 8) |
|||
sequence += [ |
|||
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), |
|||
norm_layer(ndf * nf_mult), |
|||
nn.LeakyReLU(0.2, True) |
|||
] |
|||
|
|||
sequence += [ |
|||
nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map |
|||
self.main = nn.Sequential(*sequence) |
|||
|
|||
def forward(self, input): |
|||
"""Standard forward.""" |
|||
return self.main(input) |
|||
|
|||
class Discriminator(nn.Module): |
|||
def __init__(self, channel = 3, n_strided=6): |
|||
super(Discriminator, self).__init__() |
|||
self.main = nn.Sequential( |
|||
nn.Conv2d(channel, 64, 4, 2, 1, bias=False), #384 -> 192 |
|||
nn.LeakyReLU(0.2, inplace=True), |
|||
nn.Conv2d(64, 128, 4, 2, 1, bias=False), #192->96 |
|||
nn.BatchNorm2d(128), |
|||
nn.LeakyReLU(0.2, inplace=True), |
|||
nn.Conv2d(128, 256, 4, 2, 1, bias=False), # 96->48 |
|||
nn.BatchNorm2d(256), |
|||
nn.LeakyReLU(0.2, inplace=True), |
|||
nn.Conv2d(256, 512, 4, 2, 1, bias=False), #48->24 |
|||
nn.BatchNorm2d(512), |
|||
nn.LeakyReLU(0.2, inplace=True), |
|||
nn.Conv2d(512, 1024, 4, 2, 1, bias=False), #24->12 |
|||
nn.BatchNorm2d(1024), |
|||
nn.LeakyReLU(0.2, inplace=True), |
|||
nn.Conv2d(1024, 1, 4, 2, 1, bias=False), #12->6 |
|||
) |
|||
self.last = nn.Sequential( |
|||
#(B, 6*6) |
|||
nn.Linear(6*6, 1), |
|||
#nn.Sigmoid() |
|||
) |
|||
|
|||
def discriminator_block(in_filters, out_filters): |
|||
layers = [nn.Conv2d(in_filters, out_filters, 4, stride=2, padding=1), nn.LeakyReLU(0.01)] |
|||
return layers |
|||
|
|||
layers = discriminator_block(channel, 32) |
|||
curr_dim = 32 |
|||
for _ in range(n_strided-1): |
|||
layers.extend(discriminator_block(curr_dim, curr_dim*2)) |
|||
curr_dim *= 2 |
|||
layers.extend(discriminator_block(curr_dim,curr_dim)) |
|||
self.model = nn.Sequential(*layers) |
|||
self.out1 = nn.Conv2d(curr_dim, 1, 3, stride=1, padding=0, bias=False) |
|||
def forward(self, x): |
|||
#x = self.main(x).view(-1,6*6) |
|||
feature_repr = self.model(x) |
|||
x = self.out1(feature_repr) |
|||
return x.view(-1, 1)#self.last(x) |
|||
|
|||
############################## |
|||
# RESNET |
|||
############################## |
|||
|
|||
|
|||
class ResidualBlock(nn.Module): |
|||
def __init__(self, in_features): |
|||
super(ResidualBlock, self).__init__() |
|||
|
|||
conv_block = [ |
|||
nn.Conv2d(in_features, in_features, 3, stride=1, padding=1, bias=False), |
|||
nn.InstanceNorm2d(in_features, affine=True, track_running_stats=True), |
|||
nn.ReLU(inplace=True), |
|||
nn.Conv2d(in_features, in_features, 3, stride=1, padding=1, bias=False), |
|||
nn.InstanceNorm2d(in_features, affine=True, track_running_stats=True), |
|||
] |
|||
|
|||
self.conv_block = nn.Sequential(*conv_block) |
|||
|
|||
def forward(self, x): |
|||
return x + self.conv_block(x) |
|||
|
|||
class Pre_training(nn.Module): |
|||
def __init__(self, encoder, channel=3, res_blocks=5, dropout_rate=0.0, patch_size=16) -> None: |
|||
super().__init__() |
|||
self.encoder = encoder_params[encoder]["init_op"]() |
|||
self.emb_ch = encoder_params[encoder]["features"] |
|||
|
|||
''' |
|||
self.teacher = DeepFakeClassifier(encoder="tf_efficientnet_b7_ns").to("cuda") |
|||
checkpoint = torch.load('weights/final_111_DeepFakeClassifier_tf_efficientnet_b7_ns_0_36', map_location='cpu') |
|||
state_dict = checkpoint.get("state_dict", checkpoint) |
|||
self.teacher.load_state_dict({re.sub("^module.", "", k): v for k, v in state_dict.items()}, strict=False) |
|||
''' |
|||
''' |
|||
self.deconv = nn.Sequential( |
|||
nn.Conv2d(self.emb_ch, self.emb_ch//2, kernel_size=3, stride=1, padding=1), |
|||
nn.BatchNorm2d(self.emb_ch // 2), |
|||
nn.ReLU(True), |
|||
nn.Conv2d(self.emb_ch//2, self.emb_ch //4, kernel_size=3, stride=1, padding=1), |
|||
nn.BatchNorm2d(self.emb_ch //4), |
|||
nn.ReLU(True), |
|||
) |
|||
''' |
|||
''' |
|||
self.deconv = nn.Sequential( |
|||
nn.ConvTranspose2d(self.emb_ch, self.emb_ch//2 , kernel_size=4, stride=2, padding=1, bias=False), |
|||
nn.BatchNorm2d(self.emb_ch//2), |
|||
nn.ReLU(True), |
|||
nn.ConvTranspose2d(self.emb_ch//2, self.emb_ch // 4, kernel_size=4, stride=2, padding=1, bias=False), |
|||
nn.BatchNorm2d(self.emb_ch // 4), |
|||
nn.ReLU(True), |
|||
nn.ConvTranspose2d(self.emb_ch//4, self.emb_ch // 8, kernel_size=4, stride=2, padding=1, bias=False), |
|||
nn.BatchNorm2d(self.emb_ch // 8), |
|||
nn.ReLU(True), |
|||
nn.ConvTranspose2d(self.emb_ch//8, channel, kernel_size=4, stride=2, padding=1, bias=False), |
|||
nn.Tanh() |
|||
) |
|||
''' |
|||
#self.deconv = nn.ConvTranspose2d(self.emb_ch, 3, kernel_size=16, stride=16) |
|||
#self.decoder = Decoder(double_z = False, z_channels = 1024, resolution= 384, in_channels=3, out_ch=3, ch=64 |
|||
# , ch_mult=[1,1,2,2], num_res_blocks = 0, attn_resolutions=[16], dropout=0.0) |
|||
#nn.ConvTranspose2d(encoder_params[encoder]["features"], channel, kernel_size=patch_size, stride=patch_size) |
|||
channels = self.emb_ch |
|||
model = [ |
|||
nn.ConvTranspose2d(channels, channels, 7, stride=1, padding=3, bias=False), |
|||
nn.InstanceNorm2d(channels, affine=True, track_running_stats=True), |
|||
nn.ReLU(inplace=True), |
|||
] |
|||
curr_dim = channels |
|||
|
|||
for _ in range(2): |
|||
model+=[ |
|||
nn.ConvTranspose2d(curr_dim, curr_dim//2, 4, stride=2, padding=1, bias=False), |
|||
nn.InstanceNorm2d(curr_dim//2, affine=True, track_running_stats=True), |
|||
nn.ReLU(inplace=True), |
|||
] |
|||
curr_dim //= 2 |
|||
|
|||
#Residual blocks |
|||
for _ in range(res_blocks): |
|||
model += [ResidualBlock(curr_dim)] |
|||
#Upsampling |
|||
for _ in range(2): |
|||
model += [ |
|||
nn.ConvTranspose2d(curr_dim, curr_dim//2, 4, stride=2, padding=1, bias=False), |
|||
nn.InstanceNorm2d(curr_dim//2, affine=True, track_running_stats=True), |
|||
nn.ReLU(inplace=True), |
|||
] |
|||
curr_dim = curr_dim //2 |
|||
#output layer |
|||
model += [nn.Conv2d(curr_dim, channel, 7, stride=1, padding=3), nn.Tanh()] |
|||
self.model = nn.Sequential(*model) |
|||
self.fc = Linear(encoder_params[encoder]["features"], 1) |
|||
self.dropout = Dropout(dropout_rate) |
|||
''' |
|||
def generator(self, x, freeze): |
|||
if freeze: |
|||
with torch.no_grad(): |
|||
_, z = self.encoder.pre_training(x) |
|||
for param in self.encoder.parameters(): |
|||
param.requires_grad = False |
|||
else: |
|||
#with torch.enable_grad(): |
|||
for param in self.encoder.parameters(): |
|||
param.requires_grad = True |
|||
_, z = self.encoder.pre_training(x) |
|||
x = self.model(z) |
|||
return x |
|||
def discriminator(self, x ,freeze): |
|||
if freeze: |
|||
with torch.no_grad(): |
|||
cls_token, _ = self.encoder.pre_training(x) |
|||
for param in self.encoder.parameters(): |
|||
param.requires_grad = False |
|||
else: |
|||
#with torch.enable_grad(): |
|||
for param in self.encoder.parameters(): |
|||
param.requires_grad = True |
|||
cls_token, _ = self.encoder.pre_training(x) |
|||
x = self.dropout(cls_token) |
|||
cls = self.fc(x) |
|||
return cls |
|||
''' |
|||
def get_class(self,x): |
|||
for param in self.teacher.parameters(): |
|||
param.requires_grad = False |
|||
teacher_logits = self.teacher(x) |
|||
return teacher_logits |
|||
|
|||
def forward(self, x): |
|||
cls_token, z = self.encoder.pre_training(x) |
|||
#with torch.no_grad(): |
|||
# teacher_logits = self.teacher(x) |
|||
#x = self.deconv(x) |
|||
#x = self.decoder(x) |
|||
#cls = self.dropout(cls_token) |
|||
#cls_token = self.fc(cls) |
|||
|
|||
x = self.model(z) |
|||
return x#, cls_token, teacher_logits#, labels |
|||
|
|||
class DeepFakeClassifierGWAP(nn.Module): |
|||
def __init__(self, encoder, dropout_rate=0.5) -> None: |
|||
super().__init__() |
|||
self.encoder = encoder_params[encoder]["init_op"]() |
|||
self.avg_pool = GlobalWeightedAvgPool2d(encoder_params[encoder]["features"]) |
|||
self.dropout = Dropout(dropout_rate) |
|||
self.fc = Linear(encoder_params[encoder]["features"], 1) |
|||
|
|||
def forward(self, x): |
|||
x = self.encoder.forward_features(x) |
|||
x = self.avg_pool(x).flatten(1) |
|||
x = self.dropout(x) |
|||
x = self.fc(x) |
|||
return x |
After Width: | Height: | Size: 53 KiB |
@ -0,0 +1,70 @@ |
|||
import argparse |
|||
import os |
|||
import re |
|||
import string |
|||
import time |
|||
import sys |
|||
from pathlib import Path |
|||
import torch |
|||
import pandas as pd |
|||
|
|||
import towhee |
|||
from towhee.operator.base import NNOperator, OperatorFlag |
|||
from towhee import register |
|||
import warnings |
|||
warnings.filterwarnings('ignore') |
|||
import logging |
|||
log = logging.getLogger() |
|||
|
|||
@register(output_schema=["scorelist"], |
|||
flag=OperatorFlag.STATELESS | OperatorFlag.REUSEABLE) |
|||
|
|||
class Deepfake(NNOperator): |
|||
''' |
|||
Deepfake |
|||
''' |
|||
def __init__(self): |
|||
super().__init__() |
|||
sys.path.append(str(Path(__file__).parent)) |
|||
weights_dir = os.path.join(str(Path(__file__).parent),"weights/") |
|||
self.model_paths = [os.path.join(weights_dir,model) for model in os.listdir(weights_dir)] |
|||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|||
|
|||
def __call__(self, filepath: string) -> list: |
|||
from kernel_utils import VideoReader, FaceExtractor, confident_strategy, predict_on_video |
|||
from classifiers import DeepFakeClassifier |
|||
models = [] |
|||
for path in self.model_paths: |
|||
model = DeepFakeClassifier(encoder="tf_efficientnet_b7_ns").to(self.device) |
|||
print("loading state dict {}".format(path)) |
|||
checkpoint = torch.load(path, map_location="cpu") |
|||
state_dict = checkpoint.get("state_dict", checkpoint) |
|||
model.load_state_dict({re.sub("^module.", "", k): v for k, v in state_dict.items()}, strict=False) |
|||
model.eval() |
|||
del checkpoint |
|||
models.append(model.half()) |
|||
frames_per_video = 32 |
|||
video_reader = VideoReader() |
|||
video_read_fn = lambda x: video_reader.read_frames(x, num_frames=frames_per_video) |
|||
face_extractor = FaceExtractor(video_read_fn) |
|||
input_size = 384 |
|||
strategy = confident_strategy |
|||
#stime = time.time() |
|||
prediction = predict_on_video(False, face_extractor=face_extractor, video_path=filepath, |
|||
input_size=input_size, batch_size=frames_per_video, models=models, |
|||
strategy=strategy, apply_compression=False) |
|||
''' |
|||
test_videos = sorted([x for x in os.listdir(filepath) if x[-4:] == ".mp4"]) |
|||
print("Predicting {} videos".format(len(test_videos))) |
|||
predictions = predict_on_video_set(False, face_extractor=face_extractor, input_size=input_size, models=models, |
|||
strategy=strategy, frames_per_video=frames_per_video, videos=test_videos, |
|||
num_workers=2, test_dir=filepath) |
|||
''' |
|||
return prediction |
|||
''' |
|||
if __name__ == "__main__": |
|||
filepath = "/Users/zilliz/Desktop/deepfake_video/test/aagfhgtpmv.mp4" |
|||
op = Deepfake() |
|||
pred = op(filepath=filepath) |
|||
print(pred) |
|||
''' |
@ -0,0 +1,390 @@ |
|||
import os |
|||
import math |
|||
import cv2 |
|||
import numpy as np |
|||
import torch |
|||
from PIL import Image |
|||
from albumentations.augmentations.functional import image_compression |
|||
from facenet_pytorch.models.mtcnn import MTCNN |
|||
from concurrent.futures import ThreadPoolExecutor |
|||
import matplotlib.pyplot as plt |
|||
from torchvision.transforms import Normalize |
|||
import logging |
|||
|
|||
log = logging.getLogger() |
|||
mean = [0.485, 0.456, 0.406] |
|||
std = [0.229, 0.224, 0.225] |
|||
normalize_transform = Normalize(mean, std) |
|||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|||
|
|||
class VideoReader: |
|||
"""Helper class for reading one or more frames from a video file.""" |
|||
|
|||
def __init__(self, verbose=True, insets=(0, 0)): |
|||
"""Creates a new VideoReader. |
|||
|
|||
Arguments: |
|||
verbose: whether to print warnings and error messages |
|||
insets: amount to inset the image by, as a percentage of |
|||
(width, height). This lets you "zoom in" to an image |
|||
to remove unimportant content around the borders. |
|||
Useful for face detection, which may not work if the |
|||
faces are too small. |
|||
""" |
|||
self.verbose = verbose |
|||
self.insets = insets |
|||
|
|||
def read_frames(self, path, num_frames, jitter=0, seed=None): |
|||
"""Reads frames that are always evenly spaced throughout the video. |
|||
|
|||
Arguments: |
|||
path: the video file |
|||
num_frames: how many frames to read, -1 means the entire video |
|||
(warning: this will take up a lot of memory!) |
|||
jitter: if not 0, adds small random offsets to the frame indices; |
|||
this is useful so we don't always land on even or odd frames |
|||
seed: random seed for jittering; if you set this to a fixed value, |
|||
you probably want to set it only on the first video |
|||
""" |
|||
assert num_frames > 0 |
|||
|
|||
capture = cv2.VideoCapture(path) |
|||
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT)) |
|||
if frame_count <= 0: return None |
|||
|
|||
frame_idxs = np.linspace(0, frame_count - 1, num_frames, endpoint=True, dtype=np.int) |
|||
if jitter > 0: |
|||
np.random.seed(seed) |
|||
jitter_offsets = np.random.randint(-jitter, jitter, len(frame_idxs)) |
|||
frame_idxs = np.clip(frame_idxs + jitter_offsets, 0, frame_count - 1) |
|||
|
|||
result = self._read_frames_at_indices(path, capture, frame_idxs) |
|||
capture.release() |
|||
return result |
|||
|
|||
def read_random_frames(self, path, num_frames, seed=None): |
|||
"""Picks the frame indices at random. |
|||
|
|||
Arguments: |
|||
path: the video file |
|||
num_frames: how many frames to read, -1 means the entire video |
|||
(warning: this will take up a lot of memory!) |
|||
""" |
|||
assert num_frames > 0 |
|||
np.random.seed(seed) |
|||
|
|||
capture = cv2.VideoCapture(path) |
|||
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT)) |
|||
if frame_count <= 0: return None |
|||
|
|||
frame_idxs = sorted(np.random.choice(np.arange(0, frame_count), num_frames)) |
|||
result = self._read_frames_at_indices(path, capture, frame_idxs) |
|||
|
|||
capture.release() |
|||
return result |
|||
|
|||
def read_frames_at_indices(self, path, frame_idxs): |
|||
"""Reads frames from a video and puts them into a NumPy array. |
|||
|
|||
Arguments: |
|||
path: the video file |
|||
frame_idxs: a list of frame indices. Important: should be |
|||
sorted from low-to-high! If an index appears multiple |
|||
times, the frame is still read only once. |
|||
|
|||
Returns: |
|||
- a NumPy array of shape (num_frames, height, width, 3) |
|||
- a list of the frame indices that were read |
|||
|
|||
Reading stops if loading a frame fails, in which case the first |
|||
dimension returned may actually be less than num_frames. |
|||
|
|||
Returns None if an exception is thrown for any reason, or if no |
|||
frames were read. |
|||
""" |
|||
assert len(frame_idxs) > 0 |
|||
capture = cv2.VideoCapture(path) |
|||
result = self._read_frames_at_indices(path, capture, frame_idxs) |
|||
capture.release() |
|||
return result |
|||
|
|||
def _read_frames_at_indices(self, path, capture, frame_idxs): |
|||
try: |
|||
frames = [] |
|||
idxs_read = [] |
|||
for frame_idx in range(frame_idxs[0], frame_idxs[-1] + 1): |
|||
# Get the next frame, but don't decode if we're not using it. |
|||
ret = capture.grab() |
|||
if not ret: |
|||
if self.verbose: |
|||
log.error("Error grabbing frame %d from movie %s" % (frame_idx, path)) |
|||
break |
|||
|
|||
# Need to look at this frame? |
|||
current = len(idxs_read) |
|||
if frame_idx == frame_idxs[current]: |
|||
ret, frame = capture.retrieve() |
|||
if not ret or frame is None: |
|||
if self.verbose: |
|||
log.error("Error retrieving frame %d from movie %s" % (frame_idx, path)) |
|||
break |
|||
|
|||
frame = self._postprocess_frame(frame) |
|||
frames.append(frame) |
|||
idxs_read.append(frame_idx) |
|||
|
|||
if len(frames) > 0: |
|||
return np.stack(frames), idxs_read |
|||
if self.verbose: |
|||
log.error("No frames read from movie %s" % path) |
|||
return None |
|||
except: |
|||
if self.verbose: |
|||
log.error("Exception while reading movie %s" % path) |
|||
return None |
|||
|
|||
def read_middle_frame(self, path): |
|||
"""Reads the frame from the middle of the video.""" |
|||
capture = cv2.VideoCapture(path) |
|||
frame_count = int(capture.get(cv2.CAP_PROP_FRAME_COUNT)) |
|||
result = self._read_frame_at_index(path, capture, frame_count // 2) |
|||
capture.release() |
|||
return result |
|||
|
|||
def read_frame_at_index(self, path, frame_idx): |
|||
"""Reads a single frame from a video. |
|||
|
|||
If you just want to read a single frame from the video, this is more |
|||
efficient than scanning through the video to find the frame. However, |
|||
for reading multiple frames it's not efficient. |
|||
|
|||
My guess is that a "streaming" approach is more efficient than a |
|||
"random access" approach because, unless you happen to grab a keyframe, |
|||
the decoder still needs to read all the previous frames in order to |
|||
reconstruct the one you're asking for. |
|||
|
|||
Returns a NumPy array of shape (1, H, W, 3) and the index of the frame, |
|||
or None if reading failed. |
|||
""" |
|||
capture = cv2.VideoCapture(path) |
|||
result = self._read_frame_at_index(path, capture, frame_idx) |
|||
capture.release() |
|||
return result |
|||
|
|||
def _read_frame_at_index(self, path, capture, frame_idx): |
|||
capture.set(cv2.CAP_PROP_POS_FRAMES, frame_idx) |
|||
ret, frame = capture.read() |
|||
if not ret or frame is None: |
|||
if self.verbose: |
|||
log.error("Error retrieving frame %d from movie %s" % (frame_idx, path)) |
|||
return None |
|||
else: |
|||
frame = self._postprocess_frame(frame) |
|||
return np.expand_dims(frame, axis=0), [frame_idx] |
|||
|
|||
def _postprocess_frame(self, frame): |
|||
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) |
|||
|
|||
if self.insets[0] > 0: |
|||
W = frame.shape[1] |
|||
p = int(W * self.insets[0]) |
|||
frame = frame[:, p:-p, :] |
|||
|
|||
if self.insets[1] > 0: |
|||
H = frame.shape[1] |
|||
q = int(H * self.insets[1]) |
|||
frame = frame[q:-q, :, :] |
|||
|
|||
return frame |
|||
|
|||
|
|||
class FaceExtractor: |
|||
def __init__(self, video_read_fn): |
|||
self.video_read_fn = video_read_fn |
|||
self.detector = MTCNN(margin=0, thresholds=[0.7, 0.8, 0.8], device=device) |
|||
|
|||
def process_videos(self, input_dir, filenames, video_idxs): |
|||
videos_read = [] |
|||
frames_read = [] |
|||
frames = [] |
|||
results = [] |
|||
for video_idx in video_idxs: |
|||
# Read the full-size frames from this video. |
|||
filename = filenames[video_idx] |
|||
video_path = os.path.join(input_dir, filename) |
|||
result = self.video_read_fn(video_path) |
|||
# Error? Then skip this video. |
|||
if result is None: continue |
|||
|
|||
videos_read.append(video_idx) |
|||
|
|||
# Keep track of the original frames (need them later). |
|||
my_frames, my_idxs = result |
|||
|
|||
frames.append(my_frames) |
|||
frames_read.append(my_idxs) |
|||
for i, frame in enumerate(my_frames): |
|||
h, w = frame.shape[:2] |
|||
img = Image.fromarray(frame.astype(np.uint8)) |
|||
img = img.resize(size=[s // 2 for s in img.size]) |
|||
|
|||
batch_boxes, probs = self.detector.detect(img, landmarks=False) |
|||
|
|||
faces = [] |
|||
scores = [] |
|||
if batch_boxes is None: |
|||
continue |
|||
for bbox, score in zip(batch_boxes, probs): |
|||
if bbox is not None: |
|||
xmin, ymin, xmax, ymax = [int(b * 2) for b in bbox] |
|||
w = xmax - xmin |
|||
h = ymax - ymin |
|||
p_h = h // 3 |
|||
p_w = w // 3 |
|||
crop = frame[max(ymin - p_h, 0):ymax + p_h, max(xmin - p_w, 0):xmax + p_w] |
|||
faces.append(crop) |
|||
scores.append(score) |
|||
|
|||
frame_dict = {"video_idx": video_idx, |
|||
"frame_idx": my_idxs[i], |
|||
"frame_w": w, |
|||
"frame_h": h, |
|||
"faces": faces, |
|||
"scores": scores} |
|||
results.append(frame_dict) |
|||
|
|||
return results |
|||
|
|||
def process_video(self, video_path): |
|||
"""Convenience method for doing face extraction on a single video.""" |
|||
input_dir = os.path.dirname(video_path) |
|||
filenames = [os.path.basename(video_path)] |
|||
return self.process_videos(input_dir, filenames, [0]) |
|||
|
|||
|
|||
|
|||
def confident_strategy(pred, t=0.8): |
|||
pred = np.array(pred) |
|||
sz = len(pred) |
|||
fakes = np.count_nonzero(pred > t) |
|||
# 11 frames are detected as fakes with high probability |
|||
if fakes > sz // 2.5 and fakes > 11: |
|||
return np.mean(pred[pred > t]) |
|||
elif np.count_nonzero(pred < 0.2) > 0.9 * sz: |
|||
return np.mean(pred[pred < 0.2]) |
|||
else: |
|||
return np.mean(pred) |
|||
|
|||
strategy = confident_strategy |
|||
|
|||
|
|||
def put_to_center(img, input_size): |
|||
img = img[:input_size, :input_size] |
|||
image = np.zeros((input_size, input_size, 3), dtype=np.uint8) |
|||
start_w = (input_size - img.shape[1]) // 2 |
|||
start_h = (input_size - img.shape[0]) // 2 |
|||
image[start_h:start_h + img.shape[0], start_w: start_w + img.shape[1], :] = img |
|||
return image |
|||
|
|||
|
|||
def isotropically_resize_image(img, size, interpolation_down=cv2.INTER_AREA, interpolation_up=cv2.INTER_CUBIC): |
|||
h, w = img.shape[:2] |
|||
if max(w, h) == size: |
|||
return img |
|||
if w > h: |
|||
scale = size / w |
|||
h = h * scale |
|||
w = size |
|||
else: |
|||
scale = size / h |
|||
w = w * scale |
|||
h = size |
|||
interpolation = interpolation_up if scale > 1 else interpolation_down |
|||
resized = cv2.resize(img, (int(w), int(h)), interpolation=interpolation) |
|||
return resized |
|||
|
|||
def dist(p1, p2): |
|||
return math.sqrt((p1[0] - p2[0]) ** 2 + (p1[1] - p2[1]) ** 2) |
|||
|
|||
detector = MTCNN(margin=0, thresholds=(0.7, 0.8, 0.8), device=device) |
|||
def predict_on_video(distill, face_extractor, video_path, batch_size, input_size, models, strategy=np.mean, |
|||
apply_compression=False): |
|||
batch_size *= 4 |
|||
try: |
|||
faces = face_extractor.process_video(video_path) |
|||
if len(faces) > 0: |
|||
x = np.zeros((batch_size, input_size, input_size, 3), dtype=np.uint8) |
|||
#e = np.zeros((batch_size, 32, 32, 3), dtype=np.uint8) #eye |
|||
n = 0 |
|||
for frame_data in faces: |
|||
for face in frame_data["faces"]: |
|||
#print(face) |
|||
# _,_,landmark = detector.detect(face, landmarks=True) |
|||
'''# eye 0524 |
|||
try: |
|||
landmark = np.around(landmark[0]).astype(np.int16) |
|||
(x1, y1), (x2, y2) = landmark[:2] |
|||
w = dist((x1, y1), (x2, y2)) |
|||
dilation = int(w // 4) |
|||
eye_image = face[y2 - dilation:y1 + dilation, x1 - dilation:x2 + dilation] |
|||
eye_image = cv2.resize(eye_image, dsize=(32, 32), interpolation=cv2.INTER_CUBIC) |
|||
except Exception as ex: |
|||
eye_image = cv2.resize(face, dsize=(32, 32), interpolation=cv2.INTER_CUBIC) |
|||
'''# |
|||
resized_face = isotropically_resize_image(face, input_size) |
|||
resized_face = put_to_center(resized_face, input_size) |
|||
|
|||
if apply_compression: |
|||
resized_face = image_compression(resized_face, quality=90, image_type=".jpg") |
|||
#eye_image = image_compression(eye_image, quality=90, image_type=".jpg")#eye |
|||
if n + 1 < batch_size: |
|||
x[n] = resized_face |
|||
#e[n] = eye_image#eye |
|||
n += 1 |
|||
else: |
|||
pass |
|||
if n > 0: |
|||
x = torch.tensor(x, device=device).float() |
|||
#e = torch.tensor(e, device="cuda").float() #eye |
|||
# Preprocess the images. |
|||
x = x.permute((0, 3, 1, 2)) |
|||
#e = e.permute((0, 3, 1, 2))#eye |
|||
for i in range(len(x)): |
|||
x[i] = normalize_transform(x[i] / 255.) |
|||
#e[i] = normalize_transform(e[i] / 255.) #eye |
|||
# Make a prediction, then take the average. |
|||
with torch.no_grad(): |
|||
preds = [] |
|||
for model in models: |
|||
if distill: |
|||
_, y_pred, _ = model(x[:n]) #eye , e[:n].half() |
|||
else: |
|||
y_pred = model(x[:n]) |
|||
y_pred = torch.sigmoid(y_pred.squeeze()) |
|||
bpred = y_pred[:n].cpu().numpy() |
|||
preds.append(strategy(bpred)) |
|||
return np.mean(preds) |
|||
except Exception as e: |
|||
log.error("Prediction error on video %s: %s" % (video_path, str(e))) |
|||
|
|||
return 0.5 |
|||
|
|||
|
|||
def predict_on_video_set(distill, face_extractor, videos, input_size, num_workers, test_dir, frames_per_video, models, |
|||
strategy=np.mean, |
|||
apply_compression=False): |
|||
def process_file(i): |
|||
filename = videos[i] |
|||
y_pred = predict_on_video(distill, face_extractor=face_extractor, video_path=os.path.join(test_dir, filename), |
|||
input_size=input_size, |
|||
batch_size=frames_per_video, |
|||
models=models, strategy=strategy, apply_compression=apply_compression) |
|||
return y_pred |
|||
|
|||
with ThreadPoolExecutor(max_workers=num_workers) as ex: |
|||
predictions = ex.map(process_file, range(len(videos))) |
|||
#predictions = [] |
|||
#for i in range(len(videos)): |
|||
# predictions.append(process_file(i)) |
|||
return list(predictions) |
|||
|
@ -0,0 +1,9 @@ |
|||
dlib |
|||
facenet-pytorch |
|||
albumentations |
|||
timm |
|||
pytorch_toolbelt |
|||
tensorboardx |
|||
matplotlib |
|||
tqdm |
|||
pandas |
Binary file not shown.
Binary file not shown.
Loading…
Reference in new issue