需求

  • 对基于pytorch的深度学习模型进行多卡训练以加速训练过程
  • 由于显卡版本过于老旧,安装配置NCCL工程量过于庞大,希望使用简单的pytorch代码实现单机多卡训练,不考虑多机多卡的显卡通信
  • 训练完成后保存的checkpoint需要能够在任何设备上进行加载、推理

实现

训练

  • pytorch提供了简单的单机多卡训练api,只需要在初始化模型之后执行下列语句将模型复制到多卡上
# initiate multi-gpu training
model = nn.DataParallel(model, device_ids=<ids of the gpus you want to use>)
  • 其他操作与单卡训练完全一致

加载checkpoint

  • 上述操作后保存的checkpoint如果按照常规方法直接进行加载会报错
RuntimeError: Error(s) in loading state_dict for <ModelName>:
	Missing key(s) in state_dict:...
  • debug遍历后发现其实其状态字典是完全一致的,只是因为我们在训练过程中将模型定义为了多卡并行模型。这里只需要按照训练过程中转换为多卡模型的代码初始化当前模型结构即可,即执行:
# initiate multi-gpu training
model = nn.DataParallel(model, device_ids=<ids of the gpus you want to use>)
  • 其他操作与征程推理完全一致,若不想使用多卡/只想使用cpu,只需要按照常规将device = torch.device("<cpu/cuda:id>")即可

Note:查阅资料过程中发现有解答建议使用参数强行忽略模型加载的错误torch.load(<checkpoint>, strict=False),经测试,这样加载的模型啥也不是…不知道为什么pytorch官方要提供这个接口