|
楼主 |
发表于 2024-9-17 11:59:02
|
显示全部楼层
本帖最后由 奇奇怪怪的ID 于 2024-9-17 12:03 编辑
是因为切片索引有问题的导致的,我刚开始也遇到过,我的源文件有些东西不方便发,我再给你说明白一下
这里是ModelBase文件需要改的
def repeated_training(self,high_iter_time):
time_str = time.strftime("[%H:%M:%S]")
loss_history = self.get_loss_history()
save_iter = self.get_iter()
self.iter += 1
if high_iter_time >= 10:
loss_string = "{0}[#{1:06d}][{2:.5s}s]".format ( time_str, self.iter, '{:0.4f}'.format(high_iter_time) )
else:
loss_string = "{0}[#{1:06d}][{2:04d}ms]".format ( time_str, self.iter, int(high_iter_time*1000) )
mean_loss = np.mean ( loss_history[save_iter:self.iter], axis=0)
# 计算从上次保存以来的平均损失mean_loss。
for loss_value in mean_loss:
loss_string += "[%.4f]" % (loss_value)
# 将平均损失的每个值添加到损失字符串loss_string中。
io.log_info(loss_string, end='\r')
# 输出损失字符串loss_string。
def train_one_iter(self):
# 执行单次训练迭代,并获取损失值列表
losses,iter_time = self.onTrainOneIter()
---------------------------------------------------------------------------------------------------
这里是Model文件需要改的,我直接一整个贴给你
def onTrainOneIter(self):
if self.get_iter() == 0 and not self.pretrain and not self.pretrain_just_disabled:
io.log_info('您正在从头开始训练模型。强烈建议使用预训练模型来加快训练速度并提高质量。\n')
normal_retrain_end_time = time.time()
# 生成下一组样本数据
((warped_src, target_src, target_srcm, target_srcm_em),
(warped_dst, target_dst, target_dstm, target_dstm_em)), filenames = self.generate_next_samples()
# 使用源和目标样本进行训练,并计算损失
src_loss, dst_loss = self.src_dst_train(warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em)
# 如果启用了真实人脸权重且不是预训练模式
if self.options['true_face_power'] != 0 and not self.pretrain:
# 训练判别器
self.D_train(warped_src, warped_dst)
# 如果启用了GAN权重
if self.gan_power != 0:
# 训练带有源和目标的判别器
self.D_src_dst_train(warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em)
# 这里赋值的是正常、真脸、gan训练的迭代的结束时间,如果不进入循环训练则直接返回给train_one_iter
normal_retrain_end_time = time.time() - normal_retrain_end_time
# 如果启用了重训练样本选项
if self.options['retraining_samples']:
bs = self.get_batch_size() # 获取批量大小
lossi = self.options['high_loss_power'] # 获取高损失幂次参数
high_loss1 = self.options['num_retrain_cycles']# 获取重训次数
# 记录最近一批次的样本及其损失
for i in range(bs):
self.last_src_samples_loss.append((target_src, target_srcm, target_srcm_em, src_loss, filenames[0]))
self.last_dst_samples_loss.append((target_dst, target_dstm, target_dstm_em, dst_loss, filenames[1]))
# 当记录的样本数量达到设定的阈值时
if len(self.last_src_samples_loss) >= bs * lossi:
# 根据损失对样本进行排序
src_samples_loss = sorted(self.last_src_samples_loss, key=operator.itemgetter(3), reverse=True)
dst_samples_loss = sorted(self.last_dst_samples_loss, key=operator.itemgetter(3), reverse=True)
# 选择损失最高的一批样本用于重新训练
target_src = np.stack([x[0] for x in src_samples_loss[:bs]])
target_srcm = np.stack([x[1] for x in src_samples_loss[:bs]])
target_srcm_em = np.stack([x[2] for x in src_samples_loss[:bs]])
target_dst = np.stack([x[0] for x in dst_samples_loss[:bs]])
target_dstm = np.stack([x[1] for x in dst_samples_loss[:bs]])
target_dstm_em = np.stack([x[2] for x in dst_samples_loss[:bs]])
# 获取高损失样本的文件名
high_loss_filenames_src = [x[4] for x in src_samples_loss[:bs]]
high_loss_filenames_dst = [x[4] for x in dst_samples_loss[:bs]]
for i in range(high_loss1):
high_time = time.time()
src_loss, dst_loss = self.src_dst_train(target_src, target_src, target_srcm, target_srcm_em, target_dst, target_dst, target_dstm, target_dstm_em)
if i != high_loss1-1:
# 更新 loss_history
losses = ( ('src_loss', np.mean(src_loss) ), ('dst_loss', np.mean(dst_loss) ), )
self.loss_history.append([float(loss[1]) for loss in losses])
high_time = time.time() - high_time
self.repeated_training(high_time)
# 这里赋值的是循环训练的最后一次迭代的结束时间,返回给train_one_iter
normal_retrain_end_time = time.time() - high_time
# 清空记录的样本损失列表
self.last_src_samples_loss = []
self.last_dst_samples_loss = []
# 返回平均源损失和平均目标损失
return (('src_loss', np.mean(src_loss)), ('dst_loss', np.mean(dst_loss)), ),normal_retrain_end_time
|
|