diff --git a/README.md b/README.md
index f10f45e..29a9b45 100644
--- a/README.md
+++ b/README.md
@@ -19,19 +19,21 @@ and use deepfake operator to predict the probabilities of fake videos.
```python
-import towhee
-(
- towhee.glob['path']('/home/test_video')
- .deepfake['path', 'scores']()
- .select['path', 'scores']()
- .show()
+from towhee.dc2 import pipe, ops, DataCollection
+
+p = (
+ pipe.input('path')
+ .map('path', 'scores', ops.towhee.deepfake)
+ .output('scores')
)
+
+DataCollection(p('./deepfake_video/test/aagfhgtpmv.mp4').get_dict()).show()
```
```shell
-[0.9893, 0.9097]
+[0.99]
```
@@ -50,4 +52,4 @@ Absolute address of the test videos.
**Returns:** *list*
-The probabilities of videos being fake ones.
\ No newline at end of file
+The probabilities of videos being fake ones.
diff --git a/deepfake.py b/deepfake.py
index f93a1fb..cdc133d 100644
--- a/deepfake.py
+++ b/deepfake.py
@@ -29,8 +29,7 @@ class Deepfake(NNOperator):
weights_dir = os.path.join(str(Path(__file__).parent),"weights/")
self.model_paths = [os.path.join(weights_dir,model) for model in os.listdir(weights_dir)]
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
- def __call__(self, filepath: string) -> list:
+
from kernel_utils import VideoReader, FaceExtractor, confident_strategy, predict_on_video
from classifiers import DeepFakeClassifier
models = []
@@ -42,17 +41,21 @@ class Deepfake(NNOperator):
model.load_state_dict({re.sub("^module.", "", k): v for k, v in state_dict.items()}, strict=False)
model.eval()
del checkpoint
- models.append(model.half())
- frames_per_video = 32
+ models.append(model.float())
+ self.frames_per_video = 32
video_reader = VideoReader()
- video_read_fn = lambda x: video_reader.read_frames(x, num_frames=frames_per_video)
- face_extractor = FaceExtractor(video_read_fn)
- input_size = 384
- strategy = confident_strategy
- #stime = time.time()
- prediction = predict_on_video(False, face_extractor=face_extractor, video_path=filepath,
- input_size=input_size, batch_size=frames_per_video, models=models,
- strategy=strategy, apply_compression=False)
+ video_read_fn = lambda x: video_reader.read_frames(x, num_frames=self.frames_per_video)
+
+ self.face_extractor = FaceExtractor(video_read_fn)
+ self.input_size = 384
+ self.strategy = confident_strategy
+ self.models = models
+ self.predict_on_video = predict_on_video
+
+ def __call__(self, filepath: string) -> list:
+ prediction = self.predict_on_video(False, face_extractor=self.face_extractor, video_path=filepath,
+ input_size=self.input_size, batch_size=self.frames_per_video, models=self.models,
+ strategy=self.strategy, apply_compression=False)
'''
test_videos = sorted([x for x in os.listdir(filepath) if x[-4:] == ".mp4"])
print("Predicting {} videos".format(len(test_videos)))