diff --git a/__init__.py b/__init__.py index af9ed2f..3dc5c41 100644 --- a/__init__.py +++ b/__init__.py @@ -14,5 +14,5 @@ from .expansionnet_v2 import ExpansionNetV2 -def expansionnet_v2(model_name: str): - return ExpansionNetV2(model_name) +def expansionnet_v2(model_name, device = None): + return ExpansionNetV2(model_name, device) diff --git a/expansionnet_v2.py b/expansionnet_v2.py index da5e9c9..d7c6cf6 100644 --- a/expansionnet_v2.py +++ b/expansionnet_v2.py @@ -33,7 +33,7 @@ class ExpansionNetV2(NNOperator): """ ExpansionNet V2 image captioning operator """ - def __init__(self, model_name: str): + def __init__(self, model_name: str, device: str = None): super().__init__() path = str(pathlib.Path(__file__).parent) sys.path.append(path) @@ -50,7 +50,9 @@ class ExpansionNetV2(NNOperator): self.coco_tokens = coco_tokens img_size = 384 - self.device = "cuda" if torch.cuda.is_available() else "cpu" + if device == None: + device = "cuda" if torch.cuda.is_available() else "cpu" + self.device = device drop_args = Namespace(enc=0.0, dec=0.0, enc_input=0.0, @@ -84,7 +86,7 @@ class ExpansionNetV2(NNOperator): output_word2idx=coco_tokens['word2idx_dict'], output_idx2word=coco_tokens['idx2word_list'], max_seq_len=max_seq_len, drop_args=model_args.drop_args, - rank='cpu') + rank=self.device) checkpoint = torch.load('{}/weights/{}'.format(path,os.path.basename(cfg['weights'])), map_location=torch.device('cpu')) self.model.load_state_dict(checkpoint['model_state_dict'])