@@ -386,9 +386,9 @@ static void infer_completion_callback(void *args)
ov_shape_free(&output_shape);
return;
}
- output.channels = dims[1];
- output.height = dims[2];
- output.width = dims[3];
+ output.channels = output_shape.rank > 2 ? dims[output_shape.rank - 3] : 1;
+ output.height = output_shape.rank > 1 ? dims[output_shape.rank - 2] : 1;
+ output.width = output_shape.rank > 0 ? dims[output_shape.rank - 1] : 1;
av_assert0(request->lltask_count <= dims[0]);
ov_shape_free(&output_shape);
#else
@@ -30,9 +30,11 @@
#include "libavutil/time.h"
#include "libavutil/avstring.h"
#include "libavutil/detection_bbox.h"
+#include "libavutil/fifo.h"
typedef enum {
- DDMT_SSD
+ DDMT_SSD,
+ DDMT_YOLOV1V2,
} DNNDetectionModelType;
typedef struct DnnDetectContext {
@@ -43,6 +45,15 @@ typedef struct DnnDetectContext {
char **labels;
int label_count;
DNNDetectionModelType model_type;
+ int cell_w;
+ int cell_h;
+ int nb_classes;
+ AVFifo *bboxes_fifo;
+ int scale_width;
+ int scale_height;
+ char *anchors_str;
+ float *anchors;
+ int nb_anchor;
} DnnDetectContext;
#define OFFSET(x) offsetof(DnnDetectContext, dnnctx.x)
@@ -61,11 +72,218 @@ static const AVOption dnn_detect_options[] = {
{ "labels", "path to labels file", OFFSET2(labels_filename), AV_OPT_TYPE_STRING, { .str = NULL }, 0, 0, FLAGS },
{ "model_type", "DNN detection model type", OFFSET2(model_type), AV_OPT_TYPE_INT, { .i64 = DDMT_SSD }, INT_MIN, INT_MAX, FLAGS, "model_type" },
{ "ssd", "output shape [1, 1, N, 7]", 0, AV_OPT_TYPE_CONST, { .i64 = DDMT_SSD }, 0, 0, FLAGS, "model_type" },
+ { "yolo", "output shape [1, N*Cx*Cy*DetectionBox]", 0, AV_OPT_TYPE_CONST, { .i64 = DDMT_YOLOV1V2 }, 0, 0, FLAGS, "model_type" },
+ { "cell_w", "cell width", OFFSET2(cell_w), AV_OPT_TYPE_INT, { .i64 = 0 }, 0, INTMAX_MAX, FLAGS },
+ { "cell_h", "cell height", OFFSET2(cell_h), AV_OPT_TYPE_INT, { .i64 = 0 }, 0, INTMAX_MAX, FLAGS },
+ { "nb_classes", "The number of class", OFFSET2(nb_classes), AV_OPT_TYPE_INT, { .i64 = 0 }, 0, INTMAX_MAX, FLAGS },
+ { "anchors", "anchors, splited by '&'", OFFSET2(anchors_str), AV_OPT_TYPE_STRING, { .str = NULL }, 0, 0, FLAGS },
{ NULL }
};
AVFILTER_DEFINE_CLASS(dnn_detect);
+static int dnn_detect_get_label_id(int nb_classes, int cell_size, float *label_data)
+{
+ float max_prob = 0;
+ int label_id = 0;
+ for (int i = 0; i < nb_classes; i++) {
+ if (label_data[i * cell_size] > max_prob) {
+ max_prob = label_data[i * cell_size];
+ label_id = i;
+ }
+ }
+ return label_id;
+}
+
+static int dnn_detect_parse_anchors(char *anchors_str, float **anchors)
+{
+ char *saveptr = NULL, *token;
+ float *anchors_buf;
+ int nb_anchor = 0, i = 0;
+ while(anchors_str[i] != '\0') {
+ if(anchors_str[i] == '&')
+ nb_anchor++;
+ i++;
+ }
+ nb_anchor++;
+ anchors_buf = av_mallocz(nb_anchor * sizeof(*anchors));
+ if (!anchors_buf) {
+ return 0;
+ }
+ for (int i = 0; i < nb_anchor; i++) {
+ token = av_strtok(anchors_str, "&", &saveptr);
+ anchors_buf[i] = strtof(token, NULL);
+ anchors_str = NULL;
+ }
+ *anchors = anchors_buf;
+ return nb_anchor;
+}
+
+/* Calculate Intersection Over Union */
+static float dnn_detect_IOU(AVDetectionBBox *bbox1, AVDetectionBBox *bbox2)
+{
+ float overlapping_width = FFMIN(bbox1->x + bbox1->w, bbox2->x + bbox2->w) - FFMAX(bbox1->x, bbox2->x);
+ float overlapping_height = FFMIN(bbox1->y + bbox1->h, bbox2->y + bbox2->h) - FFMAX(bbox1->y, bbox2->y);
+ float intersection_area =
+ (overlapping_width < 0 || overlapping_height < 0) ? 0 : overlapping_height * overlapping_width;
+ float union_area = bbox1->w * bbox1->h + bbox2->w * bbox2->h - intersection_area;
+ return intersection_area / union_area;
+}
+
+static int dnn_detect_parse_yolo_output(AVFrame *frame, DNNData *output, int output_index,
+ AVFilterContext *filter_ctx)
+{
+ DnnDetectContext *ctx = filter_ctx->priv;
+ float conf_threshold = ctx->confidence;
+ int detection_boxes, box_size, cell_w, cell_h, scale_w, scale_h;
+ int nb_classes = ctx->nb_classes;
+ float *output_data = output[output_index].data;
+ float *anchors = ctx->anchors;
+ AVDetectionBBox *bbox;
+
+ if (ctx->model_type == DDMT_YOLOV1V2) {
+ cell_w = ctx->cell_w;
+ cell_h = ctx->cell_h;
+ scale_w = cell_w;
+ scale_h = cell_h;
+ }
+ box_size = nb_classes + 5;
+
+ if (!cell_h || !cell_w) {
+ av_log(filter_ctx, AV_LOG_ERROR, "cell_w and cell_h are detected\n");
+ return AVERROR(EINVAL);
+ }
+
+ if (!nb_classes) {
+ av_log(filter_ctx, AV_LOG_ERROR, "nb_classes is not set\n");
+ return AVERROR(EINVAL);
+ }
+
+ if (!anchors) {
+ av_log(filter_ctx, AV_LOG_ERROR, "anchors is not set\n");
+ return AVERROR(EINVAL);
+ }
+
+ if (output[output_index].channels * output[output_index].width *
+ output[output_index].height % (box_size * cell_w * cell_h)) {
+ av_log(filter_ctx, AV_LOG_ERROR, "wrong cell_w, cell_h or nb_classes\n");
+ return AVERROR(EINVAL);
+ }
+ detection_boxes = output[output_index].channels *
+ output[output_index].height *
+ output[output_index].width / box_size / cell_w / cell_h;
+
+ /**
+ * find all candidate bbox
+ * yolo output can be reshaped to [B, N*D, Cx, Cy]
+ * Detection box 'D' has format [`x`, `y`, `h`, `w`, `box_score`, `class_no_1`, ...,]
+ **/
+ for (int box_id = 0; box_id < detection_boxes; box_id++) {
+ for (int cx = 0; cx < cell_w; cx++)
+ for (int cy = 0; cy < cell_h; cy++) {
+ float x, y, w, h, conf;
+ float *detection_boxes_data;
+ int label_id;
+
+ detection_boxes_data = output_data + box_id * box_size * cell_w * cell_h;
+ conf = detection_boxes_data[cy * cell_w + cx + 4 * cell_w * cell_h];
+ if (conf < conf_threshold) {
+ continue;
+ }
+
+ x = detection_boxes_data[cy * cell_w + cx];
+ y = detection_boxes_data[cy * cell_w + cx + cell_w * cell_h];
+ w = detection_boxes_data[cy * cell_w + cx + 2 * cell_w * cell_h];
+ h = detection_boxes_data[cy * cell_w + cx + 3 * cell_w * cell_h];
+ label_id = dnn_detect_get_label_id(ctx->nb_classes, cell_w * cell_h,
+ detection_boxes_data + cy * cell_w + cx + 5 * cell_w * cell_h);
+ conf = conf * detection_boxes_data[cy * cell_w + cx + (label_id + 5) * cell_w * cell_h];
+
+ bbox = av_mallocz(sizeof(*bbox));
+ if (!bbox)
+ return AVERROR(ENOMEM);
+
+ bbox->w = exp(w) * anchors[box_id * 2] * frame->width / scale_w;
+ bbox->h = exp(h) * anchors[box_id * 2 + 1] * frame->height / scale_h;
+ bbox->x = (cx + x) / cell_w * frame->width - bbox->w / 2;
+ bbox->y = (cy + y) / cell_h * frame->height - bbox->h / 2;
+ bbox->detect_confidence = av_make_q((int)(conf * 10000), 10000);
+ if (ctx->labels && label_id < ctx->label_count) {
+ av_strlcpy(bbox->detect_label, ctx->labels[label_id], sizeof(bbox->detect_label));
+ } else {
+ snprintf(bbox->detect_label, sizeof(bbox->detect_label), "%d", label_id);
+ }
+
+ if (av_fifo_write(ctx->bboxes_fifo, &bbox, 1) < 0) {
+ av_freep(&bbox);
+ return AVERROR(ENOMEM);
+ }
+ }
+ }
+ return 0;
+}
+
+static int dnn_detect_fill_side_data(AVFrame *frame, AVFilterContext *filter_ctx)
+{
+ DnnDetectContext *ctx = filter_ctx->priv;
+ float conf_threshold = ctx->confidence;
+ AVDetectionBBox *bbox;
+ int nb_bboxes = 0;
+ AVDetectionBBoxHeader *header;
+ if (av_fifo_can_read(ctx->bboxes_fifo) == 0) {
+ av_log(filter_ctx, AV_LOG_VERBOSE, "nothing detected in this frame.\n");
+ return 0;
+ }
+
+ /* remove overlap bboxes */
+ for (int i = 0; i < av_fifo_can_read(ctx->bboxes_fifo); i++){
+ av_fifo_peek(ctx->bboxes_fifo, &bbox, 1, i);
+ for (int j = 0; j < av_fifo_can_read(ctx->bboxes_fifo); j++) {
+ AVDetectionBBox *overlap_bbox;
+ av_fifo_peek(ctx->bboxes_fifo, &overlap_bbox, 1, j);
+ if (!strcmp(bbox->detect_label, overlap_bbox->detect_label) &&
+ av_cmp_q(bbox->detect_confidence, overlap_bbox->detect_confidence) < 0 &&
+ dnn_detect_IOU(bbox, overlap_bbox) >= conf_threshold) {
+ bbox->classify_count = -1; // bad result
+ nb_bboxes++;
+ break;
+ }
+ }
+ }
+ nb_bboxes = av_fifo_can_read(ctx->bboxes_fifo) - nb_bboxes;
+ header = av_detection_bbox_create_side_data(frame, nb_bboxes);
+ if (!header) {
+ av_log(filter_ctx, AV_LOG_ERROR, "failed to create side data with %d bounding boxes\n", nb_bboxes);
+ return -1;
+ }
+ av_strlcpy(header->source, ctx->dnnctx.model_filename, sizeof(header->source));
+
+ while(av_fifo_can_read(ctx->bboxes_fifo)) {
+ AVDetectionBBox *candidate_bbox;
+ av_fifo_read(ctx->bboxes_fifo, &candidate_bbox, 1);
+
+ if (nb_bboxes > 0 && candidate_bbox->classify_count != -1) {
+ bbox = av_get_detection_bbox(header, header->nb_bboxes - nb_bboxes);
+ memcpy(bbox, candidate_bbox, sizeof(*bbox));
+ nb_bboxes--;
+ }
+ av_freep(&candidate_bbox);
+ }
+ return 0;
+}
+
+static int dnn_detect_post_proc_yolo(AVFrame *frame, DNNData *output, AVFilterContext *filter_ctx)
+{
+ int ret = 0;
+ ret = dnn_detect_parse_yolo_output(frame, output, 0, filter_ctx);
+ if (ret < 0)
+ return ret;
+ ret = dnn_detect_fill_side_data(frame, filter_ctx);
+ if (ret < 0)
+ return ret;
+ return 0;
+}
+
static int dnn_detect_post_proc_ssd(AVFrame *frame, DNNData *output, AVFilterContext *filter_ctx)
{
DnnDetectContext *ctx = filter_ctx->priv;
@@ -158,6 +376,10 @@ static int dnn_detect_post_proc_ov(AVFrame *frame, DNNData *output, AVFilterCont
if (ret < 0)
return ret;
break;
+ case DDMT_YOLOV1V2:
+ ret = dnn_detect_post_proc_yolo(frame, output, filter_ctx);
+ if (ret < 0)
+ return ret;
}
return 0;
@@ -356,11 +578,22 @@ static av_cold int dnn_detect_init(AVFilterContext *context)
ret = check_output_nb(ctx, dnn_ctx->backend_type, dnn_ctx->nb_outputs);
if (ret < 0)
return ret;
+ ctx->bboxes_fifo = av_fifo_alloc2(1, sizeof(AVDetectionBBox *), AV_FIFO_FLAG_AUTO_GROW);
+ if (!ctx->bboxes_fifo)
+ return AVERROR(ENOMEM);
ff_dnn_set_detect_post_proc(&ctx->dnnctx, dnn_detect_post_proc);
if (ctx->labels_filename) {
return read_detect_label_file(context);
}
+ if (ctx->anchors_str) {
+ ret = dnn_detect_parse_anchors(ctx->anchors_str, &ctx->anchors);
+ if (!ctx->anchors) {
+ av_log(context, AV_LOG_ERROR, "failed to parse anchors_str\n");
+ return AVERROR(EINVAL);
+ }
+ ctx->nb_anchor = ret;
+ }
return 0;
}
@@ -460,7 +693,14 @@ static int dnn_detect_activate(AVFilterContext *filter_ctx)
static av_cold void dnn_detect_uninit(AVFilterContext *context)
{
DnnDetectContext *ctx = context->priv;
+ AVDetectionBBox *bbox;
ff_dnn_uninit(&ctx->dnnctx);
+ while(av_fifo_can_read(ctx->bboxes_fifo)) {
+ av_fifo_read(ctx->bboxes_fifo, &bbox, 1);
+ av_freep(&bbox);
+ }
+ av_fifo_freep2(&ctx->bboxes_fifo);
+ av_freep(&ctx->anchors);
free_detect_labels(ctx);
}