|
@ -23,6 +23,7 @@ from towhee import register |
|
|
from towhee.models import clip4clip |
|
|
from towhee.models import clip4clip |
|
|
from PIL import Image as PILImage |
|
|
from PIL import Image as PILImage |
|
|
from towhee.types import VideoFrame |
|
|
from towhee.types import VideoFrame |
|
|
|
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@register(output_schema=['vec']) |
|
|
@register(output_schema=['vec']) |
|
@ -34,6 +35,9 @@ class CLIP4Clip(NNOperator): |
|
|
def __init__(self, model_name: str, modality: str, weight_path: str = None, device: str = None): |
|
|
def __init__(self, model_name: str, modality: str, weight_path: str = None, device: str = None): |
|
|
super().__init__() |
|
|
super().__init__() |
|
|
self.modality = modality |
|
|
self.modality = modality |
|
|
|
|
|
if weight_path is None: |
|
|
|
|
|
weight_path = str(Path(__file__).parent / 'pytorch_model.bin.1') |
|
|
|
|
|
print('weight_path is None, use default path: {}'.format(weight_path)) |
|
|
if device is None: |
|
|
if device is None: |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
else: |
|
|
else: |
|
|