diff --git a/README.md b/README.md index f10f45e..29a9b45 100644 --- a/README.md +++ b/README.md @@ -19,19 +19,21 @@ and use deepfake operator to predict the probabilities of fake videos. ```python -import towhee -( - towhee.glob['path']('/home/test_video') - .deepfake['path', 'scores']() - .select['path', 'scores']() - .show() +from towhee.dc2 import pipe, ops, DataCollection + +p = ( + pipe.input('path') + .map('path', 'scores', ops.towhee.deepfake) + .output('scores') ) + +DataCollection(p('./deepfake_video/test/aagfhgtpmv.mp4').get_dict()).show() ``` ```shell -[0.9893, 0.9097] +[0.99] ```
@@ -50,4 +52,4 @@ Absolute address of the test videos. **Returns:** *list* -The probabilities of videos being fake ones. \ No newline at end of file +The probabilities of videos being fake ones. diff --git a/deepfake.py b/deepfake.py index f93a1fb..cdc133d 100644 --- a/deepfake.py +++ b/deepfake.py @@ -29,8 +29,7 @@ class Deepfake(NNOperator): weights_dir = os.path.join(str(Path(__file__).parent),"weights/") self.model_paths = [os.path.join(weights_dir,model) for model in os.listdir(weights_dir)] self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - def __call__(self, filepath: string) -> list: + from kernel_utils import VideoReader, FaceExtractor, confident_strategy, predict_on_video from classifiers import DeepFakeClassifier models = [] @@ -42,17 +41,21 @@ class Deepfake(NNOperator): model.load_state_dict({re.sub("^module.", "", k): v for k, v in state_dict.items()}, strict=False) model.eval() del checkpoint - models.append(model.half()) - frames_per_video = 32 + models.append(model.float()) + self.frames_per_video = 32 video_reader = VideoReader() - video_read_fn = lambda x: video_reader.read_frames(x, num_frames=frames_per_video) - face_extractor = FaceExtractor(video_read_fn) - input_size = 384 - strategy = confident_strategy - #stime = time.time() - prediction = predict_on_video(False, face_extractor=face_extractor, video_path=filepath, - input_size=input_size, batch_size=frames_per_video, models=models, - strategy=strategy, apply_compression=False) + video_read_fn = lambda x: video_reader.read_frames(x, num_frames=self.frames_per_video) + + self.face_extractor = FaceExtractor(video_read_fn) + self.input_size = 384 + self.strategy = confident_strategy + self.models = models + self.predict_on_video = predict_on_video + + def __call__(self, filepath: string) -> list: + prediction = self.predict_on_video(False, face_extractor=self.face_extractor, video_path=filepath, + input_size=self.input_size, batch_size=self.frames_per_video, models=self.models, + strategy=self.strategy, apply_compression=False) ''' test_videos = sorted([x for x in os.listdir(filepath) if x[-4:] == ".mp4"]) print("Predicting {} videos".format(len(test_videos)))