diff --git a/__init__.py b/__init__.py index c68e45b..3b16397 100644 --- a/__init__.py +++ b/__init__.py @@ -15,6 +15,6 @@ from .clip4clip import CLIP4Clip -def clip4clip(model_name: str, modality: str, weight_path: str, device: str): - return CLIP4Clip(model_name, modality, weight_path, device) +def clip4clip(model_name: str, modality: str, **kwargs): + return CLIP4Clip(model_name, modality, **kwargs) diff --git a/clip4clip.py b/clip4clip.py index 2a8fb24..3cb1b9b 100644 --- a/clip4clip.py +++ b/clip4clip.py @@ -37,7 +37,7 @@ class CLIP4Clip(NNOperator): self.modality = modality if weight_path is None: weight_path = str(Path(__file__).parent / 'pytorch_model.bin.1') - print('weight_path is None, use default path: {}'.format(weight_path)) + # print('weight_path is None, use default path: {}'.format(weight_path)) if device is None: self.device = "cuda" if torch.cuda.is_available() else "cpu" else: