logo
Browse Source

fix gpu related problem.

Signed-off-by: jinlingxu06 <jinling.xu@zilliz.com>
main
jinlingxu06 2 years ago
parent
commit
af149af7c6
  1. 4
      __init__.py
  2. 8
      expansionnet_v2.py

4
__init__.py

@ -14,5 +14,5 @@
from .expansionnet_v2 import ExpansionNetV2 from .expansionnet_v2 import ExpansionNetV2
def expansionnet_v2(model_name: str):
return ExpansionNetV2(model_name)
def expansionnet_v2(model_name, device = None):
return ExpansionNetV2(model_name, device)

8
expansionnet_v2.py

@ -33,7 +33,7 @@ class ExpansionNetV2(NNOperator):
""" """
ExpansionNet V2 image captioning operator ExpansionNet V2 image captioning operator
""" """
def __init__(self, model_name: str):
def __init__(self, model_name: str, device: str = None):
super().__init__() super().__init__()
path = str(pathlib.Path(__file__).parent) path = str(pathlib.Path(__file__).parent)
sys.path.append(path) sys.path.append(path)
@ -50,7 +50,9 @@ class ExpansionNetV2(NNOperator):
self.coco_tokens = coco_tokens self.coco_tokens = coco_tokens
img_size = 384 img_size = 384
self.device = "cuda" if torch.cuda.is_available() else "cpu"
if device == None:
device = "cuda" if torch.cuda.is_available() else "cpu"
self.device = device
drop_args = Namespace(enc=0.0, drop_args = Namespace(enc=0.0,
dec=0.0, dec=0.0,
enc_input=0.0, enc_input=0.0,
@ -84,7 +86,7 @@ class ExpansionNetV2(NNOperator):
output_word2idx=coco_tokens['word2idx_dict'], output_word2idx=coco_tokens['word2idx_dict'],
output_idx2word=coco_tokens['idx2word_list'], output_idx2word=coco_tokens['idx2word_list'],
max_seq_len=max_seq_len, drop_args=model_args.drop_args, max_seq_len=max_seq_len, drop_args=model_args.drop_args,
rank='cpu')
rank=self.device)
checkpoint = torch.load('{}/weights/{}'.format(path,os.path.basename(cfg['weights'])), map_location=torch.device('cpu')) checkpoint = torch.load('{}/weights/{}'.format(path,os.path.basename(cfg['weights'])), map_location=torch.device('cpu'))
self.model.load_state_dict(checkpoint['model_state_dict']) self.model.load_state_dict(checkpoint['model_state_dict'])

Loading…
Cancel
Save