|
|
@ -29,8 +29,7 @@ class Deepfake(NNOperator): |
|
|
|
weights_dir = os.path.join(str(Path(__file__).parent),"weights/") |
|
|
|
self.model_paths = [os.path.join(weights_dir,model) for model in os.listdir(weights_dir)] |
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
def __call__(self, filepath: string) -> list: |
|
|
|
|
|
|
|
from kernel_utils import VideoReader, FaceExtractor, confident_strategy, predict_on_video |
|
|
|
from classifiers import DeepFakeClassifier |
|
|
|
models = [] |
|
|
@ -42,17 +41,21 @@ class Deepfake(NNOperator): |
|
|
|
model.load_state_dict({re.sub("^module.", "", k): v for k, v in state_dict.items()}, strict=False) |
|
|
|
model.eval() |
|
|
|
del checkpoint |
|
|
|
models.append(model.half()) |
|
|
|
frames_per_video = 32 |
|
|
|
models.append(model.float()) |
|
|
|
self.frames_per_video = 32 |
|
|
|
video_reader = VideoReader() |
|
|
|
video_read_fn = lambda x: video_reader.read_frames(x, num_frames=frames_per_video) |
|
|
|
face_extractor = FaceExtractor(video_read_fn) |
|
|
|
input_size = 384 |
|
|
|
strategy = confident_strategy |
|
|
|
#stime = time.time() |
|
|
|
prediction = predict_on_video(False, face_extractor=face_extractor, video_path=filepath, |
|
|
|
input_size=input_size, batch_size=frames_per_video, models=models, |
|
|
|
strategy=strategy, apply_compression=False) |
|
|
|
video_read_fn = lambda x: video_reader.read_frames(x, num_frames=self.frames_per_video) |
|
|
|
|
|
|
|
self.face_extractor = FaceExtractor(video_read_fn) |
|
|
|
self.input_size = 384 |
|
|
|
self.strategy = confident_strategy |
|
|
|
self.models = models |
|
|
|
self.predict_on_video = predict_on_video |
|
|
|
|
|
|
|
def __call__(self, filepath: string) -> list: |
|
|
|
prediction = self.predict_on_video(False, face_extractor=self.face_extractor, video_path=filepath, |
|
|
|
input_size=self.input_size, batch_size=self.frames_per_video, models=self.models, |
|
|
|
strategy=self.strategy, apply_compression=False) |
|
|
|
''' |
|
|
|
test_videos = sorted([x for x in os.listdir(filepath) if x[-4:] == ".mp4"]) |
|
|
|
print("Predicting {} videos".format(len(test_videos))) |
|
|
|