deepfacelab中文网

 找回密码
 立即注册(仅限QQ邮箱)
查看: 231|回复: 12

关于提取脸部速度问题的解决

[复制链接]

19

主题

151

帖子

964

积分

高级丹师

Rank: 5Rank: 5

积分
964
 楼主| 发表于 5 天前 | 显示全部楼层 |阅读模式
星级打分
  • 1
  • 2
  • 3
  • 4
  • 5
平均分:NAN  参与人数:0  我的评分:未评
我发个我自己修改的用的版本吧。


因为我的是python3.10版本以及相关库的环境,所以如果出错,你们自己解决。。

首先是FANExtractor.py

  1. import os
  2. import traceback
  3. from pathlib import Path

  4. import cv2
  5. import numpy as np
  6. from numpy import linalg as npla

  7. from facelib import FaceType, LandmarksProcessor
  8. from core.leras import nn

  9. """
  10. ported from https://github.com/1adrianb/face-alignment
  11. """


  12. class FANExtractor(object):
  13.     def __init__(self, landmarks_3D=False, place_model_on_cpu=False):
  14.         model_path = Path(__file__).parent / ("2DFAN.npy" if not landmarks_3D else "3DFAN.npy")
  15.         if not model_path.exists():
  16.             raise Exception("Unable to load FANExtractor model")

  17.         nn.initialize(data_format="NHWC")
  18.         tf = nn.tf

  19.         class ConvBlock(nn.ModelBase):
  20.             def on_build(self, in_planes, out_planes):
  21.                 self.in_planes = in_planes
  22.                 self.out_planes = out_planes

  23.                 self.bn1 = nn.BatchNorm2D(in_planes)
  24.                 self.conv1 = nn.Conv2D(in_planes, out_planes // 2, kernel_size=3, strides=1, padding='SAME', use_bias=False)

  25.                 self.bn2 = nn.BatchNorm2D(out_planes // 2)
  26.                 self.conv2 = nn.Conv2D(out_planes // 2, out_planes // 4, kernel_size=3, strides=1, padding='SAME', use_bias=False)

  27.                 self.bn3 = nn.BatchNorm2D(out_planes // 4)
  28.                 self.conv3 = nn.Conv2D(out_planes // 4, out_planes // 4, kernel_size=3, strides=1, padding='SAME', use_bias=False)

  29.                 if self.in_planes != self.out_planes:
  30.                     self.down_bn1 = nn.BatchNorm2D(in_planes)
  31.                     self.down_conv1 = nn.Conv2D(in_planes, out_planes, kernel_size=1, strides=1, padding='VALID', use_bias=False)
  32.                 else:
  33.                     self.down_bn1 = None
  34.                     self.down_conv1 = None

  35.             def forward(self, input):
  36.                 x = input
  37.                 x = self.bn1(x)
  38.                 x = tf.nn.relu(x)
  39.                 x = out1 = self.conv1(x)

  40.                 x = self.bn2(x)
  41.                 x = tf.nn.relu(x)
  42.                 x = out2 = self.conv2(x)

  43.                 x = self.bn3(x)
  44.                 x = tf.nn.relu(x)
  45.                 x = out3 = self.conv3(x)

  46.                 x = tf.concat([out1, out2, out3], axis=-1)

  47.                 if self.in_planes != self.out_planes:
  48.                     downsample = self.down_bn1(input)
  49.                     downsample = tf.nn.relu(downsample)
  50.                     downsample = self.down_conv1(downsample)
  51.                     x = x + downsample
  52.                 else:
  53.                     x = x + input

  54.                 return x

  55.         class HourGlass(nn.ModelBase):
  56.             def on_build(self, in_planes, depth):
  57.                 self.b1 = ConvBlock(in_planes, 256)
  58.                 self.b2 = ConvBlock(in_planes, 256)

  59.                 if depth > 1:
  60.                     self.b2_plus = HourGlass(256, depth - 1)
  61.                 else:
  62.                     self.b2_plus = ConvBlock(256, 256)

  63.                 self.b3 = ConvBlock(256, 256)

  64.             def forward(self, input):
  65.                 up1 = self.b1(input)

  66.                 low1 = tf.nn.avg_pool(input, [1, 2, 2, 1], [1, 2, 2, 1], 'VALID')
  67.                 low1 = self.b2(low1)

  68.                 low2 = self.b2_plus(low1)
  69.                 low3 = self.b3(low2)

  70.                 up2 = nn.upsample2d(low3)

  71.                 return up1 + up2

  72.         class FAN(nn.ModelBase):
  73.             def __init__(self):
  74.                 super().__init__(name='FAN')

  75.             def on_build(self):
  76.                 self.conv1 = nn.Conv2D(3, 64, kernel_size=7, strides=2, padding='SAME')
  77.                 self.bn1 = nn.BatchNorm2D(64)

  78.                 self.conv2 = ConvBlock(64, 128)
  79.                 self.conv3 = ConvBlock(128, 128)
  80.                 self.conv4 = ConvBlock(128, 256)

  81.                 self.m = []
  82.                 self.top_m = []
  83.                 self.conv_last = []
  84.                 self.bn_end = []
  85.                 self.l = []
  86.                 self.bl = []
  87.                 self.al = []
  88.                 for i in range(4):
  89.                     self.m += [HourGlass(256, 4)]
  90.                     self.top_m += [ConvBlock(256, 256)]

  91.                     self.conv_last += [nn.Conv2D(256, 256, kernel_size=1, strides=1, padding='VALID')]
  92.                     self.bn_end += [nn.BatchNorm2D(256)]

  93.                     self.l += [nn.Conv2D(256, 68, kernel_size=1, strides=1, padding='VALID')]

  94.                     if i < 4 - 1:
  95.                         self.bl += [nn.Conv2D(256, 256, kernel_size=1, strides=1, padding='VALID')]
  96.                         self.al += [nn.Conv2D(68, 256, kernel_size=1, strides=1, padding='VALID')]

  97.             def forward(self, inp):
  98.                 x, = inp
  99.                 x = self.conv1(x)
  100.                 x = self.bn1(x)
  101.                 x = tf.nn.relu(x)

  102.                 x = self.conv2(x)
  103.                 x = tf.nn.avg_pool(x, [1, 2, 2, 1], [1, 2, 2, 1], 'VALID')
  104.                 x = self.conv3(x)
  105.                 x = self.conv4(x)

  106.                 outputs = []
  107.                 previous = x
  108.                 for i in range(4):
  109.                     ll = self.m[i](previous)
  110.                     ll = self.top_m[i](ll)
  111.                     ll = self.conv_last[i](ll)
  112.                     ll = self.bn_end[i](ll)
  113.                     ll = tf.nn.relu(ll)
  114.                     tmp_out = self.l[i](ll)
  115.                     outputs.append(tmp_out)
  116.                     if i < 4 - 1:
  117.                         ll = self.bl[i](ll)
  118.                         previous = previous + ll + self.al[i](tmp_out)
  119.                 x = outputs[-1]
  120.                 x = tf.transpose(x, (0, 3, 1, 2))
  121.                 return x

  122.         e = None
  123.         if place_model_on_cpu:
  124.             e = tf.device("/CPU:0")

  125.         if e is not None:
  126.             e.__enter__()
  127.         try:
  128.             self.model = FAN()
  129.             self.model.load_weights(str(model_path))
  130.         finally:
  131.             if e is not None:
  132.                 e.__exit__(None, None, None)

  133.         self.model.build_for_run([(tf.float32, (None, 256, 256, 3))])

  134.         self.gpu_memory_limit = self._get_gpu_memory_limit()
  135.         self.current_memory_usage = 0

  136.     def _get_gpu_memory_limit(self):
  137.         import tensorflow as tf
  138.         try:
  139.             gpus = tf.config.experimental.list_physical_devices('GPU')
  140.             if gpus:
  141.                 details = tf.config.experimental.get_device_details(gpus[0])
  142.                 if 'device_limitations' in details and 'total_memory' in details['device_limitations']:
  143.                     return details['device_limitations']['total_memory'] // (1024 ** 2)
  144.                 else:
  145.                     return 8192
  146.         except:
  147.             return 8192
  148.         return 8192

  149.     def _estimate_memory_usage(self, batch_size, avg_image_size):
  150.         memory_per_image = (avg_image_size * 3 * 4) / (1024 ** 2)  # 图像数据 (float32)
  151.         model_overhead = 150  # FAN模型激活值和中间层的更大开销
  152.         return batch_size * (memory_per_image + model_overhead)

  153.     def _get_current_gpu_memory_usage(self):
  154.         try:
  155.             import tensorflow as tf
  156.             gpus = tf.config.experimental.list_physical_devices('GPU')
  157.             if gpus:
  158.                 try:
  159.                     import pynvml
  160.                     pynvml.nvmlInit()
  161.                     handle = pynvml.nvmlDeviceGetHandleByIndex(0)
  162.                     info = pynvml.nvmlDeviceGetMemoryInfo(handle)
  163.                     return info.used // (1024 ** 2)  # 转换为MB
  164.                 except ImportError:
  165.                     return self.current_memory_usage
  166.             return 0
  167.         except:
  168.             return self.current_memory_usage

  169.     def extract_batch(self, input_images, all_rects, second_pass_extractor=None,
  170.                       is_bgr=True, multi_sample=False, initial_batch_size=4, memory_safety_margin=0.7):
  171.         if not input_images or not all_rects or len(input_images) != len(all_rects):
  172.             return [[] for _ in range(len(input_images))]

  173.         all_results = []

  174.         avg_image_size = np.mean([img.shape[0] * img.shape[1] for img in input_images])

  175.         current_batch_size = initial_batch_size
  176.         max_retries = 3
  177.         retry_count = 0

  178.         i = 0
  179.         while i < len(input_images):
  180.             current_batch_size = self._adjust_batch_size(
  181.                 current_batch_size, avg_image_size, memory_safety_margin
  182.             )

  183.             end_idx = min(i + current_batch_size, len(input_images))
  184.             batch_images = input_images[i:end_idx]
  185.             batch_rects = all_rects[i:end_idx]

  186.             batch_landmarks = []
  187.             success = True

  188.             try:
  189.                 for img, rects in zip(batch_images, batch_rects):
  190.                     landmarks = self._extract_single_optimized(img, rects, second_pass_extractor, is_bgr, multi_sample)
  191.                     batch_landmarks.append(landmarks)

  192.                     self.current_memory_usage += 100  # 估算每次处理约100MB

  193.                 all_results.extend(batch_landmarks)
  194.                 i = end_idx
  195.                 retry_count = 0  # 重置重试计数

  196.             except Exception as e:
  197.                 if "out of memory" in str(e).lower() or "memory" in str(e).lower():
  198.                     if retry_count < max_retries:
  199.                         current_batch_size = max(1, current_batch_size // 2)
  200.                         retry_count += 1
  201.                         print(f"FANExtractor: 内存不足,减少批大小至 {current_batch_size}")
  202.                         continue
  203.                     else:
  204.                         print(f"FANExtractor: 无法处理图像,跳过。错误: {e}")
  205.                         for _ in range(len(batch_images)):
  206.                             all_results.append(None)
  207.                         i = end_idx
  208.                         retry_count = 0
  209.                 else:
  210.                     print(f"FANExtractor: 处理图像时发生错误: {e}")
  211.                     for _ in range(len(batch_images)):
  212.                         all_results.append(None)
  213.                     i = end_idx
  214.                     retry_count = 0

  215.         return all_results

  216.     def _adjust_batch_size(self, current_batch_size, avg_image_size, safety_margin):
  217.         if self.gpu_memory_limit == 0:  # CPU模式
  218.             return min(current_batch_size, 2)

  219.         current_usage = self._get_current_gpu_memory_usage()

  220.         available_memory = self.gpu_memory_limit * safety_margin - current_usage

  221.         required_memory = self._estimate_memory_usage(current_batch_size, avg_image_size)

  222.         if required_memory > available_memory and current_batch_size > 1:
  223.             adjusted_batch_size = int(current_batch_size * (available_memory / required_memory))
  224.             adjusted_batch_size = max(1, adjusted_batch_size)
  225.             return min(current_batch_size, adjusted_batch_size)

  226.         return current_batch_size

  227.     def _process_batch(self, input_images, all_rects, second_pass_extractor=None, is_bgr=True, multi_sample=False):
  228.         batch_results = []
  229.         for img, rects in zip(input_images, all_rects):
  230.             landmarks = self._extract_single_optimized(img, rects, second_pass_extractor, is_bgr, multi_sample)
  231.             batch_results.append(landmarks)
  232.         return batch_results

  233.     def _extract_single_optimized(self, input_image, rects, second_pass_extractor=None, is_bgr=True, multi_sample=False):
  234.         if len(rects) == 0:
  235.             return []

  236.         if is_bgr:
  237.             input_image = input_image[:, :, ::-1]
  238.             is_bgr = False

  239.         (h, w, ch) = input_image.shape

  240.         landmarks = []
  241.         for (left, top, right, bottom) in rects:
  242.             scale = (right - left + bottom - top) / 195.0

  243.             center = np.array([(left + right) / 2.0, (top + bottom) / 2.0])
  244.             centers = [center]

  245.             if multi_sample:
  246.                 centers += [center + [-1, -1],
  247.                             center + [1, -1],
  248.                             center + [1, 1],
  249.                             center + [-1, 1],
  250.                             ]

  251.             images = []
  252.             ptss = []

  253.             try:
  254.                 for c in centers:
  255.                     images += [self.crop(input_image, c, scale)]

  256.                 images = np.stack(images)
  257.                 images = images.astype(np.float32) / 255.0

  258.                 predicted = []
  259.                 for i in range(len(images)):
  260.                     pred_result = self.model.run([images[i][None, ...]])
  261.                     predicted.append(pred_result[0])

  262.                 for i, pred in enumerate(predicted):
  263.                     ptss += [self.get_pts_from_predict(pred, centers[i], scale)]
  264.                 pts_img = np.mean(np.array(ptss), 0)

  265.                 landmarks.append(pts_img)
  266.             except Exception as e:
  267.                 print(f"Error in FANExtractor.extract: {str(e)}")
  268.                 landmarks.append(None)

  269.         if second_pass_extractor is not None and any(lmrks is not None for lmrks in landmarks):
  270.             for i, lmrks in enumerate(landmarks):
  271.                 if lmrks is not None:
  272.                     try:
  273.                         image_to_face_mat = LandmarksProcessor.get_transform_mat(lmrks, 256, FaceType.FULL)
  274.                         face_image = cv2.warpAffine(input_image, image_to_face_mat, (256, 256), cv2.INTER_CUBIC)

  275.                         rects2 = second_pass_extractor.extract(face_image, is_bgr=is_bgr)
  276.                         if len(rects2) == 1:  # 只有在检测到1个面部时才进行二次提取
  277.                             lmrks2 = self.extract(face_image, [rects2[0]], is_bgr=is_bgr, multi_sample=True)[0]
  278.                             landmarks[i] = LandmarksProcessor.transform_points(lmrks2, image_to_face_mat, True)
  279.                     except Exception as e:
  280.                         print(f"Error in second pass extraction: {str(e)}")
  281.                         pass

  282.         return landmarks

  283.     def extract(self, input_image, rects, second_pass_extractor=None, is_bgr=True, multi_sample=False):
  284.         return self._extract_single_optimized(input_image, rects, second_pass_extractor, is_bgr, multi_sample)

  285.     def transform(self, point, center, scale, resolution):
  286.         pt = np.array([point[0], point[1], 1.0])
  287.         h = 200.0 * scale
  288.         m = np.eye(3)
  289.         m[0, 0] = resolution / h
  290.         m[1, 1] = resolution / h
  291.         m[0, 2] = resolution * (-center[0] / h + 0.5)
  292.         m[1, 2] = resolution * (-center[1] / h + 0.5)
  293.         m = np.linalg.inv(m)
  294.         return np.matmul(m, pt)[0:2]

  295.     def crop(self, image, center, scale, resolution=256.0):
  296.         ul = self.transform([1, 1], center, scale, resolution).astype(np.int32)
  297.         br = self.transform([resolution, resolution], center, scale, resolution).astype(np.int32)

  298.         ht, wd = image.shape[0], image.shape[1]
  299.         newX = np.array([max(1, -ul[0] + 1), min(br[0], wd) - ul[0]], dtype=np.int32)
  300.         newY = np.array([max(1, -ul[1] + 1), min(br[1], ht) - ul[1]], dtype=np.int32)
  301.         oldX = np.array([max(1, ul[0] + 1), min(br[0], wd)], dtype=np.int32)
  302.         oldY = np.array([max(1, ul[1] + 1), min(br[1], ht)], dtype=np.int32)

  303.         if image.ndim > 2:
  304.             newDim = np.array([br[1] - ul[1], br[0] - ul[0], image.shape[2]], dtype=np.int32)
  305.             newImg = np.zeros(newDim, dtype=np.uint8)
  306.         else:
  307.             newDim = np.array([br[1] - ul[1], br[0] - ul[0]], dtype=np.int32)
  308.             newImg = np.zeros(newDim, dtype=np.uint8)

  309.         newImg[newY[0] - 1:newY[1], newX[0] - 1:newX[1]] = image[oldY[0] - 1:oldY[1], oldX[0] - 1:oldX[1], :]

  310.         newImg = cv2.resize(newImg, dsize=(int(resolution), int(resolution)), interpolation=cv2.INTER_LINEAR)
  311.         return newImg

  312.     def get_pts_from_predict(self, a, center, scale):
  313.         a_ch, a_h, a_w = a.shape

  314.         b = a.reshape((a_ch, a_h * a_w))
  315.         c = b.argmax(1).reshape((a_ch, 1)).repeat(2, axis=1).astype(np.float32)
  316.         c[:, 0] %= a_w
  317.         c[:, 1] = np.floor(c[:, 1] / a_w)

  318.         for i in range(a_ch):
  319.             pX, pY = int(c[i, 0]), int(c[i, 1])
  320.             if 0 < pX < a_w - 1 and 0 < pY < a_h - 1:
  321.                 diff = np.array([a[i, pY, pX + 1] - a[i, pY, pX - 1],
  322.                                  a[i, pY + 1, pX] - a[i, pY - 1, pX]])
  323.                 c[i] += np.sign(diff) * 0.25

  324.         c += 0.5

  325.         return np.array([self.transform(c[i], center, scale, a_w) for i in range(a_ch)])
复制代码


第二个文件 S3FDExtractor.py

  1. import operator
  2. from pathlib import Path

  3. import cv2
  4. import numpy as np

  5. from core.leras import nn


  6. class S3FDExtractor(object):
  7.     def __init__(self, place_model_on_cpu=False):
  8.         nn.initialize(data_format="NHWC")
  9.         tf = nn.tf

  10.         model_path = Path(__file__).parent / "S3FD.npy"
  11.         if not model_path.exists():
  12.             raise Exception("Unable to load S3FD.npy")

  13.         class L2Norm(nn.LayerBase):
  14.             def __init__(self, n_channels, **kwargs):
  15.                 self.n_channels = n_channels
  16.                 super().__init__(**kwargs)

  17.             def build_weights(self):
  18.                 self.weight = tf.get_variable("weight", (1, 1, 1, self.n_channels), dtype=nn.floatx, initializer=tf.initializers.ones)

  19.             def get_weights(self):
  20.                 return [self.weight]

  21.             def __call__(self, inputs):
  22.                 x = inputs
  23.                 x = x / (tf.sqrt(tf.reduce_sum(tf.pow(x, 2), axis=-1, keepdims=True)) + 1e-10) * self.weight
  24.                 return x

  25.         class S3FD(nn.ModelBase):
  26.             def __init__(self):
  27.                 super().__init__(name='S3FD')

  28.             def on_build(self):
  29.                 self.minus = tf.constant([104, 117, 123], dtype=nn.floatx)
  30.                 self.conv1_1 = nn.Conv2D(3, 64, kernel_size=3, strides=1, padding='SAME')
  31.                 self.conv1_2 = nn.Conv2D(64, 64, kernel_size=3, strides=1, padding='SAME')

  32.                 self.conv2_1 = nn.Conv2D(64, 128, kernel_size=3, strides=1, padding='SAME')
  33.                 self.conv2_2 = nn.Conv2D(128, 128, kernel_size=3, strides=1, padding='SAME')

  34.                 self.conv3_1 = nn.Conv2D(128, 256, kernel_size=3, strides=1, padding='SAME')
  35.                 self.conv3_2 = nn.Conv2D(256, 256, kernel_size=3, strides=1, padding='SAME')
  36.                 self.conv3_3 = nn.Conv2D(256, 256, kernel_size=3, strides=1, padding='SAME')

  37.                 self.conv4_1 = nn.Conv2D(256, 512, kernel_size=3, strides=1, padding='SAME')
  38.                 self.conv4_2 = nn.Conv2D(512, 512, kernel_size=3, strides=1, padding='SAME')
  39.                 self.conv4_3 = nn.Conv2D(512, 512, kernel_size=3, strides=1, padding='SAME')

  40.                 self.conv5_1 = nn.Conv2D(512, 512, kernel_size=3, strides=1, padding='SAME')
  41.                 self.conv5_2 = nn.Conv2D(512, 512, kernel_size=3, strides=1, padding='SAME')
  42.                 self.conv5_3 = nn.Conv2D(512, 512, kernel_size=3, strides=1, padding='SAME')

  43.                 self.fc6 = nn.Conv2D(512, 1024, kernel_size=3, strides=1, padding=3)
  44.                 self.fc7 = nn.Conv2D(1024, 1024, kernel_size=1, strides=1, padding='SAME')

  45.                 self.conv6_1 = nn.Conv2D(1024, 256, kernel_size=1, strides=1, padding='SAME')
  46.                 self.conv6_2 = nn.Conv2D(256, 512, kernel_size=3, strides=2, padding='SAME')

  47.                 self.conv7_1 = nn.Conv2D(512, 128, kernel_size=1, strides=1, padding='SAME')
  48.                 self.conv7_2 = nn.Conv2D(128, 256, kernel_size=3, strides=2, padding='SAME')

  49.                 self.conv3_3_norm = L2Norm(256)
  50.                 self.conv4_3_norm = L2Norm(512)
  51.                 self.conv5_3_norm = L2Norm(512)

  52.                 self.conv3_3_norm_mbox_conf = nn.Conv2D(256, 4, kernel_size=3, strides=1, padding='SAME')
  53.                 self.conv3_3_norm_mbox_loc = nn.Conv2D(256, 4, kernel_size=3, strides=1, padding='SAME')

  54.                 self.conv4_3_norm_mbox_conf = nn.Conv2D(512, 2, kernel_size=3, strides=1, padding='SAME')
  55.                 self.conv4_3_norm_mbox_loc = nn.Conv2D(512, 4, kernel_size=3, strides=1, padding='SAME')

  56.                 self.conv5_3_norm_mbox_conf = nn.Conv2D(512, 2, kernel_size=3, strides=1, padding='SAME')
  57.                 self.conv5_3_norm_mbox_loc = nn.Conv2D(512, 4, kernel_size=3, strides=1, padding='SAME')

  58.                 self.fc7_mbox_conf = nn.Conv2D(1024, 2, kernel_size=3, strides=1, padding='SAME')
  59.                 self.fc7_mbox_loc = nn.Conv2D(1024, 4, kernel_size=3, strides=1, padding='SAME')

  60.                 self.conv6_2_mbox_conf = nn.Conv2D(512, 2, kernel_size=3, strides=1, padding='SAME')
  61.                 self.conv6_2_mbox_loc = nn.Conv2D(512, 4, kernel_size=3, strides=1, padding='SAME')

  62.                 self.conv7_2_mbox_conf = nn.Conv2D(256, 2, kernel_size=3, strides=1, padding='SAME')
  63.                 self.conv7_2_mbox_loc = nn.Conv2D(256, 4, kernel_size=3, strides=1, padding='SAME')

  64.             def forward(self, inp):
  65.                 x, = inp
  66.                 x = x - self.minus
  67.                 x = tf.nn.relu(self.conv1_1(x))
  68.                 x = tf.nn.relu(self.conv1_2(x))
  69.                 x = tf.nn.max_pool(x, [1, 2, 2, 1], [1, 2, 2, 1], "VALID")

  70.                 x = tf.nn.relu(self.conv2_1(x))
  71.                 x = tf.nn.relu(self.conv2_2(x))
  72.                 x = tf.nn.max_pool(x, [1, 2, 2, 1], [1, 2, 2, 1], "VALID")

  73.                 x = tf.nn.relu(self.conv3_1(x))
  74.                 x = tf.nn.relu(self.conv3_2(x))
  75.                 x = tf.nn.relu(self.conv3_3(x))
  76.                 f3_3 = x
  77.                 x = tf.nn.max_pool(x, [1, 2, 2, 1], [1, 2, 2, 1], "VALID")

  78.                 x = tf.nn.relu(self.conv4_1(x))
  79.                 x = tf.nn.relu(self.conv4_2(x))
  80.                 x = tf.nn.relu(self.conv4_3(x))
  81.                 f4_3 = x
  82.                 x = tf.nn.max_pool(x, [1, 2, 2, 1], [1, 2, 2, 1], "VALID")

  83.                 x = tf.nn.relu(self.conv5_1(x))
  84.                 x = tf.nn.relu(self.conv5_2(x))
  85.                 x = tf.nn.relu(self.conv5_3(x))
  86.                 f5_3 = x
  87.                 x = tf.nn.max_pool(x, [1, 2, 2, 1], [1, 2, 2, 1], "VALID")

  88.                 x = tf.nn.relu(self.fc6(x))
  89.                 x = tf.nn.relu(self.fc7(x))
  90.                 ffc7 = x

  91.                 x = tf.nn.relu(self.conv6_1(x))
  92.                 x = tf.nn.relu(self.conv6_2(x))
  93.                 f6_2 = x

  94.                 x = tf.nn.relu(self.conv7_1(x))
  95.                 x = tf.nn.relu(self.conv7_2(x))
  96.                 f7_2 = x

  97.                 f3_3 = self.conv3_3_norm(f3_3)
  98.                 f4_3 = self.conv4_3_norm(f4_3)
  99.                 f5_3 = self.conv5_3_norm(f5_3)

  100.                 cls1 = self.conv3_3_norm_mbox_conf(f3_3)
  101.                 reg1 = self.conv3_3_norm_mbox_loc(f3_3)

  102.                 cls2 = tf.nn.softmax(self.conv4_3_norm_mbox_conf(f4_3))
  103.                 reg2 = self.conv4_3_norm_mbox_loc(f4_3)

  104.                 cls3 = tf.nn.softmax(self.conv5_3_norm_mbox_conf(f5_3))
  105.                 reg3 = self.conv5_3_norm_mbox_loc(f5_3)

  106.                 cls4 = tf.nn.softmax(self.fc7_mbox_conf(ffc7))
  107.                 reg4 = self.fc7_mbox_loc(ffc7)

  108.                 cls5 = tf.nn.softmax(self.conv6_2_mbox_conf(f6_2))
  109.                 reg5 = self.conv6_2_mbox_loc(f6_2)

  110.                 cls6 = tf.nn.softmax(self.conv7_2_mbox_conf(f7_2))
  111.                 reg6 = self.conv7_2_mbox_loc(f7_2)

  112.                 # max-out background label
  113.                 bmax = tf.maximum(tf.maximum(cls1[:, :, :, 0:1], cls1[:, :, :, 1:2]), cls1[:, :, :, 2:3])

  114.                 cls1 = tf.concat([bmax, cls1[:, :, :, 3:4]], axis=-1)
  115.                 cls1 = tf.nn.softmax(cls1)

  116.                 return [cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6, reg6]

  117.         e = None
  118.         if place_model_on_cpu:
  119.             e = tf.device("/CPU:0")

  120.         if e is not None:
  121.             e.__enter__()
  122.         try:
  123.             self.model = S3FD()
  124.             self.model.load_weights(model_path)
  125.         finally:
  126.             if e is not None:
  127.                 e.__exit__(None, None, None)

  128.         self.model.build_for_run([(tf.float32, nn.get4Dshape(None, None, 3))])

  129.         self.gpu_memory_limit = self._get_gpu_memory_limit()
  130.         self.current_memory_usage = 0

  131.     def _get_gpu_memory_limit(self):
  132.         import tensorflow as tf
  133.         try:
  134.             gpus = tf.config.experimental.list_physical_devices('GPU')
  135.             if gpus:
  136.                 details = tf.config.experimental.get_device_details(gpus[0])
  137.                 if 'device_limitations' in details and 'total_memory' in details['device_limitations']:
  138.                     return details['device_limitations']['total_memory'] // (1024 ** 2)
  139.                 else:
  140.                     return 8192
  141.         except:
  142.             return 8192
  143.         return 8192

  144.     def _estimate_memory_usage(self, batch_size, avg_image_size):
  145.         memory_per_image = (avg_image_size * 3 * 4) / (1024 ** 2)  # 图像数据 (float32)
  146.         model_overhead = 50  # 模型激活值和中间层的开销
  147.         return batch_size * (memory_per_image + model_overhead)

  148.     def _get_current_gpu_memory_usage(self):
  149.         try:
  150.             import tensorflow as tf
  151.             gpus = tf.config.experimental.list_physical_devices('GPU')
  152.             if gpus:
  153.                 try:
  154.                     import pynvml
  155.                     pynvml.nvmlInit()
  156.                     handle = pynvml.nvmlDeviceGetHandleByIndex(0)
  157.                     info = pynvml.nvmlDeviceGetMemoryInfo(handle)
  158.                     return info.used // (1024 ** 2)  # 转换为MB
  159.                 except ImportError:
  160.                     return self.current_memory_usage
  161.             return 0
  162.         except:
  163.             return self.current_memory_usage

  164.     def extract_batch(self, input_images, is_bgr=True, is_remove_intersects=False,
  165.                       initial_batch_size=8, memory_safety_margin=0.8):
  166.         if not input_images:
  167.             return [[] for _ in range(len(input_images))]

  168.         all_results = []

  169.         avg_image_size = np.mean([img.shape[0] * img.shape[1] for img in input_images])

  170.         current_batch_size = initial_batch_size
  171.         max_retries = 3
  172.         retry_count = 0

  173.         i = 0
  174.         while i < len(input_images):
  175.             current_batch_size = self._adjust_batch_size(
  176.                 current_batch_size, avg_image_size, memory_safety_margin
  177.             )

  178.             end_idx = min(i + current_batch_size, len(input_images))
  179.             batch = input_images[i:end_idx]

  180.             try:
  181.                 processed_batch = []
  182.                 scales = []
  183.                 for img in batch:
  184.                     if is_bgr:
  185.                         img = img[:, :, ::-1]

  186.                     h, w, ch = img.shape
  187.                     d = max(w, h)
  188.                     scale_to = 640 if d >= 1280 else d / 2
  189.                     scale_to = max(64, scale_to)
  190.                     input_scale = d / scale_to

  191.                     resized_img = cv2.resize(img, (int(w / input_scale), int(h / input_scale)),
  192.                                              interpolation=cv2.INTER_LINEAR)
  193.                     processed_batch.append(resized_img)
  194.                     scales.append(input_scale)

  195.                 batch_tensor = np.stack(processed_batch).astype(np.float32)

  196.                 self.current_memory_usage += len(batch) * 50  # 估算每张图约50MB

  197.                 olist_batch = self.model.run([batch_tensor])

  198.                 for idx in range(len(batch)):
  199.                     single_olist = [tensor[idx:idx + 1] for tensor in olist_batch]
  200.                     detected_faces = []
  201.                     for ltrb in self.refine(single_olist):
  202.                         l, t, r, b = [x * scales[idx] for x in ltrb]
  203.                         bt = b - t
  204.                         if min(r - l, bt) < 40:
  205.                             continue
  206.                         b += bt * 0.1
  207.                         detected_faces.append([int(x) for x in (l, t, r, b)])

  208.                     detected_faces = [[(l, t, r, b), (r - l) * (b - t)] for (l, t, r, b) in detected_faces]
  209.                     detected_faces = sorted(detected_faces, key=operator.itemgetter(1), reverse=True)
  210.                     detected_faces = [x[0] for x in detected_faces]

  211.                     if is_remove_intersects:
  212.                         detected_faces = self._remove_intersecting_faces_optimized(detected_faces)

  213.                     all_results.append(detected_faces)

  214.                 i = end_idx
  215.                 retry_count = 0  # 重置重试计数

  216.             except Exception as e:
  217.                 if "out of memory" in str(e).lower() or "memory" in str(e).lower():
  218.                     if retry_count < max_retries:
  219.                         current_batch_size = max(1, current_batch_size // 2)
  220.                         retry_count += 1
  221.                         print(f"S3FDExtractor: 内存不足,减少批大小至 {current_batch_size}")
  222.                         continue
  223.                     else:
  224.                         print(f"S3FDExtractor: 无法处理图像,跳过。错误: {e}")
  225.                         for _ in range(len(batch)):
  226.                             all_results.append([])
  227.                         i = end_idx
  228.                         retry_count = 0
  229.                 else:
  230.                     print(f"S3FDExtractor: 处理图像时发生错误: {e}")
  231.                     for _ in range(len(batch)):
  232.                         all_results.append([])
  233.                     i = end_idx
  234.                     retry_count = 0

  235.         return all_results

  236.     def _adjust_batch_size(self, current_batch_size, avg_image_size, safety_margin):
  237.         if self.gpu_memory_limit == 0:  # CPU模式
  238.             return min(current_batch_size, 4)

  239.         current_usage = self._get_current_gpu_memory_usage()

  240.         available_memory = self.gpu_memory_limit * safety_margin - current_usage

  241.         required_memory = self._estimate_memory_usage(current_batch_size, avg_image_size)

  242.         if required_memory > available_memory and current_batch_size > 1:
  243.             adjusted_batch_size = int(current_batch_size * (available_memory / required_memory))
  244.             adjusted_batch_size = max(1, adjusted_batch_size)
  245.             return min(current_batch_size, adjusted_batch_size)

  246.         return current_batch_size

  247.     def _process_batch(self, input_images, is_bgr=True, is_remove_intersects=False):
  248.         batch_results = []
  249.         for img in input_images:
  250.             result = self.extract_single(img, is_bgr, is_remove_intersects)
  251.             batch_results.append(result)
  252.         return batch_results

  253.     def extract_single(self, input_image, is_bgr=True, is_remove_intersects=False):
  254.         if is_bgr:
  255.             input_image = input_image[:, :, ::-1]
  256.             is_bgr = False

  257.         (h, w, ch) = input_image.shape

  258.         d = max(w, h)
  259.         scale_to = 640 if d >= 1280 else d / 2
  260.         scale_to = max(64, scale_to)

  261.         input_scale = d / scale_to
  262.         input_image = cv2.resize(input_image, (int(w / input_scale), int(h / input_scale)), interpolation=cv2.INTER_LINEAR)

  263.         olist = self.model.run([input_image[None, ...]])

  264.         detected_faces = []
  265.         for ltrb in self.refine(olist):
  266.             l, t, r, b = [x * input_scale for x in ltrb]
  267.             bt = b - t
  268.             if min(r - l, bt) < 40:  # filtering faces < 40pix by any side
  269.                 continue
  270.             b += bt * 0.1  # enlarging bottom line a bit for 2DFAN-4
  271.             detected_faces.append([int(x) for x in (l, t, r, b)])

  272.         # sort by largest area first
  273.         detected_faces = [[(l, t, r, b), (r - l) * (b - t)] for (l, t, r, b) in detected_faces]
  274.         detected_faces = sorted(detected_faces, key=operator.itemgetter(1), reverse=True)
  275.         detected_faces = [x[0] for x in detected_faces]

  276.         if is_remove_intersects:
  277.             detected_faces = self._remove_intersecting_faces_optimized(detected_faces)

  278.         return detected_faces

  279.     def _remove_intersecting_faces_optimized(self, detected_faces):
  280.         if len(detected_faces) <= 1:
  281.             return detected_faces

  282.         bboxes = np.array(detected_faces)
  283.         areas = (bboxes[:, 2] - bboxes[:, 0]) * (bboxes[:, 3] - bboxes[:, 1])

  284.         x1, y1, x2, y2 = bboxes[:, 0], bboxes[:, 1], bboxes[:, 2], bboxes[:, 3]

  285.         xx1 = np.maximum(x1[:, None], x1[None, :])
  286.         yy1 = np.maximum(y1[:, None], y1[None, :])
  287.         xx2 = np.minimum(x2[:, None], x2[None, :])
  288.         yy2 = np.minimum(y2[:, None], y2[None, :])

  289.         w = np.maximum(0, xx2 - xx1)
  290.         h = np.maximum(0, yy2 - yy1)
  291.         inter = w * h

  292.         union = areas[:, None] + areas[None, :] - inter
  293.         iou = inter / union

  294.         to_remove = set()
  295.         for i in range(len(bboxes)):
  296.             if i in to_remove:
  297.                 continue
  298.             overlap_indices = np.where((iou[i, :] > 0.3) & (np.arange(len(bboxes)) != i))[0]
  299.             for j in overlap_indices:
  300.                 if j in to_remove:
  301.                     continue
  302.                 if areas[i] < areas[j]:
  303.                     to_remove.add(i)
  304.                     break
  305.                 else:
  306.                     to_remove.add(j)

  307.         for idx in sorted(to_remove, reverse=True):
  308.             if idx < len(detected_faces):
  309.                 detected_faces.pop(idx)

  310.         return detected_faces

  311.     def extract(self, input_image, is_bgr=True, is_remove_intersects=False):
  312.         return self.extract_single(input_image, is_bgr, is_remove_intersects)

  313.     def refine(self, olist):
  314.         bboxlist = []
  315.         for i, ((ocls,), (oreg,)) in enumerate(zip(olist[::2], olist[1::2])):
  316.             stride = 2 ** (i + 2)  # 4,8,16,32,64,128
  317.             s_d2 = stride / 2
  318.             s_m4 = stride * 4

  319.             scores = ocls[..., 1]
  320.             high_score_indices = np.where(scores > 0.05)

  321.             for hindex, windex in zip(*high_score_indices):
  322.                 score = scores[hindex, windex]
  323.                 loc = oreg[hindex, windex, :]
  324.                 priors = np.array([windex * stride + s_d2, hindex * stride + s_d2, s_m4, s_m4])
  325.                 priors_2p = priors[2:]
  326.                 box = np.concatenate((priors[:2] + loc[:2] * 0.1 * priors_2p,
  327.                                       priors_2p * np.exp(loc[2:] * 0.2)))
  328.                 box[:2] -= box[2:] / 2
  329.                 box[2:] += box[:2]

  330.                 bboxlist.append([*box, score])

  331.         bboxlist = np.array(bboxlist)
  332.         if len(bboxlist) == 0:
  333.             bboxlist = np.zeros((1, 5))

  334.         bboxlist = bboxlist[self.refine_nms(bboxlist, 0.3), :]
  335.         bboxlist = [x[:-1].astype(np.int32) for x in bboxlist if x[-1] >= 0.5]
  336.         return bboxlist

  337.     def refine_nms(self, dets, thresh):
  338.         if len(dets) == 0:
  339.             return []

  340.         dets = np.asarray(dets)
  341.         x1 = dets[:, 0]
  342.         y1 = dets[:, 1]
  343.         x2 = dets[:, 2]
  344.         y2 = dets[:, 3]
  345.         scores = dets[:, 4]

  346.         areas = (x2 - x1 + 1) * (y2 - y1 + 1)
  347.         order = scores.argsort()[::-1]

  348.         keep = []
  349.         while order.size > 0:
  350.             i = order[0]
  351.             keep.append(i)
  352.             xx1 = np.maximum(x1[i], x1[order[1:]])
  353.             yy1 = np.maximum(y1[i], y1[order[1:]])
  354.             xx2 = np.minimum(x2[i], x2[order[1:]])
  355.             yy2 = np.minimum(y2[i], y2[order[1:]])

  356.             w = np.maximum(0.0, xx2 - xx1 + 1)
  357.             h = np.maximum(0.0, yy2 - yy1 + 1)
  358.             inter = w * h
  359.             ovr = inter / (areas[i] + areas[order[1:]] - inter)

  360.             inds = np.where(ovr <= thresh)[0]
  361.             order = order[inds + 1]
  362.         return keep
复制代码


评分

参与人数 1贡献 +1 收起 理由
看花尽于秋 + 1 感谢楼主分享

查看全部评分

回复

使用道具 举报

19

主题

151

帖子

964

积分

高级丹师

Rank: 5Rank: 5

积分
964
 楼主| 发表于 5 天前 | 显示全部楼层
最后就是主调用文件Extractor.py

  1. import traceback
  2. import math
  3. import multiprocessing
  4. import operator
  5. import os
  6. import shutil
  7. import sys
  8. import time
  9. from pathlib import Path

  10. import cv2
  11. import numpy as np
  12. from numpy import linalg as npla

  13. import facelib
  14. from core import imagelib
  15. from core import mathlib
  16. from facelib import FaceType, LandmarksProcessor
  17. from core.interact import interact as io
  18. from core.joblib import Subprocessor
  19. from core.leras import nn
  20. from core import pathex
  21. from core.cv2ex import *
  22. from DFLIMG import *

  23. DEBUG = False

  24. class ExtractSubprocessor(Subprocessor):
  25.     class Data(object):
  26.         def __init__(self, filepath=None, rects=None, landmarks=None, landmarks_accurate=True, manual=False, force_output_path=None, final_output_files=None):
  27.             self.filepath = filepath
  28.             self.rects = rects or []
  29.             self.rects_rotation = 0
  30.             self.landmarks_accurate = landmarks_accurate
  31.             self.manual = manual
  32.             self.landmarks = landmarks or []
  33.             self.force_output_path = force_output_path
  34.             self.final_output_files = final_output_files or []
  35.             self.faces_detected = 0

  36.     class Cli(Subprocessor.Cli):

  37.         # override
  38.         def on_initialize(self, client_dict):
  39.             self.type = client_dict['type']
  40.             self.image_size = client_dict['image_size']
  41.             self.jpeg_quality = client_dict['jpeg_quality']
  42.             self.face_type = client_dict['face_type']
  43.             self.max_faces_from_image = client_dict['max_faces_from_image']
  44.             self.device_idx = client_dict['device_idx']
  45.             self.cpu_only = client_dict['device_type'] == 'CPU'
  46.             self.final_output_path = client_dict['final_output_path']
  47.             self.output_debug_path = client_dict['output_debug_path']

  48.             # transfer and set stdin in order to work code.interact in debug subprocess
  49.             stdin_fd = client_dict['stdin_fd']
  50.             if stdin_fd is not None and DEBUG:
  51.                 sys.stdin = os.fdopen(stdin_fd)

  52.             if self.cpu_only:
  53.                 device_config = nn.DeviceConfig.CPU()
  54.                 place_model_on_cpu = True
  55.             else:
  56.                 device_config = nn.DeviceConfig.GPUIndexes([self.device_idx])
  57.                 place_model_on_cpu = device_config.devices[0].total_mem_gb < 4

  58.             if self.type == 'all' or 'rects' in self.type or 'landmarks' in self.type:
  59.                 nn.initialize(device_config)

  60.             self.log_info(f"Running on {client_dict['device_name']}")

  61.             if self.type == 'all' or self.type == 'rects-s3fd' or 'landmarks' in self.type:
  62.                 self.rects_extractor = facelib.S3FDExtractor(place_model_on_cpu=place_model_on_cpu)

  63.             if self.type == 'all' or 'landmarks' in self.type:
  64.                 # for head type, extract "3D landmarks"
  65.                 self.landmarks_extractor = facelib.FANExtractor(landmarks_3D=self.face_type >= FaceType.HEAD,
  66.                                                                 place_model_on_cpu=place_model_on_cpu)

  67.             # 缓存图像,减少重复读取
  68.             self.cached_image = (None, None)

  69.         # override
  70.         def process_data(self, data):
  71.             if 'landmarks' in self.type and len(data.rects) == 0:
  72.                 return data

  73.             filepath = data.filepath
  74.             cached_filepath, image = self.cached_image
  75.             if cached_filepath != filepath:
  76.                 image = cv2_imread(filepath)
  77.                 if image is None:
  78.                     self.log_err(f'Failed to open {filepath}, reason: cv2_imread() fail.')
  79.                     return data
  80.                 image = imagelib.normalize_channels(image, 3)
  81.                 image = imagelib.cut_odd_image(image)
  82.                 self.cached_image = (filepath, image)

  83.             h, w, c = image.shape

  84.             if 'rects' in self.type or self.type == 'all':
  85.                 data = ExtractSubprocessor.Cli.rects_stage(data=data,
  86.                                                            image=image,
  87.                                                            max_faces_from_image=self.max_faces_from_image,
  88.                                                            rects_extractor=self.rects_extractor,
  89.                                                            )

  90.             if 'landmarks' in self.type or self.type == 'all':
  91.                 data = ExtractSubprocessor.Cli.landmarks_stage(data=data,
  92.                                                                image=image,
  93.                                                                landmarks_extractor=self.landmarks_extractor,
  94.                                                                rects_extractor=self.rects_extractor,
  95.                                                                )

  96.             if self.type == 'final' or self.type == 'all':
  97.                 data = ExtractSubprocessor.Cli.final_stage(data=data,
  98.                                                            image=image,
  99.                                                            face_type=self.face_type,
  100.                                                            image_size=self.image_size,
  101.                                                            jpeg_quality=self.jpeg_quality,
  102.                                                            output_debug_path=self.output_debug_path,
  103.                                                            final_output_path=self.final_output_path,
  104.                                                            )
  105.             return data

  106.         @staticmethod
  107.         def rects_stage(data,
  108.                         image,
  109.                         max_faces_from_image,
  110.                         rects_extractor,
  111.                         ):
  112.             """
  113.             优化的人脸矩形检测阶段
  114.             """
  115.             h, w, c = image.shape
  116.             if min(h, w) < 128:
  117.                 # Image is too small
  118.                 data.rects = []
  119.             else:
  120.                 # 使用旋转优化的检测
  121.                 rotation_angles = [0, 90, 270, 180]
  122.                 for rot in rotation_angles:
  123.                     if rot == 0:
  124.                         rotated_image = image
  125.                     elif rot == 90:
  126.                         rotated_image = image.swapaxes(0, 1)[:, ::-1, :]
  127.                     elif rot == 180:
  128.                         rotated_image = image[::-1, ::-1, :]
  129.                     elif rot == 270:
  130.                         rotated_image = image.swapaxes(0, 1)[::-1, :, :]

  131.                     rects = data.rects = rects_extractor.extract(rotated_image, is_bgr=True)
  132.                     if len(rects) != 0:
  133.                         data.rects_rotation = rot
  134.                         break

  135.                 if max_faces_from_image is not None and \
  136.                         max_faces_from_image > 0 and \
  137.                         len(data.rects) > 0:
  138.                     data.rects = data.rects[0:max_faces_from_image]
  139.             return data

  140.         @staticmethod
  141.         def landmarks_stage(data,
  142.                             image,
  143.                             landmarks_extractor,
  144.                             rects_extractor,
  145.                             ):
  146.             """
  147.             优化的关键点提取阶段
  148.             """
  149.             h, w, ch = image.shape

  150.             if data.rects_rotation == 0:
  151.                 rotated_image = image
  152.             elif data.rects_rotation == 90:
  153.                 rotated_image = image.swapaxes(0, 1)[:, ::-1, :]
  154.             elif data.rects_rotation == 180:
  155.                 rotated_image = image[::-1, ::-1, :]
  156.             elif data.rects_rotation == 270:
  157.                 rotated_image = image.swapaxes(0, 1)[::-1, :, :]

  158.             data.landmarks = landmarks_extractor.extract(rotated_image, data.rects,
  159.                                                          rects_extractor if (data.landmarks_accurate) else None,
  160.                                                          is_bgr=True)

  161.             # 优化旋转后的关键点调整
  162.             if data.rects_rotation != 0:
  163.                 for i, (rect, lmrks) in enumerate(zip(data.rects, data.landmarks)):
  164.                     new_rect, new_lmrks = rect, lmrks
  165.                     (l, t, r, b) = rect
  166.                     if data.rects_rotation == 90:
  167.                         new_rect = (t, h - l, b, h - r)
  168.                         if lmrks is not None:
  169.                             new_lmrks = lmrks[:, ::-1].copy()
  170.                             new_lmrks[:, 1] = h - new_lmrks[:, 1]
  171.                     elif data.rects_rotation == 180:
  172.                         if lmrks is not None:
  173.                             new_rect = (w - l, h - t, w - r, h - b)
  174.                             new_lmrks = lmrks.copy()
  175.                             new_lmrks[:, 0] = w - new_lmrks[:, 0]
  176.                             new_lmrks[:, 1] = h - new_lmrks[:, 1]
  177.                     elif data.rects_rotation == 270:
  178.                         new_rect = (w - b, l, w - t, r)
  179.                         if lmrks is not None:
  180.                             new_lmrks = lmrks[:, ::-1].copy()
  181.                             new_lmrks[:, 0] = w - new_lmrks[:, 0]
  182.                     data.rects[i], data.landmarks[i] = new_rect, new_lmrks

  183.             return data

  184.         @staticmethod
  185.         def final_stage(data,
  186.                         image,
  187.                         face_type,
  188.                         image_size,
  189.                         jpeg_quality,
  190.                         output_debug_path=None,
  191.                         final_output_path=None,
  192.                         ):
  193.             """
  194.             优化的最终处理阶段
  195.             """
  196.             data.final_output_files = []
  197.             filepath = data.filepath
  198.             rects = data.rects
  199.             landmarks = data.landmarks

  200.             if output_debug_path is not None:
  201.                 debug_image = image.copy()

  202.             face_idx = 0
  203.             for rect, image_landmarks in zip(rects, landmarks):
  204.                 if image_landmarks is None:
  205.                     continue

  206.                 rect = np.array(rect)

  207.                 if face_type == FaceType.MARK_ONLY:
  208.                     image_to_face_mat = None
  209.                     face_image = image
  210.                     face_image_landmarks = image_landmarks
  211.                 else:
  212.                     image_to_face_mat = LandmarksProcessor.get_transform_mat(image_landmarks, image_size, face_type)

  213.                     face_image = cv2.warpAffine(image, image_to_face_mat, (image_size, image_size), cv2.INTER_LANCZOS4)
  214.                     face_image_landmarks = LandmarksProcessor.transform_points(image_landmarks, image_to_face_mat)

  215.                     landmarks_bbox = LandmarksProcessor.transform_points([(0, 0), (0, image_size - 1), (image_size - 1, image_size - 1), (image_size - 1, 0)],
  216.                                                                          image_to_face_mat, True)

  217.                     rect_area = mathlib.polygon_area(np.array(rect[[0, 2, 2, 0]]).astype(np.float32), np.array(rect[[1, 1, 3, 3]]).astype(np.float32))
  218.                     landmarks_area = mathlib.polygon_area(landmarks_bbox[:, 0].astype(np.float32), landmarks_bbox[:, 1].astype(np.float32))

  219.                     if not data.manual and face_type <= FaceType.FULL_NO_ALIGN and landmarks_area > 4 * rect_area:  # get rid of faces which umeyama-landmark-area > 4*detector-rect-area
  220.                         continue

  221.                     if output_debug_path is not None:
  222.                         LandmarksProcessor.draw_rect_landmarks(debug_image, rect, image_landmarks, face_type, image_size, transparent_mask=True)

  223.                 output_path = final_output_path
  224.                 if data.force_output_path is not None:
  225.                     output_path = data.force_output_path

  226.                 output_filepath = output_path / f"{filepath.stem}_{face_idx}.jpg"
  227.                 cv2_imwrite(output_filepath, face_image, [int(cv2.IMWRITE_JPEG_QUALITY), jpeg_quality])

  228.                 dflimg = DFLJPG.load(output_filepath)
  229.                 dflimg.set_face_type(FaceType.toString(face_type))
  230.                 dflimg.set_landmarks(face_image_landmarks.tolist())
  231.                 dflimg.set_source_filename(filepath.name)
  232.                 dflimg.set_source_rect(rect)
  233.                 dflimg.set_source_landmarks(image_landmarks.tolist())
  234.                 dflimg.set_image_to_face_mat(image_to_face_mat)
  235.                 dflimg.save()

  236.                 data.final_output_files.append(output_filepath)
  237.                 face_idx += 1
  238.             data.faces_detected = face_idx

  239.             if output_debug_path is not None:
  240.                 cv2_imwrite(output_debug_path / (filepath.stem + '.jpg'), debug_image, [int(cv2.IMWRITE_JPEG_QUALITY), 50])

  241.             return data

  242.         # overridable
  243.         def get_data_name(self, data):
  244.             # return string identificator of your data
  245.             return data.filepath

  246.     @staticmethod
  247.     def get_devices_for_config(type, device_config):
  248.         devices = device_config.devices
  249.         cpu_only = len(devices) == 0

  250.         if 'rects' in type or \
  251.                 'landmarks' in type or \
  252.                 'all' in type:

  253.             if not cpu_only:
  254.                 if type == 'landmarks-manual':
  255.                     devices = [devices.get_best_device()]

  256.                 result = []

  257.                 for device in devices:
  258.                     count = 1

  259.                     if count == 1:
  260.                         result += [(device.index, 'GPU', device.name, device.total_mem_gb)]
  261.                     else:
  262.                         for i in range(count):
  263.                             result += [(device.index, 'GPU', f"{device.name} #{i}", device.total_mem_gb)]

  264.                 return result
  265.             else:
  266.                 if type == 'landmarks-manual':
  267.                     return [(0, 'CPU', 'CPU', 0)]
  268.                 else:
  269.                     return [(i, 'CPU', 'CPU%d' % (i), 0) for i in range(min(8, multiprocessing.cpu_count() // 2))]

  270.         elif type == 'final':
  271.             return [(i, 'CPU', 'CPU%d' % (i), 0) for i in (range(min(8, multiprocessing.cpu_count())) if not DEBUG else [0])]

  272.     def __init__(self, input_data, type, image_size=None, jpeg_quality=None, face_type=None, output_debug_path=None, manual_window_size=0, max_faces_from_image=0,
  273.                  final_output_path=None, device_config=None):
  274.         if type == 'landmarks-manual':
  275.             for x in input_data:
  276.                 x.manual = True

  277.         self.input_data = input_data

  278.         self.type = type
  279.         self.image_size = image_size
  280.         self.jpeg_quality = jpeg_quality
  281.         self.face_type = face_type
  282.         self.output_debug_path = output_debug_path
  283.         self.final_output_path = final_output_path
  284.         self.manual_window_size = manual_window_size
  285.         self.max_faces_from_image = max_faces_from_image
  286.         self.result = []

  287.         self.devices = ExtractSubprocessor.get_devices_for_config(self.type, device_config)

  288.         super().__init__('Extractor', ExtractSubprocessor.Cli,
  289.                          999999 if type == 'landmarks-manual' or DEBUG else 120)

  290.     # override
  291.     def on_clients_initialized(self):
  292.         if self.type == 'landmarks-manual':
  293.             self.wnd_name = 'Manual pass'
  294.             io.named_window(self.wnd_name)
  295.             io.capture_mouse(self.wnd_name)
  296.             io.capture_keys(self.wnd_name)

  297.             self.cache_original_image = (None, None)
  298.             self.cache_image = (None, None)
  299.             self.cache_text_lines_img = (None, None)
  300.             self.hide_help = False
  301.             self.landmarks_accurate = True
  302.             self.force_landmarks = False

  303.             self.landmarks = None
  304.             self.x = 0
  305.             self.y = 0
  306.             self.rect_size = 100
  307.             self.rect_locked = False
  308.             self.extract_needed = True

  309.             self.image = None
  310.             self.image_filepath = None

  311.         io.progress_bar(None, len(self.input_data))

  312.     # override
  313.     def on_clients_finalized(self):
  314.         if self.type == 'landmarks-manual':
  315.             io.destroy_all_windows()

  316.         io.progress_bar_close()

  317.     # override
  318.     def process_info_generator(self):
  319.         base_dict = {'type': self.type,
  320.                      'image_size': self.image_size,
  321.                      'jpeg_quality': self.jpeg_quality,
  322.                      'face_type': self.face_type,
  323.                      'max_faces_from_image': self.max_faces_from_image,
  324.                      'output_debug_path': self.output_debug_path,
  325.                      'final_output_path': self.final_output_path,
  326.                      'stdin_fd': sys.stdin.fileno()}

  327.         for (device_idx, device_type, device_name, device_total_vram_gb) in self.devices:
  328.             client_dict = base_dict.copy()
  329.             client_dict['device_idx'] = device_idx
  330.             client_dict['device_name'] = device_name
  331.             client_dict['device_type'] = device_type
  332.             yield client_dict['device_name'], {}, client_dict

  333.     # override
  334.     def get_data(self, host_dict):
  335.         if self.type == 'landmarks-manual':
  336.             need_remark_face = False
  337.             while len(self.input_data) > 0:
  338.                 data = self.input_data[0]
  339.                 filepath, data_rects, data_landmarks = data.filepath, data.rects, data.landmarks
  340.                 is_frame_done = False

  341.                 if self.image_filepath != filepath:
  342.                     self.image_filepath = filepath
  343.                     if self.cache_original_image[0] == filepath:
  344.                         self.original_image = self.cache_original_image[1]
  345.                     else:
  346.                         self.original_image = imagelib.normalize_channels(cv2_imread(filepath), 3)

  347.                         self.cache_original_image = (filepath, self.original_image)

  348.                     (h, w, c) = self.original_image.shape
  349.                     self.view_scale = 1.0 if self.manual_window_size == 0 else self.manual_window_size / (h * (16.0 / 9.0))

  350.                     if self.cache_image[0] == (h, w, c) + (self.view_scale, filepath):
  351.                         self.image = self.cache_image[1]
  352.                     else:
  353.                         self.image = cv2.resize(self.original_image, (int(w * self.view_scale), int(h * self.view_scale)), interpolation=cv2.INTER_LINEAR)
  354.                         self.cache_image = ((h, w, c) + (self.view_scale, filepath), self.image)

  355.                     (h, w, c) = self.image.shape

  356.                     sh = (0, 0, w, min(100, h))
  357.                     if self.cache_text_lines_img[0] == sh:
  358.                         self.text_lines_img = self.cache_text_lines_img[1]
  359.                     else:
  360.                         self.text_lines_img = (imagelib.get_draw_text_lines(self.image, sh,
  361.                                                                             ['[L Mouse click] - lock/unlock selection. [Mouse wheel] - change rect',
  362.                                                                              '[R Mouse Click] - manual face rectangle',
  363.                                                                              '[Enter] / [Space] - confirm / skip frame',
  364.                                                                              '[,] [.]- prev frame, next frame. [Q] - skip remaining frames',
  365.                                                                              '[a] - accuracy on/off (more fps)',
  366.                                                                              '[h] - hide this help'
  367.                                                                              ], (1, 1, 1)) * 255).astype(np.uint8)

  368.                         self.cache_text_lines_img = (sh, self.text_lines_img)

  369.                 if need_remark_face:  # need remark image from input data that already has a marked face?
  370.                     need_remark_face = False
  371.                     if len(data_rects) != 0:  # If there was already a face then lock the rectangle to it until the mouse is clicked
  372.                         self.rect = data_rects.pop()
  373.                         self.landmarks = data_landmarks.pop()
  374.                         data_rects.clear()
  375.                         data_landmarks.clear()

  376.                         self.rect_locked = True
  377.                         self.rect_size = (self.rect[2] - self.rect[0]) / 2
  378.                         self.x = (self.rect[0] + self.rect[2]) / 2
  379.                         self.y = (self.rect[1] + self.rect[3]) / 2
  380.                         self.redraw()

  381.                 if len(data_rects) == 0:
  382.                     (h, w, c) = self.image.shape
  383.                     while True:
  384.                         io.process_messages(0.0001)

  385.                         if not self.force_landmarks:
  386.                             new_x = self.x
  387.                             new_y = self.y

  388.                         new_rect_size = self.rect_size

  389.                         mouse_events = io.get_mouse_events(self.wnd_name)
  390.                         for ev in mouse_events:
  391.                             (x, y, ev, flags) = ev
  392.                             if ev == io.EVENT_MOUSEWHEEL and not self.rect_locked:
  393.                                 mod = 1 if flags > 0 else -1
  394.                                 diff = 1 if new_rect_size <= 40 else np.clip(new_rect_size / 10, 1, 10)
  395.                                 new_rect_size = max(5, new_rect_size + diff * mod)
  396.                             elif ev == io.EVENT_LBUTTONDOWN:
  397.                                 if self.force_landmarks:
  398.                                     self.x = new_x
  399.                                     self.y = new_y
  400.                                     self.force_landmarks = False
  401.                                     self.rect_locked = True
  402.                                     self.redraw()
  403.                                 else:
  404.                                     self.rect_locked = not self.rect_locked
  405.                                     self.extract_needed = True
  406.                             elif ev == io.EVENT_RBUTTONDOWN:
  407.                                 self.force_landmarks = not self.force_landmarks
  408.                                 if self.force_landmarks:
  409.                                     self.rect_locked = False
  410.                             elif not self.rect_locked:
  411.                                 new_x = np.clip(x, 0, w - 1) / self.view_scale
  412.                                 new_y = np.clip(y, 0, h - 1) / self.view_scale

  413.                         key_events = io.get_key_events(self.wnd_name)
  414.                         key, chr_key, ctrl_pressed, alt_pressed, shift_pressed = key_events[-1] if len(key_events) > 0 else (0, 0, False, False, False)

  415.                         if key == ord('\r') or key == ord('\n'):
  416.                             # confirm frame
  417.                             is_frame_done = True
  418.                             data_rects.append(self.rect)
  419.                             data_landmarks.append(self.landmarks)
  420.                             break
  421.                         elif key == ord(' '):
  422.                             # confirm skip frame
  423.                             is_frame_done = True
  424.                             break
  425.                         elif key == ord(',') and len(self.result) > 0:
  426.                             # go prev frame

  427.                             if self.rect_locked:
  428.                                 self.rect_locked = False
  429.                                 # Only save the face if the rect is still locked
  430.                                 data_rects.append(self.rect)
  431.                                 data_landmarks.append(self.landmarks)

  432.                             self.input_data.insert(0, self.result.pop())
  433.                             io.progress_bar_inc(-1)
  434.                             need_remark_face = True

  435.                             break
  436.                         elif key == ord('.'):
  437.                             # go next frame

  438.                             if self.rect_locked:
  439.                                 self.rect_locked = False
  440.                                 # Only save the face if the rect is still locked
  441.                                 data_rects.append(self.rect)
  442.                                 data_landmarks.append(self.landmarks)

  443.                             need_remark_face = True
  444.                             is_frame_done = True
  445.                             break
  446.                         elif key == ord('q'):
  447.                             # skip remaining

  448.                             if self.rect_locked:
  449.                                 self.rect_locked = False
  450.                                 data.rects.append(self.rect)
  451.                                 data.landmarks.append(self.landmarks)

  452.                             while len(self.input_data) > 0:
  453.                                 self.result.append(self.input_data.pop(0))
  454.                                 io.progress_bar_inc(1)

  455.                             break

  456.                         elif key == ord('h'):
  457.                             self.hide_help = not self.hide_help
  458.                             break
  459.                         elif key == ord('a'):
  460.                             self.landmarks_accurate = not self.landmarks_accurate
  461.                             break

  462.                         if self.force_landmarks:
  463.                             pt2 = np.float32([new_x, new_y])
  464.                             pt1 = np.float32([self.x, self.y])

  465.                             pt_vec_len = npla.norm(pt2 - pt1)
  466.                             pt_vec = pt2 - pt1
  467.                             if pt_vec_len != 0:
  468.                                 pt_vec /= pt_vec_len

  469.                             self.rect_size = pt_vec_len
  470.                             self.rect = (int(self.x - self.rect_size),
  471.                                          int(self.y - self.rect_size),
  472.                                          int(self.x + self.rect_size),
  473.                                          int(self.y + self.rect_size))

  474.                             if pt_vec_len > 0:
  475.                                 lmrks = np.concatenate((np.zeros((17, 2), np.float32), LandmarksProcessor.landmarks_2D), axis=0)
  476.                                 lmrks -= lmrks[30:31, :]
  477.                                 mat = cv2.getRotationMatrix2D((0, 0), -np.arctan2(pt_vec[1], pt_vec[0]) * 180 / math.pi, pt_vec_len)
  478.                                 mat[:, 2] += (self.x, self.y)
  479.                                 self.landmarks = LandmarksProcessor.transform_points(lmrks, mat)

  480.                             self.redraw()

  481.                         elif self.x != new_x or \
  482.                                 self.y != new_y or \
  483.                                 self.rect_size != new_rect_size or \
  484.                                 self.extract_needed:
  485.                             self.x = new_x
  486.                             self.y = new_y
  487.                             self.rect_size = new_rect_size
  488.                             self.rect = (int(self.x - self.rect_size),
  489.                                          int(self.y - self.rect_size),
  490.                                          int(self.x + self.rect_size),
  491.                                          int(self.y + self.rect_size))

  492.                             return ExtractSubprocessor.Data(filepath, rects=[self.rect], landmarks_accurate=self.landmarks_accurate)

  493.                 else:
  494.                     is_frame_done = True

  495.                 if is_frame_done:
  496.                     self.result.append(data)
  497.                     self.input_data.pop(0)
  498.                     io.progress_bar_inc(1)
  499.                     self.extract_needed = True
  500.                     self.rect_locked = False
  501.         else:
  502.             if len(self.input_data) > 0:
  503.                 return self.input_data.pop(0)

  504.         return None

  505.     # override
  506.     def on_data_return(self, host_dict, data):
  507.         if not self.type != 'landmarks-manual':
  508.             self.input_data.insert(0, data)

  509.     def redraw(self):
  510.         (h, w, c) = self.image.shape

  511.         if not self.hide_help:
  512.             image = cv2.addWeighted(self.image, 1.0, self.text_lines_img, 1.0, 0)
  513.         else:
  514.             image = self.image.copy()

  515.         view_rect = (np.array(self.rect) * self.view_scale).astype(np.int32).tolist()
  516.         view_landmarks = (np.array(self.landmarks) * self.view_scale).astype(np.int32).tolist()

  517.         if self.rect_size <= 40:
  518.             scaled_rect_size = h // 3 if h < w else w // 3

  519.             p1 = (self.x - self.rect_size, self.y - self.rect_size)
  520.             p2 = (self.x + self.rect_size, self.y - self.rect_size)
  521.             p3 = (self.x - self.rect_size, self.y + self.rect_size)

  522.             wh = h if h < w else w
  523.             np1 = (w / 2 - wh / 4, h / 2 - wh / 4)
  524.             np2 = (w / 2 + wh / 4, h / 2 - wh / 4)
  525.             np3 = (w / 2 - wh / 4, h / 2 + wh / 4)

  526.             mat = cv2.getAffineTransform(np.float32([p1, p2, p3]) * self.view_scale, np.float32([np1, np2, np3]))
  527.             image = cv2.warpAffine(image, mat, (w, h))
  528.             view_landmarks = LandmarksProcessor.transform_points(view_landmarks, mat)

  529.         landmarks_color = (255, 255, 0) if self.rect_locked else (0, 255, 0)
  530.         LandmarksProcessor.draw_rect_landmarks(image, view_rect, view_landmarks, self.face_type, self.image_size, landmarks_color=landmarks_color)
  531.         self.extract_needed = False

  532.         io.show_image(self.wnd_name, image)

  533.     # override
  534.     def on_result(self, host_dict, data, result):
  535.         if self.type == 'landmarks-manual':
  536.             filepath, landmarks = result.filepath, result.landmarks

  537.             if len(landmarks) != 0 and landmarks[0] is not None:
  538.                 self.landmarks = landmarks[0]

  539.             self.redraw()
  540.         else:
  541.             self.result.append(result)
  542.             io.progress_bar_inc(1)

  543.     # override
  544.     def get_result(self):
  545.         return self.result


  546. class DeletedFilesSearcherSubprocessor(Subprocessor):
  547.     class Cli(Subprocessor.Cli):
  548.         # override
  549.         def on_initialize(self, client_dict):
  550.             self.debug_paths_stems = client_dict['debug_paths_stems']
  551.             return None

  552.         # override
  553.         def process_data(self, data):
  554.             input_path_stem = Path(data[0]).stem
  555.             return any([input_path_stem == d_stem for d_stem in self.debug_paths_stems])

  556.         # override
  557.         def get_data_name(self, data):
  558.             # return string identificator of your data
  559.             return data[0]

  560.     # override
  561.     def __init__(self, input_paths, debug_paths):
  562.         self.input_paths = input_paths
  563.         self.debug_paths_stems = [Path(d).stem for d in debug_paths]
  564.         self.result = []
  565.         super().__init__('DeletedFilesSearcherSubprocessor', DeletedFilesSearcherSubprocessor.Cli, 60)

  566.     # override
  567.     def process_info_generator(self):
  568.         for i in range(min(multiprocessing.cpu_count(), 8)):
  569.             yield 'CPU%d' % (i), {}, {'debug_paths_stems': self.debug_paths_stems}

  570.     # override
  571.     def on_clients_initialized(self):
  572.         io.progress_bar("Searching deleted files", len(self.input_paths))

  573.     # override
  574.     def on_clients_finalized(self):
  575.         io.progress_bar_close()

  576.     # override
  577.     def get_data(self, host_dict):
  578.         if len(self.input_paths) > 0:
  579.             return [self.input_paths.pop(0)]
  580.         return None

  581.     # override
  582.     def on_data_return(self, host_dict, data):
  583.         self.input_paths.insert(0, data[0])

  584.     # override
  585.     def on_result(self, host_dict, data, result):
  586.         if result == False:
  587.             self.result.append(data[0])
  588.         io.progress_bar_inc(1)

  589.     # override
  590.     def get_result(self):
  591.         return self.result


  592. def main(detector=None,
  593.          input_path=None,
  594.          output_path=None,
  595.          output_debug=None,
  596.          manual_fix=False,
  597.          manual_output_debug_fix=False,
  598.          manual_window_size=1368,
  599.          face_type='full_face',
  600.          max_faces_from_image=None,
  601.          image_size=None,
  602.          jpeg_quality=None,
  603.          cpu_only=False,
  604.          force_gpu_idxs=None,
  605.          ):
  606.     if not input_path.exists():
  607.         io.log_err('Input directory not found. Please ensure it exists.')
  608.         return

  609.     if not output_path.exists():
  610.         output_path.mkdir(parents=True, exist_ok=True)

  611.     if face_type is not None:
  612.         face_type = FaceType.fromString(face_type)

  613.     if face_type is None:
  614.         if manual_output_debug_fix:
  615.             files = pathex.get_image_paths(output_path)
  616.             if len(files) != 0:
  617.                 dflimg = DFLIMG.load(Path(files[0]))
  618.                 if dflimg is not None and dflimg.has_data():
  619.                     face_type = FaceType.fromString(dflimg.get_face_type())

  620.     input_image_paths = pathex.get_image_unique_filestem_paths(input_path, verbose_print_func=io.log_info)
  621.     output_images_paths = pathex.get_image_paths(output_path)
  622.     output_debug_path = output_path.parent / (output_path.name + '_debug')

  623.     continue_extraction = False
  624.     if not manual_output_debug_fix and len(output_images_paths) > 0:
  625.         if len(output_images_paths) > 128:
  626.             continue_extraction = io.input_bool("Continue extraction?", True, help_message="Extraction can be continued, but you must specify the same options again.")

  627.         if len(output_images_paths) > 128 and continue_extraction:
  628.             try:
  629.                 input_image_paths = input_image_paths[[Path(x).stem for x in input_image_paths].index(Path(output_images_paths[-128]).stem.split('_')[0]):]
  630.             except:
  631.                 io.log_err("Error in fetching the last index. Extraction cannot be continued.")
  632.                 return
  633.         elif input_path != output_path:
  634.             io.input(f"\n WARNING !!! \n {output_path} contains files! \n They will be deleted. \n Press enter to continue.\n")
  635.             for filename in output_images_paths:
  636.                 Path(filename).unlink()

  637.     device_config = nn.DeviceConfig.GPUIndexes(force_gpu_idxs or nn.ask_choose_device_idxs(choose_only_one=detector == 'manual', suggest_all_gpu=True)) \
  638.         if not cpu_only else nn.DeviceConfig.CPU()

  639.     if face_type is None:
  640.         face_type = io.input_str("Face type", 'wf', ['f', 'wf', 'head'],
  641.                                  help_message="Full face / whole face / head. 'Whole face' covers full area of face include forehead. 'head' covers full head, but requires XSeg for src and dst faceset.").lower()
  642.         face_type = {'f': FaceType.FULL,
  643.                      'wf': FaceType.WHOLE_FACE,
  644.                      'head': FaceType.HEAD}[face_type]

  645.     if max_faces_from_image is None:
  646.         max_faces_from_image = io.input_int(f"Max number of faces from image", 0,
  647.                                             help_message="If you extract a src faceset that has frames with a large number of faces, it is advisable to set max faces to 3 to speed up extraction. 0 - unlimited")

  648.     if image_size is None:
  649.         image_size = io.input_int(f"Image size", 512 if face_type < FaceType.HEAD else 768, valid_range=[256, 2048],
  650.                                   help_message="Output image size. The higher image size, the worse face-enhancer works. Use higher than 512 value only if the source image is sharp enough and the face does not need to be enhanced.")

  651.     if jpeg_quality is None:
  652.         jpeg_quality = io.input_int(f"Jpeg quality", 90, valid_range=[1, 100], help_message="Jpeg quality. The higher jpeg quality the larger the output file size.")

  653.     if detector is None:
  654.         io.log_info("Choose detector type.")
  655.         io.log_info("[0] S3FD")
  656.         io.log_info("[1] manual")
  657.         detector = {0: 's3fd', 1: 'manual'}[io.input_int("", 0, [0, 1])]

  658.     if output_debug is None:
  659.         output_debug = io.input_bool(f"Write debug images to {output_debug_path.name}?", False)

  660.     if output_debug:
  661.         output_debug_path.mkdir(parents=True, exist_ok=True)

  662.     if manual_output_debug_fix:
  663.         if not output_debug_path.exists():
  664.             io.log_err(f'{output_debug_path} not found. Re-extract faces with "Write debug images" option.')
  665.             return
  666.         else:
  667.             detector = 'manual'
  668.             io.log_info('Performing re-extract frames which were deleted from _debug directory.')

  669.             input_image_paths = DeletedFilesSearcherSubprocessor(input_image_paths, pathex.get_image_paths(output_debug_path)).run()
  670.             input_image_paths = sorted(input_image_paths)
  671.             io.log_info('Found %d images.' % (len(input_image_paths)))
  672.     else:
  673.         if not continue_extraction and output_debug_path.exists():
  674.             for filename in pathex.get_image_paths(output_debug_path):
  675.                 Path(filename).unlink()

  676.     images_found = len(input_image_paths)
  677.     faces_detected = 0
  678.     if images_found != 0:
  679.         if detector == 'manual':
  680.             io.log_info('Performing manual extract...')
  681.             data = ExtractSubprocessor([ExtractSubprocessor.Data(Path(filename)) for filename in input_image_paths], 'landmarks-manual', image_size, jpeg_quality,
  682.                                        face_type, output_debug_path if output_debug else None, manual_window_size=manual_window_size, device_config=device_config).run()

  683.             io.log_info('Performing 3rd pass...')
  684.             data = ExtractSubprocessor(data, 'final', image_size, jpeg_quality, face_type, output_debug_path if output_debug else None, final_output_path=output_path,
  685.                                        device_config=device_config).run()

  686.         else:
  687.             io.log_info('Extracting faces...')
  688.             data = ExtractSubprocessor([ExtractSubprocessor.Data(Path(filename)) for filename in input_image_paths],
  689.                                        'all',
  690.                                        image_size,
  691.                                        jpeg_quality,
  692.                                        face_type,
  693.                                        output_debug_path if output_debug else None,
  694.                                        max_faces_from_image=max_faces_from_image,
  695.                                        final_output_path=output_path,
  696.                                        device_config=device_config).run()

  697.         faces_detected += sum([d.faces_detected for d in data])

  698.         if manual_fix:
  699.             if all(np.array([d.faces_detected > 0 for d in data]) == True):
  700.                 io.log_info('All faces are detected, manual fix not needed.')
  701.             else:
  702.                 fix_data = [ExtractSubprocessor.Data(d.filepath) for d in data if d.faces_detected == 0]
  703.                 io.log_info('Performing manual fix for %d images...' % (len(fix_data)))
  704.                 fix_data = ExtractSubprocessor(fix_data, 'landmarks-manual', image_size, jpeg_quality, face_type, output_debug_path if output_debug else None,
  705.                                                manual_window_size=manual_window_size, device_config=device_config).run()
  706.                 fix_data = ExtractSubprocessor(fix_data, 'final', image_size, jpeg_quality, face_type, output_debug_path if output_debug else None,
  707.                                                final_output_path=output_path, device_config=device_config).run()
  708.                 faces_detected += sum([d.faces_detected for d in fix_data])

  709.     io.log_info('-------------------------')
  710.     io.log_info('Images found:        %d' % (images_found))
  711.     io.log_info('Faces detected:      %d' % (faces_detected))
  712.     io.log_info('-------------------------')
复制代码
回复 支持 2 反对 0

使用道具 举报

19

主题

151

帖子

964

积分

高级丹师

Rank: 5Rank: 5

积分
964
 楼主| 发表于 5 天前 | 显示全部楼层
优化了gpu并行能力,可以同时处理的图片不是一张,而是批次的多图 ,还有优化了些计算,等等,忘记了都写了什么了。   速度随着显卡的大小和gpu数,呈几何倍数递增。

请大家看清楚,是优化了代码,不是他们那些所谓的靠丢失脸部的提速。

至于具体怎么样,看你们自己的硬件了。

想起来 实时监控gpu去动态加载和调整批处理
回复 支持 反对

使用道具 举报

34

主题

293

帖子

9284

积分

高级丹圣

Rank: 13Rank: 13Rank: 13Rank: 13

积分
9284
发表于 5 天前 | 显示全部楼层
day270010678 发表于 2026-1-7 20:11
优化了gpu并行能力,可以同时处理的图片不是一张,而是批次的多图 ,还有优化了些计算,等等,忘记了都写了 ...

大佬请问提脸主要是这三个py文件吗?如果我想修改提脸的过滤逻辑(我记得原版的过滤好像是过滤10像素一下的人脸),如果我想把提脸的过滤改成用户交互的动态选择。是不是也是修改这三个py文件就行了
回复 支持 反对

使用道具 举报

19

主题

151

帖子

964

积分

高级丹师

Rank: 5Rank: 5

积分
964
 楼主| 发表于 5 天前 | 显示全部楼层
本帖最后由 day270010678 于 2026-1-7 21:00 编辑
fghfdg 发表于 2026-1-7 20:53
大佬请问提脸主要是这三个py文件吗?如果我想修改提脸的过滤逻辑(我记得原版的过滤好像是过滤10像素一下 ...

肯定是这里修改,不是你写交互干嘛?自己设定就行了。你过虑什么?是像他们一样靠丢失提速?在S3FDExtractor.py里面不是有个阀值吗?默认0.5,你随便提高,你切脸速度马上就提升了,只是丢失严重罢了。
回复 支持 反对

使用道具 举报

49

主题

358

帖子

4685

积分

高级丹圣

Rank: 13Rank: 13Rank: 13Rank: 13

积分
4685

万事如意节日勋章

发表于 5 天前 | 显示全部楼层
你好,我看你早期还发过贴讨论src切脸俯仰角度 五官对称问题,现在你有答案了吗?按对称标会出问题吗
回复 支持 反对

使用道具 举报

19

主题

151

帖子

964

积分

高级丹师

Rank: 5Rank: 5

积分
964
 楼主| 发表于 5 天前 | 显示全部楼层
DFL小白02 发表于 2026-1-7 21:05
你好,我看你早期还发过贴讨论src切脸俯仰角度 五官对称问题,现在你有答案了吗?按对称标会出问题吗 ...

那个时候刚开始接触,没研究过他代码,现在就这种对称的问题,解决法子多了去了,比方说降低阀值,最佳法子肯定是修改旋转逻辑
回复 支持 反对

使用道具 举报

19

主题

151

帖子

964

积分

高级丹师

Rank: 5Rank: 5

积分
964
 楼主| 发表于 5 天前 | 显示全部楼层
在模型推理内部实现真正的多线程并行(修改模型架构以支持动态批次大小,利用TensorFlow/CUDA的内置并行机制等等),每个子进程可以一次处理多个图像。有人研究过这个方向不? 我试验了下TensorFlow的内置并行机制,效果老是不理想,感觉是鸡肋。


回复 支持 反对

使用道具 举报

9

主题

39

帖子

4996

积分

高级丹圣

Rank: 13Rank: 13Rank: 13Rank: 13

积分
4996
发表于 5 天前 | 显示全部楼层
楼主能放出文件直接下载吗?小白不会编辑
回复 支持 反对

使用道具 举报

44

主题

1043

帖子

5954

积分

高级丹圣

Rank: 13Rank: 13Rank: 13Rank: 13

积分
5954

万事如意节日勋章开心娱乐节日勋章

发表于 4 天前 | 显示全部楼层
A卡有效果吗
回复 支持 反对

使用道具 举报

QQ|Archiver|手机版|deepfacelab中文网 |网站地图

GMT+8, 2026-1-12 09:39 , Processed in 0.120162 second(s), 37 queries .

Powered by Discuz! X3.4

Copyright © 2001-2020, Tencent Cloud.

快速回复 返回顶部 返回列表