logo
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions

36 lines
1.2 KiB

import torch
from collections import OrderedDict
def load_encoder_checkpoint(checkpoint_path: str, output_dim: int) -> OrderedDict:
state_dict = torch.load(checkpoint_path, map_location=torch.device("cpu"))
if "pytorch-lightning_version" in state_dict.keys():
new_state_dict = OrderedDict(
{
k.replace("model.encoder.", ""): v
for k, v in state_dict["state_dict"].items()
if "model.encoder." in k
}
)
else:
new_state_dict = OrderedDict()
for k, v in state_dict.items():
if "encoder." in k:
new_state_dict[k.replace("encoder.", "")] = v
new_state_dict["fc.weight"] = torch.zeros(output_dim, 512)
new_state_dict["fc.bias"] = torch.zeros(output_dim)
return new_state_dict
def load_finetuner_checkpoint(checkpoint_path: str) -> OrderedDict:
state_dict = torch.load(checkpoint_path, map_location=torch.device("cpu"))
if "pytorch-lightning_version" in state_dict.keys():
state_dict = OrderedDict(
{
k.replace("model.", ""): v
for k, v in state_dict["state_dict"].items()
if "model." in k
}
)
return state_dict