logo
Browse Source

add weight_path if

main
ChengZi 3 years ago
parent
commit
9f4eadbaba
  1. 4
      clip4clip.py

4
clip4clip.py

@ -23,6 +23,7 @@ from towhee import register
from towhee.models import clip4clip from towhee.models import clip4clip
from PIL import Image as PILImage from PIL import Image as PILImage
from towhee.types import VideoFrame from towhee.types import VideoFrame
from pathlib import Path
@register(output_schema=['vec']) @register(output_schema=['vec'])
@ -34,6 +35,9 @@ class CLIP4Clip(NNOperator):
def __init__(self, model_name: str, modality: str, weight_path: str = None, device: str = None): def __init__(self, model_name: str, modality: str, weight_path: str = None, device: str = None):
super().__init__() super().__init__()
self.modality = modality self.modality = modality
if weight_path is None:
weight_path = str(Path(__file__).parent / 'pytorch_model.bin.1')
print('weight_path is None, use default path: {}'.format(weight_path))
if device is None: if device is None:
self.device = "cuda" if torch.cuda.is_available() else "cpu" self.device = "cuda" if torch.cuda.is_available() else "cpu"
else: else:

Loading…
Cancel
Save