星级打分
平均分: NAN 参与人数: 0 我的评分: 未评
本帖最后由 lispmox 于 2021-8-22 21:54 编辑
新手上路,手里有一块2060用来玩游戏。用DFL玩过几次 256WF 的模型,受限于 6G 显存,无法追求更高的画质和参数。最近显卡价格一直很高,估计到年底都不会降价了。于此开始尝试在 6G 显存上训练 SAEHD 模型的想法。
先说思路,将 DFL 模型转化为 pytorch 实现,然后借助混合精度训练、checkpoint等方法减少训练期间的显存使用。
DFL 的作者一开始采用 leras 作为深度学习框架,这是原作者基于 tensorflow 做的一个封装,据说是简化了使用(个人感觉并没有)。随着近几年各大深度学习框架的发展,leras则显得比较简陋,不够灵活。实际上作者在 DFLive 里已经不再使用 leras。我个人比较喜欢 pytorch 这个框架,因此第一步就是将 DFL 的原模型导出为 pytorch 版本。从 leras 模型导出为 pytorch 模型主要关注四个点:
1. leras Conv2D 导出为 pytorch Conv2d 时,卷积核的维度应该由 (H, W, I, O) 转置为 (O, I, H, W),和 tensorflow 转 pytorch 类似。
2. leras Conv2D 支持非对称 padding,而 pytorch 只支持对称的padding。解决方法就是在 pytorch 里基于 F.pad 和 F.conv2d 两个函数重新实现一个 leras 版本。
3. leras Dense 全连接层的参数维度时 (I, O),pytorch 对应的 Linear 全连接层参数维度是 (O, I),需要转置一下。
4. leras 自己实现了 depth_to_space 函数,pytorch 对应的是 F.pixel_shuffle 函数,然而两者运行细节有差异,要仿照 leras 重现实现一下。
训练用的损失函数方面,我暂时只实现了基本的 DSSIM(结构损失)和 true face power 两种,剩下功能等有时间了再做。
混合精度训练的含义是用16位的半精度浮点进行前向传播和梯度计算,然后用32位浮点数进行梯度更新。而一般的训练都是全32位浮点数计算的。混合精度的好处是可以减少显存占用,实测大概能减少1/3。现在只能运行在Nvidia的显卡上,其中20系列的显卡性价比最高。N卡20细节半精度和全精度的计算比是2:1,理论上有两倍的计算效率,30系列显卡只有1:1。在使用混合精度训练后,CUDA利用率从持续95%+降到了在70~90%之间波动,每一轮的迭代时间大概是原来的3/4。
PS:在使用混合精度训练的时候,最后的 loss 函数还是需要用全精度计算。推测是DSSIM损失函数里面包含一个非常大的卷积核(>20),用半精度计算会导致数值溢出。
另一个减少显存占用的方法是checkpoint,可以感性理解为对模型分段求导。这个功能在 pytorch 里有官方支持,tensorflow 也有民间实现。据说 mxnet 对此支持更好,不过生态和稳定性是一个坑。模型训练时,显存占用主要分为4类:1) 模型参数占用,大概静态文件的2倍;2) 梯度占用,和模型参数占用差不多大小;3) 模型前向传播保存的中间计算结果,用于自动微分的求解;4) 临时计算缓存。这里面可以优化的时2) 和 3) 。
先说3),checkpoint主要功能就是将模型分段,然后只保存前向传播处于段头和段尾的中间结果,剩下的内存就可以释放了,自动微分时再重新计算缺失的内容。假设模型有N层,大概分为sqrt(N)个段,3) 的显存占用开个根号,性能损失也不大。
然后是4), 可以尝试在更新完梯度之后将参数绑定的梯度张量直接释放。pytorch 的默认行为是数值清零但不释放空间,这种方式会提高计算效率(少了N次内存申请),但会和计算缓存抢空间。操作也比较简单,当明确不需要某个模块的梯度时,运行model.zero_grad(set_to_none=True)就行。当然,这仅限于简单情况,当你的网络包含多个分支,例如优化 true face 的判别器时,还是要小心选择释放时机的。
经过一系列的优化,成功在 6G 显存上运行了一个512WF DF-UD的模型,具体参数为:
resolution = 512
ae_dims = 320
e_dims = 64
d_dims = 72
d_mask_dims = 22
batch_size = 8
跑一个epoch时间在2.6-3.1s之间(还要啥自行车)
在 pytorch 框架里训练好模型之后再反向转回 leras 版本,然后就可以借助 DFL 的一些工具换脸了。详细部分可以参考 DFL 的代码,需要注意部分切脸、对齐等参数要统一,否则会影响效果。
代码主体部分已经完成,还需要调整一些细节,现在的训练效果还比不上DFL内置的实现。附上一个从零开始训练的图,跑了3W次迭代,没开 true face。
PS:问我感受?DFL哪有写代码好玩[doge]
评分
查看全部评分