deepfacelab中文网

 找回密码
 立即注册(仅限QQ邮箱)

【统计】搞了个高loss重复次数的选项

[复制链接]

17

主题

148

帖子

1798

积分

初级丹圣

Rank: 8Rank: 8

积分
1798
 楼主| 发表于 2024-9-17 11:20:54 | 显示全部楼层
wtxx8888 发表于 2024-9-17 11:14
我研究研究吧,暂时没成功

你做好了也给我看看呗,我这个现在跑了6个钟,还没发现什么问题
回复

使用道具 举报

14

主题

2849

帖子

1万

积分

高级丹圣

Rank: 13Rank: 13Rank: 13Rank: 13

积分
15365

真我风采勋章万事如意节日勋章

发表于 2024-9-17 11:49:19 | 显示全部楼层
奇奇怪怪的ID 发表于 2024-9-17 11:20
你做好了也给我看看呗,我这个现在跑了6个钟,还没发现什么问题

你把ModelBase.py跟Model.py发给我看看吧。
就改这俩吧?

老报给我这个   File "E:\LV\DeepFaceLab\_internal\DeepFaceLab\models\ModelBase.py", line 542, in repeated_training
    for loss_value in mean_loss:
TypeError: 'numpy.float64' object is not iterable

不可迭代
回复

使用道具 举报

17

主题

148

帖子

1798

积分

初级丹圣

Rank: 8Rank: 8

积分
1798
 楼主| 发表于 2024-9-17 11:59:02 | 显示全部楼层
本帖最后由 奇奇怪怪的ID 于 2024-9-17 12:03 编辑
wtxx8888 发表于 2024-9-17 11:49
你把ModelBase.py跟Model.py发给我看看吧。
就改这俩吧?

是因为切片索引有问题的导致的,我刚开始也遇到过,我的源文件有些东西不方便发,我再给你说明白一下

image.png
这里是ModelBase文件需要改的
    def repeated_training(self,high_iter_time):
        time_str = time.strftime("[%H:%M:%S]")
        loss_history = self.get_loss_history()
        save_iter = self.get_iter()
        self.iter += 1


        if high_iter_time >= 10:
            loss_string = "{0}[#{1:06d}][{2:.5s}s]".format ( time_str, self.iter, '{:0.4f}'.format(high_iter_time) )
        else:
            loss_string = "{0}[#{1:06d}][{2:04d}ms]".format ( time_str, self.iter, int(high_iter_time*1000) )


        mean_loss = np.mean ( loss_history[save_iter:self.iter], axis=0)
        # 计算从上次保存以来的平均损失mean_loss。


        for loss_value in mean_loss:
            loss_string += "[%.4f]" % (loss_value)
        # 将平均损失的每个值添加到损失字符串loss_string中。


        io.log_info(loss_string, end='\r')
        # 输出损失字符串loss_string。

    def train_one_iter(self):
        

        # 执行单次训练迭代,并获取损失值列表
        losses,iter_time = self.onTrainOneIter()
        
---------------------------------------------------------------------------------------------------

这里是Model文件需要改的,我直接一整个贴给你




def onTrainOneIter(self):
        if self.get_iter() == 0 and not self.pretrain and not self.pretrain_just_disabled:
            io.log_info('您正在从头开始训练模型。强烈建议使用预训练模型来加快训练速度并提高质量。\n')
        normal_retrain_end_time = time.time()
        # 生成下一组样本数据
        ((warped_src, target_src, target_srcm, target_srcm_em),
        (warped_dst, target_dst, target_dstm, target_dstm_em)), filenames = self.generate_next_samples()

        # 使用源和目标样本进行训练,并计算损失
        src_loss, dst_loss = self.src_dst_train(warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em)

        # 如果启用了真实人脸权重且不是预训练模式
        if self.options['true_face_power'] != 0 and not self.pretrain:
            # 训练判别器
            self.D_train(warped_src, warped_dst)

        # 如果启用了GAN权重
        if self.gan_power != 0:
            # 训练带有源和目标的判别器
            self.D_src_dst_train(warped_src, target_src, target_srcm, target_srcm_em, warped_dst, target_dst, target_dstm, target_dstm_em)

        # 这里赋值的是正常、真脸、gan训练的迭代的结束时间,如果不进入循环训练则直接返回给train_one_iter
        normal_retrain_end_time = time.time() - normal_retrain_end_time

        # 如果启用了重训练样本选项
        if self.options['retraining_samples']:
            bs = self.get_batch_size()  # 获取批量大小
            lossi = self.options['high_loss_power']  # 获取高损失幂次参数
            high_loss1 = self.options['num_retrain_cycles']# 获取重训次数
               
            # 记录最近一批次的样本及其损失
            for i in range(bs):
                self.last_src_samples_loss.append((target_src, target_srcm, target_srcm_em, src_loss, filenames[0]))
                self.last_dst_samples_loss.append((target_dst, target_dstm, target_dstm_em, dst_loss, filenames[1]))

            # 当记录的样本数量达到设定的阈值时
            if len(self.last_src_samples_loss) >= bs * lossi:
                # 根据损失对样本进行排序
                src_samples_loss = sorted(self.last_src_samples_loss, key=operator.itemgetter(3), reverse=True)
                dst_samples_loss = sorted(self.last_dst_samples_loss, key=operator.itemgetter(3), reverse=True)

                # 选择损失最高的一批样本用于重新训练
                target_src = np.stack([x[0] for x in src_samples_loss[:bs]])
                target_srcm = np.stack([x[1] for x in src_samples_loss[:bs]])
                target_srcm_em = np.stack([x[2] for x in src_samples_loss[:bs]])

                target_dst = np.stack([x[0] for x in dst_samples_loss[:bs]])
                target_dstm = np.stack([x[1] for x in dst_samples_loss[:bs]])
                target_dstm_em = np.stack([x[2] for x in dst_samples_loss[:bs]])

                # 获取高损失样本的文件名
                high_loss_filenames_src = [x[4] for x in src_samples_loss[:bs]]
                high_loss_filenames_dst = [x[4] for x in dst_samples_loss[:bs]]

                for i in range(high_loss1):
                    high_time = time.time()

                    src_loss, dst_loss = self.src_dst_train(target_src, target_src, target_srcm, target_srcm_em, target_dst, target_dst, target_dstm, target_dstm_em)

                    if i != high_loss1-1:
                        # 更新 loss_history
                        losses = ( ('src_loss', np.mean(src_loss) ), ('dst_loss', np.mean(dst_loss) ), )

                        self.loss_history.append([float(loss[1]) for loss in losses])
                        
                        high_time = time.time() - high_time

                        self.repeated_training(high_time)

                # 这里赋值的是循环训练的最后一次迭代的结束时间,返回给train_one_iter
                normal_retrain_end_time = time.time() - high_time

                # 清空记录的样本损失列表
                self.last_src_samples_loss = []
                self.last_dst_samples_loss = []
               
        # 返回平均源损失和平均目标损失
        return (('src_loss', np.mean(src_loss)), ('dst_loss', np.mean(dst_loss)), ),normal_retrain_end_time


回复

使用道具 举报

14

主题

2849

帖子

1万

积分

高级丹圣

Rank: 13Rank: 13Rank: 13Rank: 13

积分
15365

真我风采勋章万事如意节日勋章

发表于 2024-9-17 12:43:35 | 显示全部楼层
本帖最后由 wtxx8888 于 2024-9-17 14:06 编辑
奇奇怪怪的ID 发表于 2024-9-17 11:59
是因为切片索引有问题的导致的,我刚开始也遇到过,我的源文件有些东西不方便发,我再给你说明白一下

就把你的,略微调整了下变量名,及去掉文件名的相关,还是报这个错。

正常训练没问题,一进入高LOSS的运算,就报。

暂时屏蔽了读LOSS的代码,无关紧要。
不卡顿是为了,可以分辨,代码当掉了没。
回复

使用道具 举报

QQ|Archiver|手机版|deepfacelab中文网 |网站地图

GMT+8, 2024-10-25 03:28 , Processed in 0.093920 second(s), 12 queries , Redis On.

Powered by Discuz! X3.4

Copyright © 2001-2020, Tencent Cloud.

快速回复 返回顶部 返回列表