|
|
@ -28,6 +28,7 @@ from towhee.types.image_utils import to_pil |
|
|
|
from towhee.operator.base import NNOperator, OperatorFlag |
|
|
|
from towhee import register |
|
|
|
from towhee.models import clip |
|
|
|
from towhee.command.s3 import S3Bucket |
|
|
|
|
|
|
|
class ExpansionNetV2(NNOperator): |
|
|
|
""" |
|
|
@ -40,6 +41,12 @@ class ExpansionNetV2(NNOperator): |
|
|
|
from models.End_ExpansionNet_v2 import End_ExpansionNet_v2 |
|
|
|
from utils.language_utils import convert_vector_idx2word |
|
|
|
self.convert_vector_idx2word = convert_vector_idx2word |
|
|
|
|
|
|
|
cfg = self._configs()[model_name] |
|
|
|
|
|
|
|
s3_bucket = S3Bucket() |
|
|
|
s3_bucket.download_file(cfg['weights'], path + '/weights/') |
|
|
|
|
|
|
|
sys.path.pop() |
|
|
|
with open('{}/demo_coco_tokens.pickle'.format(path), 'rb') as f: |
|
|
|
coco_tokens = pickle.load(f) |
|
|
@ -82,8 +89,7 @@ class ExpansionNetV2(NNOperator): |
|
|
|
max_seq_len=max_seq_len, drop_args=model_args.drop_args, |
|
|
|
rank='cpu') |
|
|
|
|
|
|
|
cfg = self.model._configs()[model_name] |
|
|
|
checkpoint = torch.load('{}/weights/{}'.format(cfg,cfg['weights']), map_location=torch.device('cpu')) |
|
|
|
checkpoint = torch.load('{}/weights/{}'.format(path,os.path.basename(cfg['weights'])), map_location=torch.device('cpu')) |
|
|
|
self.model.load_state_dict(checkpoint['model_state_dict']) |
|
|
|
self.model.to(self.device) |
|
|
|
|
|
|
@ -137,5 +143,5 @@ class ExpansionNetV2(NNOperator): |
|
|
|
def _configs(self): |
|
|
|
config = {} |
|
|
|
config['expansionnet_rf'] = {} |
|
|
|
config['expansionnet_rf']['weights'] = 'rf_model.pth' |
|
|
|
config['expansionnet_rf']['weights'] = 'image-captioning/expansionnet-v2/rf_model.pth' |
|
|
|
return config |
|
|
|