@@ -35,7 +35,8 @@
typedef enum {
DDMT_SSD,
DDMT_YOLOV1V2,
- DDMT_YOLOV3
+ DDMT_YOLOV3,
+ DDMT_YOLOV4
} DNNDetectionModelType;
typedef struct DnnDetectContext {
@@ -75,6 +76,7 @@ static const AVOption dnn_detect_options[] = {
{ "ssd", "output shape [1, 1, N, 7]", 0, AV_OPT_TYPE_CONST, { .i64 = DDMT_SSD }, 0, 0, FLAGS, "model_type" },
{ "yolo", "output shape [1, N*Cx*Cy*DetectionBox]", 0, AV_OPT_TYPE_CONST, { .i64 = DDMT_YOLOV1V2 }, 0, 0, FLAGS, "model_type" },
{ "yolov3", "outputs shape [1, N*D, Cx, Cy]", 0, AV_OPT_TYPE_CONST, { .i64 = DDMT_YOLOV3 }, 0, 0, FLAGS, "model_type" },
+ { "yolov4", "outputs shape [1, N*D, Cx, Cy]", 0, AV_OPT_TYPE_CONST, { .i64 = DDMT_YOLOV4 }, 0, 0, FLAGS, "model_type" },
{ "cell_w", "cell width", OFFSET2(cell_w), AV_OPT_TYPE_INT, { .i64 = 0 }, 0, INTMAX_MAX, FLAGS },
{ "cell_h", "cell height", OFFSET2(cell_h), AV_OPT_TYPE_INT, { .i64 = 0 }, 0, INTMAX_MAX, FLAGS },
{ "nb_classes", "The number of class", OFFSET2(nb_classes), AV_OPT_TYPE_INT, { .i64 = 0 }, 0, INTMAX_MAX, FLAGS },
@@ -84,6 +86,14 @@ static const AVOption dnn_detect_options[] = {
AVFILTER_DEFINE_CLASS(dnn_detect);
+static inline float sigmoid(float x) {
+ return 1.f / (1.f + exp(-x));
+}
+
+static inline float linear(float x) {
+ return x;
+}
+
static int dnn_detect_get_label_id(int nb_classes, int cell_size, float *label_data)
{
float max_prob = 0;
@@ -142,6 +152,8 @@ static int dnn_detect_parse_yolo_output(AVFrame *frame, DNNData *output, int out
float *output_data = output[output_index].data;
float *anchors = ctx->anchors;
AVDetectionBBox *bbox;
+ float (*post_process_raw_data)(float x);
+ int is_NHWC = 0;
if (ctx->model_type == DDMT_YOLOV1V2) {
cell_w = ctx->cell_w;
@@ -149,13 +161,30 @@ static int dnn_detect_parse_yolo_output(AVFrame *frame, DNNData *output, int out
scale_w = cell_w;
scale_h = cell_h;
} else {
- cell_w = output[output_index].width;
- cell_h = output[output_index].height;
+ if (output[output_index].height != output[output_index].width &&
+ output[output_index].height == output[output_index].channels) {
+ is_NHWC = 1;
+ cell_w = output[output_index].height;
+ cell_h = output[output_index].channels;
+ } else {
+ cell_w = output[output_index].width;
+ cell_h = output[output_index].height;
+ }
scale_w = ctx->scale_width;
scale_h = ctx->scale_height;
}
box_size = nb_classes + 5;
+ switch (ctx->model_type) {
+ case DDMT_YOLOV1V2:
+ case DDMT_YOLOV3:
+ post_process_raw_data = linear;
+ break;
+ case DDMT_YOLOV4:
+ post_process_raw_data = sigmoid;
+ break;
+ }
+
if (!cell_h || !cell_w) {
av_log(filter_ctx, AV_LOG_ERROR, "cell_w and cell_h are detected\n");
return AVERROR(EINVAL);
@@ -193,19 +222,36 @@ static int dnn_detect_parse_yolo_output(AVFrame *frame, DNNData *output, int out
float *detection_boxes_data;
int label_id;
- detection_boxes_data = output_data + box_id * box_size * cell_w * cell_h;
- conf = detection_boxes_data[cy * cell_w + cx + 4 * cell_w * cell_h];
+ if (is_NHWC) {
+ detection_boxes_data = output_data +
+ ((cy * cell_w + cx) * detection_boxes + box_id) * box_size;
+ conf = post_process_raw_data(detection_boxes_data[4]);
+ } else {
+ detection_boxes_data = output_data + box_id * box_size * cell_w * cell_h;
+ conf = post_process_raw_data(
+ detection_boxes_data[cy * cell_w + cx + 4 * cell_w * cell_h]);
+ }
if (conf < conf_threshold) {
continue;
}
- x = detection_boxes_data[cy * cell_w + cx];
- y = detection_boxes_data[cy * cell_w + cx + cell_w * cell_h];
- w = detection_boxes_data[cy * cell_w + cx + 2 * cell_w * cell_h];
- h = detection_boxes_data[cy * cell_w + cx + 3 * cell_w * cell_h];
- label_id = dnn_detect_get_label_id(ctx->nb_classes, cell_w * cell_h,
- detection_boxes_data + cy * cell_w + cx + 5 * cell_w * cell_h);
- conf = conf * detection_boxes_data[cy * cell_w + cx + (label_id + 5) * cell_w * cell_h];
+ if (is_NHWC) {
+ x = post_process_raw_data(detection_boxes_data[0]);
+ y = post_process_raw_data(detection_boxes_data[1]);
+ w = detection_boxes_data[2];
+ h = detection_boxes_data[3];
+ label_id = dnn_detect_get_label_id(ctx->nb_classes, 1, detection_boxes_data + 5);
+ conf = conf * post_process_raw_data(detection_boxes_data[label_id + 5]);
+ } else {
+ x = post_process_raw_data(detection_boxes_data[cy * cell_w + cx]);
+ y = post_process_raw_data(detection_boxes_data[cy * cell_w + cx + cell_w * cell_h]);
+ w = detection_boxes_data[cy * cell_w + cx + 2 * cell_w * cell_h];
+ h = detection_boxes_data[cy * cell_w + cx + 3 * cell_w * cell_h];
+ label_id = dnn_detect_get_label_id(ctx->nb_classes, cell_w * cell_h,
+ detection_boxes_data + cy * cell_w + cx + 5 * cell_w * cell_h);
+ conf = conf * post_process_raw_data(
+ detection_boxes_data[cy * cell_w + cx + (label_id + 5) * cell_w * cell_h]);
+ }
bbox = av_mallocz(sizeof(*bbox));
if (!bbox)
@@ -404,6 +450,7 @@ static int dnn_detect_post_proc_ov(AVFrame *frame, DNNData *output, int nb_outpu
if (ret < 0)
return ret;
case DDMT_YOLOV3:
+ case DDMT_YOLOV4:
ret = dnn_detect_post_proc_yolov3(frame, output, filter_ctx, nb_outputs);
if (ret < 0)
return ret;