[Pytorch] model weight 특정 부분만 불러오기

2020. 12. 11. 16:46카테고리 없음

    # https://discuss.pytorch.org/t/how-to-load-part-of-pre-trained-model/1113/2
    pretrained_dict = ckpt_E["model_state_dict"]
    model_dict = encoder.state_dict()

    pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
    model_dict.update(pretrained_dict) 
    encoder.load_state_dict(model_dict)