|
发表于 2024-9-16 11:31:35
|
显示全部楼层
本帖最后由 wtxx8888 于 2024-9-16 12:27 编辑
额,我是照扒ModelBase.py的代码。你这跟我开始的计数方式差不多,但我仔细想了想,这样简单的计数,是有BUG的。
意味着,最后两次的记录,是重训练的两次重复(循环时记录了,返回后又记录一遍),而不是一次正常,一次重训练。
另,真脸倒无所谓,只用warped_src,warped_dst。
但GAN训练,应该在重训练之前完成。
不然重训练会对部分变量,进行新的赋值,从而刷掉了GAN用的,除warped_src,warped_dst之外的变量。
这会造成GAN的BUG,不是同一套的数据,在训练GAN。。。
def onTrainOneIter(self):
if self.get_iter() == 0 and not self.pretrain and not self.pretrain_just_disabled:
io.log_info('您正在从头开始训练模型。强烈建议使用预训练模型来加快训练速度并提高质量。.\n')
# 生成下一组样本数据
( (warped_src, target_src, target_srcm, target_srcm_em), \
(warped_dst, target_dst, target_dstm, target_dstm_em) ) = self.generate_next_samples()
# 使用SRC和DST样本进行训练,并计算LOSS
src_loss, dst_loss = self.src_dst_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em)
# 如果启用了真实人脸权重,且不是预训练模式
if self.options['true_face_power'] != 0 and not self.pretrain:
# 训练判别器
self.D_train (warped_src, warped_dst)
# 如果启用了GAN权重
if self.gan_power != 0:
# 训练带有SRC和DST的判别器
self.D_src_dst_train (warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em)
# 如果启用了重训练样本选项
if self.options['retraining_samples']:
bs = self.get_batch_size() # 获取批量大小
lossi = self.options['high_loss_power'] # 获取高LOSS权重参数
cycles = self.options['number_of_cycles'] # 获取高LOSS训练次数参数
# 记录最近一批次的样本,及其LOSS
for i in range(bs):
self.last_src_samples_loss.append ( (target_src, target_srcm, target_srcm_em, src_loss ) )
self.last_dst_samples_loss.append ( (target_dst, target_dstm, target_dstm_em, dst_loss ) )
# 当记录的样本数量,达到设定的权重时
if len(self.last_src_samples_loss) >= bs*lossi:
# 获取正常训练样本的平均LOSS值
losses = ( ('src_loss', np.mean(src_loss) ), ('dst_loss', np.mean(dst_loss) ), )
# 将正常训练样本的LOSS值,写入历史记录
self.loss_history.append ( [float(loss[1]) for loss in losses] )
# 迭代器计数,记录正常的这次
self.iter += 1
# 根据LOSS,对样本进行排序
src_samples_loss = sorted(self.last_src_samples_loss, key=operator.itemgetter(3), reverse=True)
dst_samples_loss = sorted(self.last_dst_samples_loss, key=operator.itemgetter(3), reverse=True)
# 选择LOSS最高的一批样本,用于重新训练
target_src = np.stack( [ x[0] for x in src_samples_loss[:bs] ] )
target_srcm = np.stack( [ x[1] for x in src_samples_loss[:bs] ] )
target_srcm_em = np.stack( [ x[2] for x in src_samples_loss[:bs] ] )
target_dst = np.stack( [ x[0] for x in dst_samples_loss[:bs] ] )
target_dstm = np.stack( [ x[1] for x in dst_samples_loss[:bs] ] )
target_dstm_em = np.stack( [ x[2] for x in dst_samples_loss[:bs] ] )
# 重新训练的次数
for i in range(cycles):
# 使用已选定的SRC和DST样本,重新进行训练,并计算LOSS
src_loss, dst_loss = self.src_dst_train (target_src, target_src, target_srcm, target_srcm_em, target_dst, target_dst, target_dstm, target_dstm_em)
# 判定是否为 最后一次 是0则交由返回进行记录 不是0则在循环内记录
if i != 0:
# 获取重训练样本的平均LOSS值
losses = ( ('src_loss', np.mean(src_loss) ), ('dst_loss', np.mean(dst_loss) ), )
# 将重训练样本的LOSS值,写入历史记录
self.loss_history.append ( [float(loss[1]) for loss in losses] )
# 迭代器计数,记录重训练的这次
self.iter += 1
# 清空记录样本的LOSS列表
self.last_src_samples_loss = []
self.last_dst_samples_loss = []
# 返回SRC的平均LOSS,和DST的平均LOSS
return ( ('src_loss', np.mean(src_loss) ), ('dst_loss', np.mean(dst_loss) ), ) # 返回ModelBase.py 那边会写最后一次的记录
|
|