clmr
copied
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Readme
Files and versions
37 lines
1.2 KiB
37 lines
1.2 KiB
3 years ago
|
import torch
|
||
|
from collections import OrderedDict
|
||
|
|
||
|
|
||
|
def load_encoder_checkpoint(checkpoint_path: str, output_dim: int) -> OrderedDict:
|
||
|
state_dict = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||
|
if "pytorch-lightning_version" in state_dict.keys():
|
||
|
new_state_dict = OrderedDict(
|
||
|
{
|
||
|
k.replace("model.encoder.", ""): v
|
||
|
for k, v in state_dict["state_dict"].items()
|
||
|
if "model.encoder." in k
|
||
|
}
|
||
|
)
|
||
|
else:
|
||
|
new_state_dict = OrderedDict()
|
||
|
for k, v in state_dict.items():
|
||
|
if "encoder." in k:
|
||
|
new_state_dict[k.replace("encoder.", "")] = v
|
||
|
|
||
|
new_state_dict["fc.weight"] = torch.zeros(output_dim, 512)
|
||
|
new_state_dict["fc.bias"] = torch.zeros(output_dim)
|
||
|
return new_state_dict
|
||
|
|
||
|
|
||
|
def load_finetuner_checkpoint(checkpoint_path: str) -> OrderedDict:
|
||
|
state_dict = torch.load(checkpoint_path, map_location=torch.device("cpu"))
|
||
|
if "pytorch-lightning_version" in state_dict.keys():
|
||
|
state_dict = OrderedDict(
|
||
|
{
|
||
|
k.replace("model.", ""): v
|
||
|
for k, v in state_dict["state_dict"].items()
|
||
|
if "model." in k
|
||
|
}
|
||
|
)
|
||
|
return state_dict
|