|
@ -33,7 +33,7 @@ class ExpansionNetV2(NNOperator): |
|
|
""" |
|
|
""" |
|
|
ExpansionNet V2 image captioning operator |
|
|
ExpansionNet V2 image captioning operator |
|
|
""" |
|
|
""" |
|
|
def __init__(self, model_name: str): |
|
|
|
|
|
|
|
|
def __init__(self, model_name: str, device: str = None): |
|
|
super().__init__() |
|
|
super().__init__() |
|
|
path = str(pathlib.Path(__file__).parent) |
|
|
path = str(pathlib.Path(__file__).parent) |
|
|
sys.path.append(path) |
|
|
sys.path.append(path) |
|
@ -50,7 +50,9 @@ class ExpansionNetV2(NNOperator): |
|
|
self.coco_tokens = coco_tokens |
|
|
self.coco_tokens = coco_tokens |
|
|
|
|
|
|
|
|
img_size = 384 |
|
|
img_size = 384 |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
if device == None: |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
self.device = device |
|
|
drop_args = Namespace(enc=0.0, |
|
|
drop_args = Namespace(enc=0.0, |
|
|
dec=0.0, |
|
|
dec=0.0, |
|
|
enc_input=0.0, |
|
|
enc_input=0.0, |
|
@ -84,7 +86,7 @@ class ExpansionNetV2(NNOperator): |
|
|
output_word2idx=coco_tokens['word2idx_dict'], |
|
|
output_word2idx=coco_tokens['word2idx_dict'], |
|
|
output_idx2word=coco_tokens['idx2word_list'], |
|
|
output_idx2word=coco_tokens['idx2word_list'], |
|
|
max_seq_len=max_seq_len, drop_args=model_args.drop_args, |
|
|
max_seq_len=max_seq_len, drop_args=model_args.drop_args, |
|
|
rank='cpu') |
|
|
|
|
|
|
|
|
rank=self.device) |
|
|
|
|
|
|
|
|
checkpoint = torch.load('{}/weights/{}'.format(path,os.path.basename(cfg['weights'])), map_location=torch.device('cpu')) |
|
|
checkpoint = torch.load('{}/weights/{}'.format(path,os.path.basename(cfg['weights'])), map_location=torch.device('cpu')) |
|
|
self.model.load_state_dict(checkpoint['model_state_dict']) |
|
|
self.model.load_state_dict(checkpoint['model_state_dict']) |
|
|