import argparse import os import re import string import time import torch import pandas as pd from .kernel_utils import VideoReader, FaceExtractor, confident_strategy, predict_on_video_set from .classifiers import DeepFakeClassifier import towhee from towhee.operator.base import NNOperator, OperatorFlag from towhee import register import warnings warnings.filterwarnings('ignore') import logging log = logging.getLogger() @register(output_schema=["scorelist"], flag=OperatorFlag.STATELESS | OperatorFlag.REUSEABLE) class Deepfake(NNOperator): ''' ViT Distillation ''' def __init__(self): pass def __call__(self, filepath: string) -> list: weights_dir = "weights/" models = [] model_paths = [os.path.join(weights_dir,model) for model in os.listdir(weights_dir)] for path in model_paths: model = DeepFakeClassifier(encoder="tf_efficientnet_b7_ns").to("cuda") print("loading state dict {}".format(path)) checkpoint = torch.load(path, map_location="cpu") state_dict = checkpoint.get("state_dict", checkpoint) model.load_state_dict({re.sub("^module.", "", k): v for k, v in state_dict.items()}, strict=False) model.eval() del checkpoint models.append(model.half()) frames_per_video = 32 video_reader = VideoReader() video_read_fn = lambda x: video_reader.read_frames(x, num_frames=frames_per_video) face_extractor = FaceExtractor(video_read_fn) input_size = 384 strategy = confident_strategy stime = time.time() test_videos = sorted([x for x in os.listdir(filepath) if x[-4:] == ".mp4"]) print("Predicting {} videos".format(len(test_videos))) predictions = predict_on_video_set(False, face_extractor=face_extractor, input_size=input_size, models=models, strategy=strategy, frames_per_video=frames_per_video, videos=test_videos, num_workers=2, test_dir=filepath) return predictions ''' if __name__ == "__main__": filepath = "/home/xuyu/Deepfake/deepfake_detec/" op = Deepfakevit() pred = op(filepath=filepath) print(pred) '''