|
|
@ -35,12 +35,12 @@ class ExpansionNetV2(NNOperator): |
|
|
|
""" |
|
|
|
def __init__(self, model_name: str): |
|
|
|
super().__init__() |
|
|
|
sys.path.append(str(pathlib.Path(__file__).parent)) |
|
|
|
path = str(pathlib.Path(__file__).parent) |
|
|
|
sys.path.append(path) |
|
|
|
from models.End_ExpansionNet_v2 import End_ExpansionNet_v2 |
|
|
|
from utils.language_utils import convert_vector_idx2word |
|
|
|
self.convert_vector_idx2word = convert_vector_idx2word |
|
|
|
sys.path.pop() |
|
|
|
path = pathlib.Path(__file__).parent |
|
|
|
with open('{}/demo_coco_tokens.pickle'.format(path), 'rb') as f: |
|
|
|
coco_tokens = pickle.load(f) |
|
|
|
self.coco_tokens = coco_tokens |
|
|
@ -82,6 +82,10 @@ class ExpansionNetV2(NNOperator): |
|
|
|
max_seq_len=max_seq_len, drop_args=model_args.drop_args, |
|
|
|
rank='cpu') |
|
|
|
|
|
|
|
checkpoint = torch.load('{}/weights/rf_model.pth'.format(path), map_location=torch.device('cpu')) |
|
|
|
self.model.load_state_dict(checkpoint['model_state_dict']) |
|
|
|
self.model.to(self.device) |
|
|
|
|
|
|
|
self.transf_1 = torchvision.transforms.Compose([torchvision.transforms.Resize((img_size, img_size)), torchvision.transforms.ToTensor()]) |
|
|
|
self.transf_2 = torchvision.transforms.Compose([torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) |
|
|
|
self.beam_search_kwargs = {'beam_size': beam_size, |
|
|
|