[Pytorch] model weight 특정 부분만 불러오기
2020. 12. 11. 16:46ㆍ카테고리 없음
# https://discuss.pytorch.org/t/how-to-load-part-of-pre-trained-model/1113/2
pretrained_dict = ckpt_E["model_state_dict"]
model_dict = encoder.state_dict()
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
encoder.load_state_dict(model_dict)