logo
Browse Source

add weight_path if

main
ChengZi 2 years ago
parent
commit
9f4eadbaba
  1. 4
      clip4clip.py

4
clip4clip.py

@ -23,6 +23,7 @@ from towhee import register
from towhee.models import clip4clip
from PIL import Image as PILImage
from towhee.types import VideoFrame
from pathlib import Path
@register(output_schema=['vec'])
@ -34,6 +35,9 @@ class CLIP4Clip(NNOperator):
def __init__(self, model_name: str, modality: str, weight_path: str = None, device: str = None):
super().__init__()
self.modality = modality
if weight_path is None:
weight_path = str(Path(__file__).parent / 'pytorch_model.bin.1')
print('weight_path is None, use default path: {}'.format(weight_path))
if device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu"
else:

Loading…
Cancel
Save