diff --git a/expansionnet_v2.py b/expansionnet_v2.py index 4902073..62d52aa 100644 --- a/expansionnet_v2.py +++ b/expansionnet_v2.py @@ -28,6 +28,7 @@ from towhee.types.image_utils import to_pil from towhee.operator.base import NNOperator, OperatorFlag from towhee import register from towhee.models import clip +from towhee.command.s3 import S3Bucket class ExpansionNetV2(NNOperator): """ @@ -40,6 +41,12 @@ class ExpansionNetV2(NNOperator): from models.End_ExpansionNet_v2 import End_ExpansionNet_v2 from utils.language_utils import convert_vector_idx2word self.convert_vector_idx2word = convert_vector_idx2word + + cfg = self._configs()[model_name] + + s3_bucket = S3Bucket() + s3_bucket.download_file(cfg['weights'], path + '/weights/') + sys.path.pop() with open('{}/demo_coco_tokens.pickle'.format(path), 'rb') as f: coco_tokens = pickle.load(f) @@ -82,8 +89,7 @@ class ExpansionNetV2(NNOperator): max_seq_len=max_seq_len, drop_args=model_args.drop_args, rank='cpu') - cfg = self.model._configs()[model_name] - checkpoint = torch.load('{}/weights/{}'.format(cfg,cfg['weights']), map_location=torch.device('cpu')) + checkpoint = torch.load('{}/weights/{}'.format(path,os.path.basename(cfg['weights'])), map_location=torch.device('cpu')) self.model.load_state_dict(checkpoint['model_state_dict']) self.model.to(self.device) @@ -137,5 +143,5 @@ class ExpansionNetV2(NNOperator): def _configs(self): config = {} config['expansionnet_rf'] = {} - config['expansionnet_rf']['weights'] = 'rf_model.pth' + config['expansionnet_rf']['weights'] = 'image-captioning/expansionnet-v2/rf_model.pth' return config