logo
Browse Source

update the operator.

Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb 2 years ago
parent
commit
81f242e57d
  1. 12
      expansionnet_v2.py

12
expansionnet_v2.py

@ -28,6 +28,7 @@ from towhee.types.image_utils import to_pil
from towhee.operator.base import NNOperator, OperatorFlag
from towhee import register
from towhee.models import clip
from towhee.command.s3 import S3Bucket
class ExpansionNetV2(NNOperator):
"""
@ -40,6 +41,12 @@ class ExpansionNetV2(NNOperator):
from models.End_ExpansionNet_v2 import End_ExpansionNet_v2
from utils.language_utils import convert_vector_idx2word
self.convert_vector_idx2word = convert_vector_idx2word
cfg = self._configs()[model_name]
s3_bucket = S3Bucket()
s3_bucket.download_file(cfg['weights'], path + '/weights/')
sys.path.pop()
with open('{}/demo_coco_tokens.pickle'.format(path), 'rb') as f:
coco_tokens = pickle.load(f)
@ -82,8 +89,7 @@ class ExpansionNetV2(NNOperator):
max_seq_len=max_seq_len, drop_args=model_args.drop_args,
rank='cpu')
cfg = self.model._configs()[model_name]
checkpoint = torch.load('{}/weights/{}'.format(cfg,cfg['weights']), map_location=torch.device('cpu'))
checkpoint = torch.load('{}/weights/{}'.format(path,os.path.basename(cfg['weights'])), map_location=torch.device('cpu'))
self.model.load_state_dict(checkpoint['model_state_dict'])
self.model.to(self.device)
@ -137,5 +143,5 @@ class ExpansionNetV2(NNOperator):
def _configs(self):
config = {}
config['expansionnet_rf'] = {}
config['expansionnet_rf']['weights'] = 'rf_model.pth'
config['expansionnet_rf']['weights'] = 'image-captioning/expansionnet-v2/rf_model.pth'
return config

Loading…
Cancel
Save