logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

67 lines
2.5 KiB

4 years ago
import argparse
import os
import re
import string
import time
4 years ago
import sys
from pathlib import Path
4 years ago
import torch
import pandas as pd
4 years ago
4 years ago
import towhee
4 years ago
from towhee.operator.base import NNOperator, OperatorFlag
4 years ago
from towhee import register
import warnings
warnings.filterwarnings('ignore')
import logging
log = logging.getLogger()
@register(output_schema=["scorelist"],
flag=OperatorFlag.STATELESS | OperatorFlag.REUSEABLE)
4 years ago
class Deepfake(NNOperator):
4 years ago
'''
4 years ago
Deepfake
4 years ago
'''
def __init__(self):
4 years ago
super().__init__()
sys.path.append(str(Path(__file__).parent))
weights_dir = os.path.join(str(Path(__file__).parent),"weights/")
self.model_paths = [os.path.join(weights_dir,model) for model in os.listdir(weights_dir)]
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
4 years ago
def __call__(self, filepath: string) -> list:
4 years ago
from kernel_utils import VideoReader, FaceExtractor, confident_strategy, predict_on_video_set
from classifiers import DeepFakeClassifier
4 years ago
models = []
4 years ago
for path in self.model_paths:
model = DeepFakeClassifier(encoder="tf_efficientnet_b7_ns").to(self.device)
4 years ago
print("loading state dict {}".format(path))
checkpoint = torch.load(path, map_location="cpu")
state_dict = checkpoint.get("state_dict", checkpoint)
model.load_state_dict({re.sub("^module.", "", k): v for k, v in state_dict.items()}, strict=False)
model.eval()
del checkpoint
models.append(model.half())
frames_per_video = 32
video_reader = VideoReader()
video_read_fn = lambda x: video_reader.read_frames(x, num_frames=frames_per_video)
face_extractor = FaceExtractor(video_read_fn)
input_size = 384
strategy = confident_strategy
stime = time.time()
test_videos = sorted([x for x in os.listdir(filepath) if x[-4:] == ".mp4"])
print("Predicting {} videos".format(len(test_videos)))
predictions = predict_on_video_set(False, face_extractor=face_extractor, input_size=input_size, models=models,
strategy=strategy, frames_per_video=frames_per_video, videos=test_videos,
num_workers=2, test_dir=filepath)
return predictions
'''
if __name__ == "__main__":
filepath = "/home/xuyu/Deepfake/deepfake_detec/"
op = Deepfakevit()
pred = op(filepath=filepath)
print(pred)
'''