logo
Browse Source

update the operator.

Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb 2 years ago
parent
commit
81f242e57d
  1. 12
      expansionnet_v2.py

12
expansionnet_v2.py

@ -28,6 +28,7 @@ from towhee.types.image_utils import to_pil
from towhee.operator.base import NNOperator, OperatorFlag from towhee.operator.base import NNOperator, OperatorFlag
from towhee import register from towhee import register
from towhee.models import clip from towhee.models import clip
from towhee.command.s3 import S3Bucket
class ExpansionNetV2(NNOperator): class ExpansionNetV2(NNOperator):
""" """
@ -40,6 +41,12 @@ class ExpansionNetV2(NNOperator):
from models.End_ExpansionNet_v2 import End_ExpansionNet_v2 from models.End_ExpansionNet_v2 import End_ExpansionNet_v2
from utils.language_utils import convert_vector_idx2word from utils.language_utils import convert_vector_idx2word
self.convert_vector_idx2word = convert_vector_idx2word self.convert_vector_idx2word = convert_vector_idx2word
cfg = self._configs()[model_name]
s3_bucket = S3Bucket()
s3_bucket.download_file(cfg['weights'], path + '/weights/')
sys.path.pop() sys.path.pop()
with open('{}/demo_coco_tokens.pickle'.format(path), 'rb') as f: with open('{}/demo_coco_tokens.pickle'.format(path), 'rb') as f:
coco_tokens = pickle.load(f) coco_tokens = pickle.load(f)
@ -82,8 +89,7 @@ class ExpansionNetV2(NNOperator):
max_seq_len=max_seq_len, drop_args=model_args.drop_args, max_seq_len=max_seq_len, drop_args=model_args.drop_args,
rank='cpu') rank='cpu')
cfg = self.model._configs()[model_name]
checkpoint = torch.load('{}/weights/{}'.format(cfg,cfg['weights']), map_location=torch.device('cpu'))
checkpoint = torch.load('{}/weights/{}'.format(path,os.path.basename(cfg['weights'])), map_location=torch.device('cpu'))
self.model.load_state_dict(checkpoint['model_state_dict']) self.model.load_state_dict(checkpoint['model_state_dict'])
self.model.to(self.device) self.model.to(self.device)
@ -137,5 +143,5 @@ class ExpansionNetV2(NNOperator):
def _configs(self): def _configs(self):
config = {} config = {}
config['expansionnet_rf'] = {} config['expansionnet_rf'] = {}
config['expansionnet_rf']['weights'] = 'rf_model.pth'
config['expansionnet_rf']['weights'] = 'image-captioning/expansionnet-v2/rf_model.pth'
return config return config

Loading…
Cancel
Save