Pytorch单机多卡(DP)训练之后的模型“货不对板”
目录
背景⌗
- 之前介绍过Pytoch的单机多卡训练,当时采用的是内置的
nn.DataParallel()
方法 - 经过单机多卡训练的模型保存checkpoint之后再次加载需要同样使用
DP()
对模型进行封装之后才能正常加载,否则会报错状态词典的键值对不上
原因⌗
- 经过
DP()
封装之后的模型,其状态词典分别添加了module.
的前缀,例如原参数为hr_branch.conv_hr.1.layers.0.weight
,封装之后变成了module.hr_branch.conv_hr.1.layers.0.weight
。
解决方案⌗
- 在加载状态词典之前将该前缀进行替换,得到纯净的状态词典,则不再需要重新对模型进行
DP()
封装
ckp_state_dict = torch.load(args.ckp, map_location=torch.device("cpu"))["model"]
ckp_state = {}
for k, v in ckp_state_dict.items():
k = k.replace("module.", "")
ckp_state[k] = v
参考⌗
Read other posts