diff mbox series

[FFmpeg-devel,4/4] libavfilter/vf_dnn_detect: Add yolov4 support

Message ID 20231204053633.1743228-4-wenbin.chen@intel.com
State New
Headers show
Series [FFmpeg-devel,1/4] libavfiter/dnn/dnn_backend_openvino: add multiple output support | expand

Checks

Context Check Description
andriy/make_x86 success Make finished
andriy/make_fate_x86 success Make fate finished

Commit Message

Chen, Wenbin Dec. 4, 2023, 5:36 a.m. UTC
From: Wenbin Chen <wenbin.chen@intel.com>

The difference of yolov4 is that sigmoid function needed to be applied
on x, y coordinates. Also make it compatiple with NHWC output as the
yolov4 model from openvino model zoo has NHWC output layout.

Model refer to: https://github.com/openvinotoolkit/open_model_zoo/tree/master/models/public/yolo-v4-tf

Signed-off-by: Wenbin Chen <wenbin.chen@intel.com>
---
 libavfilter/vf_dnn_detect.c | 71 ++++++++++++++++++++++++++++++-------
 1 file changed, 59 insertions(+), 12 deletions(-)
diff mbox series

Patch

diff --git a/libavfilter/vf_dnn_detect.c b/libavfilter/vf_dnn_detect.c
index 7a32b191c3..1b04a2cb98 100644
--- a/libavfilter/vf_dnn_detect.c
+++ b/libavfilter/vf_dnn_detect.c
@@ -35,7 +35,8 @@ 
 typedef enum {
     DDMT_SSD,
     DDMT_YOLOV1V2,
-    DDMT_YOLOV3
+    DDMT_YOLOV3,
+    DDMT_YOLOV4
 } DNNDetectionModelType;
 
 typedef struct DnnDetectContext {
@@ -75,6 +76,7 @@  static const AVOption dnn_detect_options[] = {
         { "ssd",     "output shape [1, 1, N, 7]",  0,                        AV_OPT_TYPE_CONST,       { .i64 = DDMT_SSD },    0, 0, FLAGS, "model_type" },
         { "yolo",    "output shape [1, N*Cx*Cy*DetectionBox]",  0,           AV_OPT_TYPE_CONST,       { .i64 = DDMT_YOLOV1V2 },    0, 0, FLAGS, "model_type" },
         { "yolov3",  "outputs shape [1, N*D, Cx, Cy]",  0,                   AV_OPT_TYPE_CONST,       { .i64 = DDMT_YOLOV3 },      0, 0, FLAGS, "model_type" },
+        { "yolov4",  "outputs shape [1, N*D, Cx, Cy]",  0,                   AV_OPT_TYPE_CONST,       { .i64 = DDMT_YOLOV4 },    0, 0, FLAGS, "model_type" },
     { "cell_w",      "cell width",                 OFFSET2(cell_w),          AV_OPT_TYPE_INT,       { .i64 = 0 },    0, INTMAX_MAX, FLAGS },
     { "cell_h",      "cell height",                OFFSET2(cell_h),          AV_OPT_TYPE_INT,       { .i64 = 0 },    0, INTMAX_MAX, FLAGS },
     { "nb_classes",  "The number of class",        OFFSET2(nb_classes),      AV_OPT_TYPE_INT,       { .i64 = 0 },    0, INTMAX_MAX, FLAGS },
@@ -84,6 +86,14 @@  static const AVOption dnn_detect_options[] = {
 
 AVFILTER_DEFINE_CLASS(dnn_detect);
 
+static inline float sigmoid(float x) {
+    return 1.f / (1.f + exp(-x));
+}
+
+static inline float linear(float x) {
+    return x;
+}
+
 static int dnn_detect_get_label_id(int nb_classes, int cell_size, float *label_data)
 {
     float max_prob = 0;
@@ -142,6 +152,8 @@  static int dnn_detect_parse_yolo_output(AVFrame *frame, DNNData *output, int out
     float *output_data = output[output_index].data;
     float *anchors = ctx->anchors;
     AVDetectionBBox *bbox;
+    float (*post_process_raw_data)(float x);
+    int is_NHWC = 0;
 
     if (ctx->model_type == DDMT_YOLOV1V2) {
         cell_w = ctx->cell_w;
@@ -149,13 +161,30 @@  static int dnn_detect_parse_yolo_output(AVFrame *frame, DNNData *output, int out
         scale_w = cell_w;
         scale_h = cell_h;
     } else {
-        cell_w = output[output_index].width;
-        cell_h = output[output_index].height;
+        if (output[output_index].height != output[output_index].width &&
+            output[output_index].height == output[output_index].channels) {
+            is_NHWC = 1;
+            cell_w = output[output_index].height;
+            cell_h = output[output_index].channels;
+        } else {
+            cell_w = output[output_index].width;
+            cell_h = output[output_index].height;
+        }
         scale_w = ctx->scale_width;
         scale_h = ctx->scale_height;
     }
     box_size = nb_classes + 5;
 
+    switch (ctx->model_type) {
+    case DDMT_YOLOV1V2:
+    case DDMT_YOLOV3:
+        post_process_raw_data = linear;
+        break;
+    case DDMT_YOLOV4:
+        post_process_raw_data = sigmoid;
+         break;
+    }
+
     if (!cell_h || !cell_w) {
         av_log(filter_ctx, AV_LOG_ERROR, "cell_w and cell_h are detected\n");
         return AVERROR(EINVAL);
@@ -193,19 +222,36 @@  static int dnn_detect_parse_yolo_output(AVFrame *frame, DNNData *output, int out
                 float *detection_boxes_data;
                 int label_id;
 
-                detection_boxes_data = output_data + box_id * box_size * cell_w * cell_h;
-                conf = detection_boxes_data[cy * cell_w + cx + 4 * cell_w * cell_h];
+                if (is_NHWC) {
+                    detection_boxes_data = output_data +
+                        ((cy * cell_w + cx) * detection_boxes + box_id) * box_size;
+                    conf = post_process_raw_data(detection_boxes_data[4]);
+                } else {
+                    detection_boxes_data = output_data + box_id * box_size * cell_w * cell_h;
+                    conf = post_process_raw_data(
+                                detection_boxes_data[cy * cell_w + cx + 4 * cell_w * cell_h]);
+                }
                 if (conf < conf_threshold) {
                     continue;
                 }
 
-                x    = detection_boxes_data[cy * cell_w + cx];
-                y    = detection_boxes_data[cy * cell_w + cx + cell_w * cell_h];
-                w    = detection_boxes_data[cy * cell_w + cx + 2 * cell_w * cell_h];
-                h    = detection_boxes_data[cy * cell_w + cx + 3 * cell_w * cell_h];
-                label_id = dnn_detect_get_label_id(ctx->nb_classes, cell_w * cell_h,
-                                    detection_boxes_data + cy * cell_w + cx + 5 * cell_w * cell_h);
-                conf = conf * detection_boxes_data[cy * cell_w + cx + (label_id + 5) * cell_w * cell_h];
+                if (is_NHWC) {
+                    x = post_process_raw_data(detection_boxes_data[0]);
+                    y = post_process_raw_data(detection_boxes_data[1]);
+                    w = detection_boxes_data[2];
+                    h = detection_boxes_data[3];
+                    label_id = dnn_detect_get_label_id(ctx->nb_classes, 1, detection_boxes_data + 5);
+                    conf = conf * post_process_raw_data(detection_boxes_data[label_id + 5]);
+                } else {
+                    x = post_process_raw_data(detection_boxes_data[cy * cell_w + cx]);
+                    y = post_process_raw_data(detection_boxes_data[cy * cell_w + cx + cell_w * cell_h]);
+                    w = detection_boxes_data[cy * cell_w + cx + 2 * cell_w * cell_h];
+                    h = detection_boxes_data[cy * cell_w + cx + 3 * cell_w * cell_h];
+                    label_id = dnn_detect_get_label_id(ctx->nb_classes, cell_w * cell_h,
+                        detection_boxes_data + cy * cell_w + cx + 5 * cell_w * cell_h);
+                    conf = conf * post_process_raw_data(
+                                detection_boxes_data[cy * cell_w + cx + (label_id + 5) * cell_w * cell_h]);
+                }
 
                 bbox = av_mallocz(sizeof(*bbox));
                 if (!bbox)
@@ -404,6 +450,7 @@  static int dnn_detect_post_proc_ov(AVFrame *frame, DNNData *output, int nb_outpu
         if (ret < 0)
             return ret;
     case DDMT_YOLOV3:
+    case DDMT_YOLOV4:
         ret = dnn_detect_post_proc_yolov3(frame, output, filter_ctx, nb_outputs);
         if (ret < 0)
             return ret;