logo
Browse Source

modify factory

main
ChengZi 2 years ago
parent
commit
1a61aee75a
  1. 4
      __init__.py
  2. 2
      clip4clip.py

4
__init__.py

@ -15,6 +15,6 @@
from .clip4clip import CLIP4Clip
def clip4clip(model_name: str, modality: str, weight_path: str, device: str):
return CLIP4Clip(model_name, modality, weight_path, device)
def clip4clip(model_name: str, modality: str, **kwargs):
return CLIP4Clip(model_name, modality, **kwargs)

2
clip4clip.py

@ -37,7 +37,7 @@ class CLIP4Clip(NNOperator):
self.modality = modality
if weight_path is None:
weight_path = str(Path(__file__).parent / 'pytorch_model.bin.1')
print('weight_path is None, use default path: {}'.format(weight_path))
# print('weight_path is None, use default path: {}'.format(weight_path))
if device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
else:

Loading…
Cancel
Save