Browse Source
fix gpu related problem.
Signed-off-by: jinlingxu06 <jinling.xu@zilliz.com>
main
jinlingxu06
2 years ago
2 changed files with
7 additions and
5 deletions
-
__init__.py
-
expansionnet_v2.py
|
|
@ -14,5 +14,5 @@ |
|
|
|
|
|
|
|
from .expansionnet_v2 import ExpansionNetV2 |
|
|
|
|
|
|
|
def expansionnet_v2(model_name: str): |
|
|
|
return ExpansionNetV2(model_name) |
|
|
|
def expansionnet_v2(model_name, device = None): |
|
|
|
return ExpansionNetV2(model_name, device) |
|
|
|
|
|
@ -33,7 +33,7 @@ class ExpansionNetV2(NNOperator): |
|
|
|
""" |
|
|
|
ExpansionNet V2 image captioning operator |
|
|
|
""" |
|
|
|
def __init__(self, model_name: str): |
|
|
|
def __init__(self, model_name: str, device: str = None): |
|
|
|
super().__init__() |
|
|
|
path = str(pathlib.Path(__file__).parent) |
|
|
|
sys.path.append(path) |
|
|
@ -50,7 +50,9 @@ class ExpansionNetV2(NNOperator): |
|
|
|
self.coco_tokens = coco_tokens |
|
|
|
|
|
|
|
img_size = 384 |
|
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
if device == None: |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
self.device = device |
|
|
|
drop_args = Namespace(enc=0.0, |
|
|
|
dec=0.0, |
|
|
|
enc_input=0.0, |
|
|
@ -84,7 +86,7 @@ class ExpansionNetV2(NNOperator): |
|
|
|
output_word2idx=coco_tokens['word2idx_dict'], |
|
|
|
output_idx2word=coco_tokens['idx2word_list'], |
|
|
|
max_seq_len=max_seq_len, drop_args=model_args.drop_args, |
|
|
|
rank='cpu') |
|
|
|
rank=self.device) |
|
|
|
|
|
|
|
checkpoint = torch.load('{}/weights/{}'.format(path,os.path.basename(cfg['weights'])), map_location=torch.device('cpu')) |
|
|
|
self.model.load_state_dict(checkpoint['model_state_dict']) |
|
|
|