diff mbox series

[FFmpeg-devel,V2,5/6] lavfi/dnn: add classify support with openvino backend

Message ID 20210429133657.23076-5-yejun.guo@intel.com
State Accepted
Headers show
Series [FFmpeg-devel,V2,1/6] lavfi/dnn_backend_openvino.c: unify code for infer request for sync/async
Related show

Checks

Context Check Description
andriy/x86_make success Make finished
andriy/x86_make_fate success Make fate finished
andriy/PPC64_make success Make finished
andriy/PPC64_make_fate success Make fate finished

Commit Message

Guo, Yejun April 29, 2021, 1:36 p.m. UTC
Signed-off-by: Guo, Yejun <yejun.guo@intel.com>
---
 libavfilter/dnn/dnn_backend_openvino.c | 143 +++++++++++++++++++++----
 libavfilter/dnn/dnn_io_proc.c          |  60 +++++++++++
 libavfilter/dnn/dnn_io_proc.h          |   1 +
 libavfilter/dnn_filter_common.c        |  21 ++++
 libavfilter/dnn_filter_common.h        |   2 +
 libavfilter/dnn_interface.h            |  10 +-
 6 files changed, 218 insertions(+), 19 deletions(-)
diff mbox series

Patch

diff --git a/libavfilter/dnn/dnn_backend_openvino.c b/libavfilter/dnn/dnn_backend_openvino.c
index 4e58ff6d9c..1ff8a720b9 100644
--- a/libavfilter/dnn/dnn_backend_openvino.c
+++ b/libavfilter/dnn/dnn_backend_openvino.c
@@ -29,6 +29,7 @@ 
 #include "libavutil/avassert.h"
 #include "libavutil/opt.h"
 #include "libavutil/avstring.h"
+#include "libavutil/detection_bbox.h"
 #include "../internal.h"
 #include "queue.h"
 #include "safe_queue.h"
@@ -74,6 +75,7 @@  typedef struct TaskItem {
 // one task might have multiple inferences
 typedef struct InferenceItem {
     TaskItem *task;
+    uint32_t bbox_index;
 } InferenceItem;
 
 // one request for one call to openvino
@@ -182,12 +184,23 @@  static DNNReturnType fill_model_input_ov(OVModel *ov_model, RequestItem *request
         request->inferences[i] = inference;
         request->inference_count = i + 1;
         task = inference->task;
-        if (task->do_ioproc) {
-            if (ov_model->model->frame_pre_proc != NULL) {
-                ov_model->model->frame_pre_proc(task->in_frame, &input, ov_model->model->filter_ctx);
-            } else {
-                ff_proc_from_frame_to_dnn(task->in_frame, &input, ov_model->model->func_type, ctx);
+        switch (task->ov_model->model->func_type) {
+        case DFT_PROCESS_FRAME:
+        case DFT_ANALYTICS_DETECT:
+            if (task->do_ioproc) {
+                if (ov_model->model->frame_pre_proc != NULL) {
+                    ov_model->model->frame_pre_proc(task->in_frame, &input, ov_model->model->filter_ctx);
+                } else {
+                    ff_proc_from_frame_to_dnn(task->in_frame, &input, ov_model->model->func_type, ctx);
+                }
             }
+            break;
+        case DFT_ANALYTICS_CLASSIFY:
+            ff_frame_to_dnn_classify(task->in_frame, &input, inference->bbox_index, ctx);
+            break;
+        default:
+            av_assert0(!"should not reach here");
+            break;
         }
         input.data = (uint8_t *)input.data
                      + input.width * input.height * input.channels * get_datatype_size(input.dt);
@@ -276,6 +289,13 @@  static void infer_completion_callback(void *args)
             }
             task->ov_model->model->detect_post_proc(task->out_frame, &output, 1, task->ov_model->model->filter_ctx);
             break;
+        case DFT_ANALYTICS_CLASSIFY:
+            if (!task->ov_model->model->classify_post_proc) {
+                av_log(ctx, AV_LOG_ERROR, "classify filter needs to provide post proc\n");
+                return;
+            }
+            task->ov_model->model->classify_post_proc(task->out_frame, &output, request->inferences[i]->bbox_index, task->ov_model->model->filter_ctx);
+            break;
         default:
             av_assert0(!"should not reach here");
             break;
@@ -513,7 +533,44 @@  static DNNReturnType get_input_ov(void *model, DNNData *input, const char *input
     return DNN_ERROR;
 }
 
-static DNNReturnType extract_inference_from_task(DNNFunctionType func_type, TaskItem *task, Queue *inference_queue)
+static int contain_valid_detection_bbox(AVFrame *frame)
+{
+    AVFrameSideData *sd;
+    const AVDetectionBBoxHeader *header;
+    const AVDetectionBBox *bbox;
+
+    sd = av_frame_get_side_data(frame, AV_FRAME_DATA_DETECTION_BBOXES);
+    if (!sd) { // this frame has nothing detected
+        return 0;
+    }
+
+    if (!sd->size) {
+        return 0;
+    }
+
+    header = (const AVDetectionBBoxHeader *)sd->data;
+    if (!header->nb_bboxes) {
+        return 0;
+    }
+
+    for (uint32_t i = 0; i < header->nb_bboxes; i++) {
+        bbox = av_get_detection_bbox(header, i);
+        if (bbox->x < 0 || bbox->w < 0 || bbox->x + bbox->w >= frame->width) {
+            return 0;
+        }
+        if (bbox->y < 0 || bbox->h < 0 || bbox->y + bbox->h >= frame->width) {
+            return 0;
+        }
+
+        if (bbox->classify_count == AV_NUM_DETECTION_BBOX_CLASSIFY) {
+            return 0;
+        }
+    }
+
+    return 1;
+}
+
+static DNNReturnType extract_inference_from_task(DNNFunctionType func_type, TaskItem *task, Queue *inference_queue, DNNExecBaseParams *exec_params)
 {
     switch (func_type) {
     case DFT_PROCESS_FRAME:
@@ -532,6 +589,45 @@  static DNNReturnType extract_inference_from_task(DNNFunctionType func_type, Task
         }
         return DNN_SUCCESS;
     }
+    case DFT_ANALYTICS_CLASSIFY:
+    {
+        const AVDetectionBBoxHeader *header;
+        AVFrame *frame = task->in_frame;
+        AVFrameSideData *sd;
+        DNNExecClassificationParams *params = (DNNExecClassificationParams *)exec_params;
+
+        task->inference_todo = 0;
+        task->inference_done = 0;
+
+        if (!contain_valid_detection_bbox(frame)) {
+            return DNN_SUCCESS;
+        }
+
+        sd = av_frame_get_side_data(frame, AV_FRAME_DATA_DETECTION_BBOXES);
+        header = (const AVDetectionBBoxHeader *)sd->data;
+
+        for (uint32_t i = 0; i < header->nb_bboxes; i++) {
+            InferenceItem *inference;
+            const AVDetectionBBox *bbox = av_get_detection_bbox(header, i);
+
+            if (av_strncasecmp(bbox->detect_label, params->target, sizeof(bbox->detect_label)) != 0) {
+                continue;
+            }
+
+            inference = av_malloc(sizeof(*inference));
+            if (!inference) {
+                return DNN_ERROR;
+            }
+            task->inference_todo++;
+            inference->task = task;
+            inference->bbox_index = i;
+            if (ff_queue_push_back(inference_queue, inference) < 0) {
+                av_freep(&inference);
+                return DNN_ERROR;
+            }
+        }
+        return DNN_SUCCESS;
+    }
     default:
         av_assert0(!"should not reach here");
         return DNN_ERROR;
@@ -598,7 +694,7 @@  static DNNReturnType get_output_ov(void *model, const char *input_name, int inpu
     task.out_frame = out_frame;
     task.ov_model = ov_model;
 
-    if (extract_inference_from_task(ov_model->model->func_type, &task, ov_model->inference_queue) != DNN_SUCCESS) {
+    if (extract_inference_from_task(ov_model->model->func_type, &task, ov_model->inference_queue, NULL) != DNN_SUCCESS) {
         av_frame_free(&out_frame);
         av_frame_free(&in_frame);
         av_log(ctx, AV_LOG_ERROR, "unable to extract inference from task.\n");
@@ -690,6 +786,14 @@  DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, DNNExecBaseParams *
         return DNN_ERROR;
     }
 
+    if (model->func_type == DFT_ANALYTICS_CLASSIFY) {
+        // Once we add async support for tensorflow backend and native backend,
+        // we'll combine the two sync/async functions in dnn_interface.h to
+        // simplify the code in filter, and async will be an option within backends.
+        // so, do not support now, and classify filter will not call this function.
+        return DNN_ERROR;
+    }
+
     if (ctx->options.batch_size > 1) {
         avpriv_report_missing_feature(ctx, "batch mode for sync execution");
         return DNN_ERROR;
@@ -710,7 +814,7 @@  DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, DNNExecBaseParams *
     task.out_frame = exec_params->out_frame ? exec_params->out_frame : exec_params->in_frame;
     task.ov_model = ov_model;
 
-    if (extract_inference_from_task(ov_model->model->func_type, &task, ov_model->inference_queue) != DNN_SUCCESS) {
+    if (extract_inference_from_task(ov_model->model->func_type, &task, ov_model->inference_queue, exec_params) != DNN_SUCCESS) {
         av_log(ctx, AV_LOG_ERROR, "unable to extract inference from task.\n");
         return DNN_ERROR;
     }
@@ -730,6 +834,7 @@  DNNReturnType ff_dnn_execute_model_async_ov(const DNNModel *model, DNNExecBasePa
     OVContext *ctx = &ov_model->ctx;
     RequestItem *request;
     TaskItem *task;
+    DNNReturnType ret;
 
     if (ff_check_exec_params(ctx, DNN_OV, model->func_type, exec_params) != 0) {
         return DNN_ERROR;
@@ -761,23 +866,25 @@  DNNReturnType ff_dnn_execute_model_async_ov(const DNNModel *model, DNNExecBasePa
         return DNN_ERROR;
     }
 
-    if (extract_inference_from_task(ov_model->model->func_type, task, ov_model->inference_queue) != DNN_SUCCESS) {
+    if (extract_inference_from_task(model->func_type, task, ov_model->inference_queue, exec_params) != DNN_SUCCESS) {
         av_log(ctx, AV_LOG_ERROR, "unable to extract inference from task.\n");
         return DNN_ERROR;
     }
 
-    if (ff_queue_size(ov_model->inference_queue) < ctx->options.batch_size) {
-        // not enough inference items queued for a batch
-        return DNN_SUCCESS;
-    }
+    while (ff_queue_size(ov_model->inference_queue) >= ctx->options.batch_size) {
+        request = ff_safe_queue_pop_front(ov_model->request_queue);
+        if (!request) {
+            av_log(ctx, AV_LOG_ERROR, "unable to get infer request.\n");
+            return DNN_ERROR;
+        }
 
-    request = ff_safe_queue_pop_front(ov_model->request_queue);
-    if (!request) {
-        av_log(ctx, AV_LOG_ERROR, "unable to get infer request.\n");
-        return DNN_ERROR;
+        ret = execute_model_ov(request, ov_model->inference_queue);
+        if (ret != DNN_SUCCESS) {
+            return ret;
+        }
     }
 
-    return execute_model_ov(request, ov_model->inference_queue);
+    return DNN_SUCCESS;
 }
 
 DNNAsyncStatusType ff_dnn_get_async_result_ov(const DNNModel *model, AVFrame **in, AVFrame **out)
diff --git a/libavfilter/dnn/dnn_io_proc.c b/libavfilter/dnn/dnn_io_proc.c
index e104cc5064..5f60d68078 100644
--- a/libavfilter/dnn/dnn_io_proc.c
+++ b/libavfilter/dnn/dnn_io_proc.c
@@ -22,6 +22,7 @@ 
 #include "libavutil/imgutils.h"
 #include "libswscale/swscale.h"
 #include "libavutil/avassert.h"
+#include "libavutil/detection_bbox.h"
 
 DNNReturnType ff_proc_from_dnn_to_frame(AVFrame *frame, DNNData *output, void *log_ctx)
 {
@@ -175,6 +176,65 @@  static enum AVPixelFormat get_pixel_format(DNNData *data)
     return AV_PIX_FMT_BGR24;
 }
 
+DNNReturnType ff_frame_to_dnn_classify(AVFrame *frame, DNNData *input, uint32_t bbox_index, void *log_ctx)
+{
+    const AVPixFmtDescriptor *desc;
+    int offsetx[4], offsety[4];
+    uint8_t *bbox_data[4];
+    struct SwsContext *sws_ctx;
+    int linesizes[4];
+    enum AVPixelFormat fmt;
+    int left, top, width, height;
+    const AVDetectionBBoxHeader *header;
+    const AVDetectionBBox *bbox;
+    AVFrameSideData *sd = av_frame_get_side_data(frame, AV_FRAME_DATA_DETECTION_BBOXES);
+    av_assert0(sd);
+
+    header = (const AVDetectionBBoxHeader *)sd->data;
+    bbox = av_get_detection_bbox(header, bbox_index);
+
+    left = bbox->x;
+    width = bbox->w;
+    top = bbox->y;
+    height = bbox->h;
+
+    fmt = get_pixel_format(input);
+    sws_ctx = sws_getContext(width, height, frame->format,
+                             input->width, input->height, fmt,
+                             SWS_FAST_BILINEAR, NULL, NULL, NULL);
+    if (!sws_ctx) {
+        av_log(log_ctx, AV_LOG_ERROR, "Failed to create scale context for the conversion "
+               "fmt:%s s:%dx%d -> fmt:%s s:%dx%d\n",
+               av_get_pix_fmt_name(frame->format), width, height,
+               av_get_pix_fmt_name(fmt), input->width, input->height);
+        return DNN_ERROR;
+    }
+
+    if (av_image_fill_linesizes(linesizes, fmt, input->width) < 0) {
+        av_log(log_ctx, AV_LOG_ERROR, "unable to get linesizes with av_image_fill_linesizes");
+        sws_freeContext(sws_ctx);
+        return DNN_ERROR;
+    }
+
+    desc = av_pix_fmt_desc_get(frame->format);
+    offsetx[1] = offsetx[2] = AV_CEIL_RSHIFT(left, desc->log2_chroma_w);
+    offsetx[0] = offsetx[3] = left;
+
+    offsety[1] = offsety[2] = AV_CEIL_RSHIFT(top, desc->log2_chroma_h);
+    offsety[0] = offsety[3] = top;
+
+    for (int k = 0; frame->data[k]; k++)
+        bbox_data[k] = frame->data[k] + offsety[k] * frame->linesize[k] + offsetx[k];
+
+    sws_scale(sws_ctx, (const uint8_t *const *)&bbox_data, frame->linesize,
+                       0, height,
+                       (uint8_t *const *)(&input->data), linesizes);
+
+    sws_freeContext(sws_ctx);
+
+    return DNN_SUCCESS;
+}
+
 static DNNReturnType proc_from_frame_to_dnn_analytics(AVFrame *frame, DNNData *input, void *log_ctx)
 {
     struct SwsContext *sws_ctx;
diff --git a/libavfilter/dnn/dnn_io_proc.h b/libavfilter/dnn/dnn_io_proc.h
index 91ad3cb261..16dcdd6d1a 100644
--- a/libavfilter/dnn/dnn_io_proc.h
+++ b/libavfilter/dnn/dnn_io_proc.h
@@ -32,5 +32,6 @@ 
 
 DNNReturnType ff_proc_from_frame_to_dnn(AVFrame *frame, DNNData *input, DNNFunctionType func_type, void *log_ctx);
 DNNReturnType ff_proc_from_dnn_to_frame(AVFrame *frame, DNNData *output, void *log_ctx);
+DNNReturnType ff_frame_to_dnn_classify(AVFrame *frame, DNNData *input, uint32_t bbox_index, void *log_ctx);
 
 #endif
diff --git a/libavfilter/dnn_filter_common.c b/libavfilter/dnn_filter_common.c
index c085884eb4..52c7a5392a 100644
--- a/libavfilter/dnn_filter_common.c
+++ b/libavfilter/dnn_filter_common.c
@@ -77,6 +77,12 @@  int ff_dnn_set_detect_post_proc(DnnContext *ctx, DetectPostProc post_proc)
     return 0;
 }
 
+int ff_dnn_set_classify_post_proc(DnnContext *ctx, ClassifyPostProc post_proc)
+{
+    ctx->model->classify_post_proc = post_proc;
+    return 0;
+}
+
 DNNReturnType ff_dnn_get_input(DnnContext *ctx, DNNData *input)
 {
     return ctx->model->get_input(ctx->model->model, input, ctx->model_inputname);
@@ -112,6 +118,21 @@  DNNReturnType ff_dnn_execute_model_async(DnnContext *ctx, AVFrame *in_frame, AVF
     return (ctx->dnn_module->execute_model_async)(ctx->model, &exec_params);
 }
 
+DNNReturnType ff_dnn_execute_model_classification(DnnContext *ctx, AVFrame *in_frame, AVFrame *out_frame, char *target)
+{
+    DNNExecClassificationParams class_params = {
+        {
+            .input_name     = ctx->model_inputname,
+            .output_names   = (const char **)&ctx->model_outputname,
+            .nb_output      = 1,
+            .in_frame       = in_frame,
+            .out_frame      = out_frame,
+        },
+        .target = target,
+    };
+    return (ctx->dnn_module->execute_model_async)(ctx->model, &class_params.base);
+}
+
 DNNAsyncStatusType ff_dnn_get_async_result(DnnContext *ctx, AVFrame **in_frame, AVFrame **out_frame)
 {
     return (ctx->dnn_module->get_async_result)(ctx->model, in_frame, out_frame);
diff --git a/libavfilter/dnn_filter_common.h b/libavfilter/dnn_filter_common.h
index 8deb18b39a..e7736d2bac 100644
--- a/libavfilter/dnn_filter_common.h
+++ b/libavfilter/dnn_filter_common.h
@@ -50,10 +50,12 @@  typedef struct DnnContext {
 int ff_dnn_init(DnnContext *ctx, DNNFunctionType func_type, AVFilterContext *filter_ctx);
 int ff_dnn_set_frame_proc(DnnContext *ctx, FramePrePostProc pre_proc, FramePrePostProc post_proc);
 int ff_dnn_set_detect_post_proc(DnnContext *ctx, DetectPostProc post_proc);
+int ff_dnn_set_classify_post_proc(DnnContext *ctx, ClassifyPostProc post_proc);
 DNNReturnType ff_dnn_get_input(DnnContext *ctx, DNNData *input);
 DNNReturnType ff_dnn_get_output(DnnContext *ctx, int input_width, int input_height, int *output_width, int *output_height);
 DNNReturnType ff_dnn_execute_model(DnnContext *ctx, AVFrame *in_frame, AVFrame *out_frame);
 DNNReturnType ff_dnn_execute_model_async(DnnContext *ctx, AVFrame *in_frame, AVFrame *out_frame);
+DNNReturnType ff_dnn_execute_model_classification(DnnContext *ctx, AVFrame *in_frame, AVFrame *out_frame, char *target);
 DNNAsyncStatusType ff_dnn_get_async_result(DnnContext *ctx, AVFrame **in_frame, AVFrame **out_frame);
 DNNReturnType ff_dnn_flush(DnnContext *ctx);
 void ff_dnn_uninit(DnnContext *ctx);
diff --git a/libavfilter/dnn_interface.h b/libavfilter/dnn_interface.h
index 941670675d..799244ee14 100644
--- a/libavfilter/dnn_interface.h
+++ b/libavfilter/dnn_interface.h
@@ -52,7 +52,7 @@  typedef enum {
     DFT_NONE,
     DFT_PROCESS_FRAME,      // process the whole frame
     DFT_ANALYTICS_DETECT,   // detect from the whole frame
-    // we can add more such as detect_from_crop, classify_from_bbox, etc.
+    DFT_ANALYTICS_CLASSIFY, // classify for each bounding box
 }DNNFunctionType;
 
 typedef struct DNNData{
@@ -71,8 +71,14 @@  typedef struct DNNExecBaseParams {
     AVFrame *out_frame;
 } DNNExecBaseParams;
 
+typedef struct DNNExecClassificationParams {
+    DNNExecBaseParams base;
+    const char *target;
+} DNNExecClassificationParams;
+
 typedef int (*FramePrePostProc)(AVFrame *frame, DNNData *model, AVFilterContext *filter_ctx);
 typedef int (*DetectPostProc)(AVFrame *frame, DNNData *output, uint32_t nb, AVFilterContext *filter_ctx);
+typedef int (*ClassifyPostProc)(AVFrame *frame, DNNData *output, uint32_t bbox_index, AVFilterContext *filter_ctx);
 
 typedef struct DNNModel{
     // Stores model that can be different for different backends.
@@ -97,6 +103,8 @@  typedef struct DNNModel{
     FramePrePostProc frame_post_proc;
     // set the post process to interpret detect result from DNNData
     DetectPostProc detect_post_proc;
+    // set the post process to interpret classify result from DNNData
+    ClassifyPostProc classify_post_proc;
 } DNNModel;
 
 // Stores pointers to functions for loading, executing, freeing DNN models for one of the backends.