From af149af7c659412b74684c22db799bc1c4b83ac8 Mon Sep 17 00:00:00 2001 From: jinlingxu06 Date: Fri, 24 Mar 2023 16:11:07 +0800 Subject: [PATCH] fix gpu related problem. Signed-off-by: jinlingxu06 --- __init__.py | 4 ++-- expansionnet_v2.py | 8 +++++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/__init__.py b/__init__.py index af9ed2f..3dc5c41 100644 --- a/__init__.py +++ b/__init__.py @@ -14,5 +14,5 @@ from .expansionnet_v2 import ExpansionNetV2 -def expansionnet_v2(model_name: str): - return ExpansionNetV2(model_name) +def expansionnet_v2(model_name, device = None): + return ExpansionNetV2(model_name, device) diff --git a/expansionnet_v2.py b/expansionnet_v2.py index da5e9c9..d7c6cf6 100644 --- a/expansionnet_v2.py +++ b/expansionnet_v2.py @@ -33,7 +33,7 @@ class ExpansionNetV2(NNOperator): """ ExpansionNet V2 image captioning operator """ - def __init__(self, model_name: str): + def __init__(self, model_name: str, device: str = None): super().__init__() path = str(pathlib.Path(__file__).parent) sys.path.append(path) @@ -50,7 +50,9 @@ class ExpansionNetV2(NNOperator): self.coco_tokens = coco_tokens img_size = 384 - self.device = "cuda" if torch.cuda.is_available() else "cpu" + if device == None: + device = "cuda" if torch.cuda.is_available() else "cpu" + self.device = device drop_args = Namespace(enc=0.0, dec=0.0, enc_input=0.0, @@ -84,7 +86,7 @@ class ExpansionNetV2(NNOperator): output_word2idx=coco_tokens['word2idx_dict'], output_idx2word=coco_tokens['idx2word_list'], max_seq_len=max_seq_len, drop_args=model_args.drop_args, - rank='cpu') + rank=self.device) checkpoint = torch.load('{}/weights/{}'.format(path,os.path.basename(cfg['weights'])), map_location=torch.device('cpu')) self.model.load_state_dict(checkpoint['model_state_dict'])