logo
Browse Source

upload the weight.

Signed-off-by: wxywb <xy.wang@zilliz.com>
main
wxywb 2 years ago
parent
commit
11063de701
  1. 8
      expansionnet_v2.py
  2. BIN
      weights/rf_model.pth

8
expansionnet_v2.py

@ -35,12 +35,12 @@ class ExpansionNetV2(NNOperator):
"""
def __init__(self, model_name: str):
super().__init__()
sys.path.append(str(pathlib.Path(__file__).parent))
path = str(pathlib.Path(__file__).parent)
sys.path.append(path)
from models.End_ExpansionNet_v2 import End_ExpansionNet_v2
from utils.language_utils import convert_vector_idx2word
self.convert_vector_idx2word = convert_vector_idx2word
sys.path.pop()
path = pathlib.Path(__file__).parent
with open('{}/demo_coco_tokens.pickle'.format(path), 'rb') as f:
coco_tokens = pickle.load(f)
self.coco_tokens = coco_tokens
@ -82,6 +82,10 @@ class ExpansionNetV2(NNOperator):
max_seq_len=max_seq_len, drop_args=model_args.drop_args,
rank='cpu')
checkpoint = torch.load('{}/weights/rf_model.pth'.format(path), map_location=torch.device('cpu'))
self.model.load_state_dict(checkpoint['model_state_dict'])
self.model.to(self.device)
self.transf_1 = torchvision.transforms.Compose([torchvision.transforms.Resize((img_size, img_size)), torchvision.transforms.ToTensor()])
self.transf_2 = torchvision.transforms.Compose([torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])])
self.beam_search_kwargs = {'beam_size': beam_size,

BIN
weights/rf_model.pth (Stored with Git LFS)

Binary file not shown.
Loading…
Cancel
Save