背景

  • 之前介绍过Pytoch的单机多卡训练,当时采用的是内置的nn.DataParallel()方法
  • 经过单机多卡训练的模型保存checkpoint之后再次加载需要同样使用DP()对模型进行封装之后才能正常加载,否则会报错状态词典的键值对不上

原因

  • 经过DP()封装之后的模型,其状态词典分别添加了module.的前缀,例如原参数为hr_branch.conv_hr.1.layers.0.weight,封装之后变成了module.hr_branch.conv_hr.1.layers.0.weight

解决方案

  • 在加载状态词典之前将该前缀进行替换,得到纯净的状态词典,则不再需要重新对模型进行DP()封装
ckp_state_dict = torch.load(args.ckp, map_location=torch.device("cpu"))["model"]
ckp_state = {}
for k, v in ckp_state_dict.items():
    k = k.replace("module.", "")
    ckp_state[k] = v

参考