diff --git a/clip4clip.py b/clip4clip.py index b6c1ea7..2a8fb24 100644 --- a/clip4clip.py +++ b/clip4clip.py @@ -23,6 +23,7 @@ from towhee import register from towhee.models import clip4clip from PIL import Image as PILImage from towhee.types import VideoFrame +from pathlib import Path @register(output_schema=['vec']) @@ -34,6 +35,9 @@ class CLIP4Clip(NNOperator): def __init__(self, model_name: str, modality: str, weight_path: str = None, device: str = None): super().__init__() self.modality = modality + if weight_path is None: + weight_path = str(Path(__file__).parent / 'pytorch_model.bin.1') + print('weight_path is None, use default path: {}'.format(weight_path)) if device is None: self.device = "cuda" if torch.cuda.is_available() else "cpu" else: