logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

61 lines
2.2 KiB

import argparse
import os
import re
import string
import time
import torch
import pandas as pd
from kernel_utils import VideoReader, FaceExtractor, confident_strategy, predict_on_video_set
from classifiers import DeepFakeClassifier
import towhee
from towhee.operator.base import Operator, OperatorFlag
from towhee import register
import warnings
warnings.filterwarnings('ignore')
import logging
log = logging.getLogger()
@register(output_schema=["scorelist"],
flag=OperatorFlag.STATELESS | OperatorFlag.REUSEABLE)
class deepfake(Operator):
'''
ViT Distillation
'''
def __init__(self):
pass
def __call__(self, filepath: string) -> list:
weights_dir = "weights/"
models = []
model_paths = [os.path.join(weights_dir,model) for model in os.listdir(weights_dir)]
for path in model_paths:
model = DeepFakeClassifier(encoder="tf_efficientnet_b7_ns").to("cuda")
print("loading state dict {}".format(path))
checkpoint = torch.load(path, map_location="cpu")
state_dict = checkpoint.get("state_dict", checkpoint)
model.load_state_dict({re.sub("^module.", "", k): v for k, v in state_dict.items()}, strict=False)
model.eval()
del checkpoint
models.append(model.half())
frames_per_video = 32
video_reader = VideoReader()
video_read_fn = lambda x: video_reader.read_frames(x, num_frames=frames_per_video)
face_extractor = FaceExtractor(video_read_fn)
input_size = 384
strategy = confident_strategy
stime = time.time()
test_videos = sorted([x for x in os.listdir(filepath) if x[-4:] == ".mp4"])
print("Predicting {} videos".format(len(test_videos)))
predictions = predict_on_video_set(False, face_extractor=face_extractor, input_size=input_size, models=models,
strategy=strategy, frames_per_video=frames_per_video, videos=test_videos,
num_workers=2, test_dir=filepath)
return predictions
'''
if __name__ == "__main__":
filepath = "/home/xuyu/Deepfake/deepfake_detec/"
op = Deepfakevit()
pred = op(filepath=filepath)
print(pred)
'''