deepfacelab中文网

 找回密码
 立即注册(仅限QQ邮箱)

【统计】搞了个高loss重复次数的选项

[复制链接]

17

主题

148

帖子

1798

积分

初级丹圣

Rank: 8Rank: 8

积分
1798
 楼主| 发表于 2024-9-16 12:49:49 | 显示全部楼层
wtxx8888 发表于 2024-9-16 12:47
对了。那个GAN,该在重训练的前面。不然有BUG

你是说gan和真脸的训练数据会受到重训代码块的影响是吧
回复

使用道具 举报

14

主题

2939

帖子

1万

积分

高级丹圣

Rank: 13Rank: 13Rank: 13Rank: 13

积分
15926

真我风采勋章万事如意节日勋章

发表于 2024-9-16 12:52:11 | 显示全部楼层
奇奇怪怪的ID 发表于 2024-9-16 12:49
你是说gan和真脸的训练数据会受到重训代码块的影响是吧

嗯,重训练有个新赋值的动作。那几个变量会变得,不跟W那俩是一套了。
回复

使用道具 举报

17

主题

148

帖子

1798

积分

初级丹圣

Rank: 8Rank: 8

积分
1798
 楼主| 发表于 2024-9-16 12:53:25 | 显示全部楼层
wtxx8888 发表于 2024-9-16 12:52
嗯,重训练有个新赋值的动作。那几个变量会变得,不跟W那俩是一套了。

明白,晚上再说,再不睡要猝死
回复

使用道具 举报

14

主题

2939

帖子

1万

积分

高级丹圣

Rank: 13Rank: 13Rank: 13Rank: 13

积分
15926

真我风采勋章万事如意节日勋章

发表于 2024-9-16 12:53:49 | 显示全部楼层
奇奇怪怪的ID 发表于 2024-9-16 12:53
明白,晚上再说,再不睡要猝死

睡吧
回复

使用道具 举报

17

主题

148

帖子

1798

积分

初级丹圣

Rank: 8Rank: 8

积分
1798
 楼主| 发表于 2024-9-16 23:00:05 | 显示全部楼层
本帖最后由 奇奇怪怪的ID 于 2024-9-16 23:06 编辑

我在犹豫,重复训练批次的loss均值是不是每次都添加到模型历史里,比如重训50次,只添加最后一次的数据到损失历史记录列表,目前运行还没出现什么问题


image.png
回复

使用道具 举报

14

主题

2939

帖子

1万

积分

高级丹圣

Rank: 13Rank: 13Rank: 13Rank: 13

积分
15926

真我风采勋章万事如意节日勋章

发表于 2024-9-16 23:08:24 | 显示全部楼层
本帖最后由 wtxx8888 于 2024-9-16 23:10 编辑
奇奇怪怪的ID 发表于 2024-9-16 23:00
我在犹豫,重复训练批次的loss均值是不是每次都添加到模型历史里

能让这句在重训练里有效吗?
loss_history = model.get_loss_history()

这个是调LOSS值的,重训练里用,会报错。是Trainer.py里的句子。
解决它,重训练就跟正常训练一样了,不是表现为卡顿。
我现在只能显示除LOSS之外的,LOSS不动。
回复

使用道具 举报

14

主题

2939

帖子

1万

积分

高级丹圣

Rank: 13Rank: 13Rank: 13Rank: 13

积分
15926

真我风采勋章万事如意节日勋章

发表于 2024-9-16 23:19:47 | 显示全部楼层
奇奇怪怪的ID 发表于 2024-9-16 23:00
我在犹豫,重复训练批次的loss均值是不是每次都添加到模型历史里,比如重训50次,只添加最后一次的数据到 ...

为了美观,倒是可以这样,只写一次LOSS记录。
回复

使用道具 举报

17

主题

148

帖子

1798

积分

初级丹圣

Rank: 8Rank: 8

积分
1798
 楼主| 发表于 2024-9-16 23:28:26 | 显示全部楼层
wtxx8888 发表于 2024-9-16 23:08
能让这句在重训练里有效吗?
loss_history = model.get_loss_history()


都是链式调用


改肯定能改,我现在没那么迫切的想去改这个,毕竟只是看似卡顿而已,不影响效果,只是观感不佳


我2点闲下来看看



回复

使用道具 举报

17

主题

148

帖子

1798

积分

初级丹圣

Rank: 8Rank: 8

积分
1798
 楼主| 发表于 2024-9-17 04:17:34 | 显示全部楼层
本帖最后由 奇奇怪怪的ID 于 2024-9-17 07:27 编辑
wtxx8888 发表于 2024-9-16 23:08
能让这句在重训练里有效吗?
loss_history = model.get_loss_history()

两个钟搞定,太恶心了,我把控制台更新的代码做成功能函数了,期间各种奇怪报错,


好在成功了,循环训练的迭代时间会比正常的更少,是因为计时跑的代码更短,主要是不用准备素材


基础还是太差了,没贴上来的就是没变动


    def repeated_training(self,high_iter_time):
        time_str = time.strftime("[%H:%M:%S]")
        loss_history = self.get_loss_history()
        save_iter = self.get_iter()
        self.iter += 1


        if high_iter_time >= 10:
            loss_string = "{0}[#{1:06d}][{2:.5s}s]".format ( time_str, self.iter, '{:0.4f}'.format(high_iter_time) )
        else:
            loss_string = "{0}[#{1:06d}][{2:04d}ms]".format ( time_str, self.iter, int(high_iter_time*1000) )


        mean_loss = np.mean ( loss_history[save_iter:self.iter], axis=0)
        # 计算从上次保存以来的平均损失mean_loss。


        for loss_value in mean_loss:
            loss_string += "[%.4f]" % (loss_value)
        # 将平均损失的每个值添加到损失字符串loss_string中。


        io.log_info(loss_string, end='\r')
        # 输出损失字符串loss_string。

------------------------------------------------------------------------------------------------------------------   
    def train_one_iter(self):
        

        # 执行单次训练迭代,并获取损失值列表
        losses,iter_time = self.onTrainOneIter()
        
---------------------------------------------------------------------------------------------------



        # 这里赋值的是正常、真脸、gan训练的迭代的结束时间,如果不进入循环训练则直接返回给train_one_iter
        normal_retrain_end_time = time.time() - normal_retrain_end_time

        # 如果启用了重训练样本选项
        if self.options['retraining_samples']:
            bs = self.get_batch_size()  # 获取批量大小
            lossi = self.options['high_loss_power']  # 获取高损失幂次参数
            high_loss1 = self.options['num_retrain_cycles']

            # 记录最近一批次的样本及其损失
            for i in range(bs):
                self.last_src_samples_loss.append((target_src, target_srcm, target_srcm_em, src_loss, filenames[0]))
                self.last_dst_samples_loss.append((target_dst, target_dstm, target_dstm_em, dst_loss, filenames[1]))

            # 当记录的样本数量达到设定的阈值时
            if len(self.last_src_samples_loss) >= bs * lossi:
                # 根据损失对样本进行排序
                src_samples_loss = sorted(self.last_src_samples_loss, key=operator.itemgetter(3), reverse=True)
                dst_samples_loss = sorted(self.last_dst_samples_loss, key=operator.itemgetter(3), reverse=True)

                # 选择损失最高的一批样本用于重新训练
                target_src = np.stack([x[0] for x in src_samples_loss[:bs]])
                target_srcm = np.stack([x[1] for x in src_samples_loss[:bs]])
                target_srcm_em = np.stack([x[2] for x in src_samples_loss[:bs]])

                target_dst = np.stack([x[0] for x in dst_samples_loss[:bs]])
                target_dstm = np.stack([x[1] for x in dst_samples_loss[:bs]])
                target_dstm_em = np.stack([x[2] for x in dst_samples_loss[:bs]])

                # 获取高损失样本的文件名
                high_loss_filenames_src = [x[4] for x in src_samples_loss[:bs]]
                high_loss_filenames_dst = [x[4] for x in dst_samples_loss[:bs]]

                for i in range(high_loss1):
                    high_time = time.time()
                    src_loss, dst_loss = self.src_dst_train(target_src, target_src, target_srcm, target_srcm_em, target_dst, target_dst, target_dstm, target_dstm_em)
                    if i != high_loss1-1:
                    # 更新 loss_history
                        losses = ( ('src_loss', np.mean(src_loss) ), ('dst_loss', np.mean(dst_loss) ), )
                        self.loss_history.append([float(loss[1]) for loss in losses])

                        high_time = time.time() - high_time

                        self.repeated_training(high_time)

                # 这里赋值的是循环训练的最后一次迭代的结束时间,返回给train_one_iter
                normal_retrain_end_time = time.time() - high_time


                # 清空记录的样本损失列表
                self.last_src_samples_loss = []
                self.last_dst_samples_loss = []

           # 返回平均源损失和平均目标损失及正常迭代时间 or 复训的最后一次迭代时间
        return (('src_loss', np.mean(src_loss)), ('dst_loss', np.mean(dst_loss)), ),normal_retrain_end_time


回复

使用道具 举报

14

主题

2939

帖子

1万

积分

高级丹圣

Rank: 13Rank: 13Rank: 13Rank: 13

积分
15926

真我风采勋章万事如意节日勋章

发表于 2024-9-17 11:14:04 | 显示全部楼层
奇奇怪怪的ID 发表于 2024-9-17 04:17
两个钟搞定,太恶心了,我把控制台更新的代码做成功能函数了,期间各种奇怪报错,

我研究研究吧,暂时没成功
回复

使用道具 举报

QQ|Archiver|手机版|deepfacelab中文网 |网站地图

GMT+8, 2024-11-21 21:39 , Processed in 0.137770 second(s), 33 queries .

Powered by Discuz! X3.4

Copyright © 2001-2020, Tencent Cloud.

快速回复 返回顶部 返回列表