Browse Source
use the corresponding device in the save method.
Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb
2 years ago
1 changed files with
2 additions and
3 deletions
-
clip.py
|
@ -193,10 +193,10 @@ class Clip(NNOperator): |
|
|
h = sz['height'] |
|
|
h = sz['height'] |
|
|
w = sz['width'] |
|
|
w = sz['width'] |
|
|
dummy_input = Image.new('RGB', (w, h), color = 'red') |
|
|
dummy_input = Image.new('RGB', (w, h), color = 'red') |
|
|
inputs = self.processor(images=dummy_input, return_tensors='pt') # a dictionary |
|
|
|
|
|
|
|
|
inputs = self.processor(images=dummy_input, return_tensors='pt').to(self.device) # a dictionary |
|
|
elif self.modality == 'text': |
|
|
elif self.modality == 'text': |
|
|
dummy_input = 'dummy' |
|
|
dummy_input = 'dummy' |
|
|
inputs = self.tokenizer(dummy_input, padding=True, truncation=True, return_tensors='pt') # a dictionary |
|
|
|
|
|
|
|
|
inputs = self.tokenizer(dummy_input, padding=True, truncation=True, return_tensors='pt').to(self.device) # a dictionary |
|
|
else: |
|
|
else: |
|
|
raise ValueError("modality[{}] not implemented.".format(self.modality)) |
|
|
raise ValueError("modality[{}] not implemented.".format(self.modality)) |
|
|
|
|
|
|
|
@ -235,6 +235,5 @@ class Clip(NNOperator): |
|
|
opset_version=14, |
|
|
opset_version=14, |
|
|
) |
|
|
) |
|
|
else: |
|
|
else: |
|
|
pass |
|
|
|
|
|
raise NotImplementedError |
|
|
raise NotImplementedError |
|
|
|
|
|
|
|
|