|
楼主 |
发表于 2024-9-17 04:17:34
|
显示全部楼层
本帖最后由 奇奇怪怪的ID 于 2024-9-17 07:27 编辑
两个钟搞定,太恶心了,我把控制台更新的代码做成功能函数了,期间各种奇怪报错,
好在成功了,循环训练的迭代时间会比正常的更少,是因为计时跑的代码更短,主要是不用准备素材
基础还是太差了,没贴上来的就是没变动
def repeated_training(self,high_iter_time):
time_str = time.strftime("[%H:%M:%S]")
loss_history = self.get_loss_history()
save_iter = self.get_iter()
self.iter += 1
if high_iter_time >= 10:
loss_string = "{0}[#{1:06d}][{2:.5s}s]".format ( time_str, self.iter, '{:0.4f}'.format(high_iter_time) )
else:
loss_string = "{0}[#{1:06d}][{2:04d}ms]".format ( time_str, self.iter, int(high_iter_time*1000) )
mean_loss = np.mean ( loss_history[save_iter:self.iter], axis=0)
# 计算从上次保存以来的平均损失mean_loss。
for loss_value in mean_loss:
loss_string += "[%.4f]" % (loss_value)
# 将平均损失的每个值添加到损失字符串loss_string中。
io.log_info(loss_string, end='\r')
# 输出损失字符串loss_string。
------------------------------------------------------------------------------------------------------------------
def train_one_iter(self):
# 执行单次训练迭代,并获取损失值列表
losses,iter_time = self.onTrainOneIter()
---------------------------------------------------------------------------------------------------
# 这里赋值的是正常、真脸、gan训练的迭代的结束时间,如果不进入循环训练则直接返回给train_one_iter
normal_retrain_end_time = time.time() - normal_retrain_end_time
# 如果启用了重训练样本选项
if self.options['retraining_samples']:
bs = self.get_batch_size() # 获取批量大小
lossi = self.options['high_loss_power'] # 获取高损失幂次参数
high_loss1 = self.options['num_retrain_cycles']
# 记录最近一批次的样本及其损失
for i in range(bs):
self.last_src_samples_loss.append((target_src, target_srcm, target_srcm_em, src_loss, filenames[0]))
self.last_dst_samples_loss.append((target_dst, target_dstm, target_dstm_em, dst_loss, filenames[1]))
# 当记录的样本数量达到设定的阈值时
if len(self.last_src_samples_loss) >= bs * lossi:
# 根据损失对样本进行排序
src_samples_loss = sorted(self.last_src_samples_loss, key=operator.itemgetter(3), reverse=True)
dst_samples_loss = sorted(self.last_dst_samples_loss, key=operator.itemgetter(3), reverse=True)
# 选择损失最高的一批样本用于重新训练
target_src = np.stack([x[0] for x in src_samples_loss[:bs]])
target_srcm = np.stack([x[1] for x in src_samples_loss[:bs]])
target_srcm_em = np.stack([x[2] for x in src_samples_loss[:bs]])
target_dst = np.stack([x[0] for x in dst_samples_loss[:bs]])
target_dstm = np.stack([x[1] for x in dst_samples_loss[:bs]])
target_dstm_em = np.stack([x[2] for x in dst_samples_loss[:bs]])
# 获取高损失样本的文件名
high_loss_filenames_src = [x[4] for x in src_samples_loss[:bs]]
high_loss_filenames_dst = [x[4] for x in dst_samples_loss[:bs]]
for i in range(high_loss1):
high_time = time.time()
src_loss, dst_loss = self.src_dst_train(target_src, target_src, target_srcm, target_srcm_em, target_dst, target_dst, target_dstm, target_dstm_em)
if i != high_loss1-1:
# 更新 loss_history
losses = ( ('src_loss', np.mean(src_loss) ), ('dst_loss', np.mean(dst_loss) ), )
self.loss_history.append([float(loss[1]) for loss in losses])
high_time = time.time() - high_time
self.repeated_training(high_time)
# 这里赋值的是循环训练的最后一次迭代的结束时间,返回给train_one_iter
normal_retrain_end_time = time.time() - high_time
# 清空记录的样本损失列表
self.last_src_samples_loss = []
self.last_dst_samples_loss = []
# 返回平均源损失和平均目标损失及正常迭代时间 or 复训的最后一次迭代时间
return (('src_loss', np.mean(src_loss)), ('dst_loss', np.mean(dst_loss)), ),normal_retrain_end_time
|
|