diff --git a/expansionnet_v2.py b/expansionnet_v2.py index 2020cf3..90a7e12 100644 --- a/expansionnet_v2.py +++ b/expansionnet_v2.py @@ -35,12 +35,12 @@ class ExpansionNetV2(NNOperator): """ def __init__(self, model_name: str): super().__init__() - sys.path.append(str(pathlib.Path(__file__).parent)) + path = str(pathlib.Path(__file__).parent) + sys.path.append(path) from models.End_ExpansionNet_v2 import End_ExpansionNet_v2 from utils.language_utils import convert_vector_idx2word self.convert_vector_idx2word = convert_vector_idx2word sys.path.pop() - path = pathlib.Path(__file__).parent with open('{}/demo_coco_tokens.pickle'.format(path), 'rb') as f: coco_tokens = pickle.load(f) self.coco_tokens = coco_tokens @@ -82,6 +82,10 @@ class ExpansionNetV2(NNOperator): max_seq_len=max_seq_len, drop_args=model_args.drop_args, rank='cpu') + checkpoint = torch.load('{}/weights/rf_model.pth'.format(path), map_location=torch.device('cpu')) + self.model.load_state_dict(checkpoint['model_state_dict']) + self.model.to(self.device) + self.transf_1 = torchvision.transforms.Compose([torchvision.transforms.Resize((img_size, img_size)), torchvision.transforms.ToTensor()]) self.transf_2 = torchvision.transforms.Compose([torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]) self.beam_search_kwargs = {'beam_size': beam_size, diff --git a/weights/rf_model.pth b/weights/rf_model.pth new file mode 100644 index 0000000..3b9ac2f --- /dev/null +++ b/weights/rf_model.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4abee9777d227400af36717de758c45bb2ac3e8b0a3935fe1c7af294c05d1ebe +size 2819735367