diff mbox series

[FFmpeg-devel,V7,3/6] lavfi/dnn: add post process for detection

Message ID 20210407141723.11527-3-yejun.guo@intel.com
State Accepted
Headers show
Series [FFmpeg-devel,V7,1/6] lavfi/dnn_backend_openvino.c: only allow DFT_PROCESS_FRAME to get output dim | expand

Checks

Context Check Description
andriy/x86_make success Make finished
andriy/x86_make_fate success Make fate finished
andriy/PPC64_make success Make finished
andriy/PPC64_make_fate success Make fate finished

Commit Message

Guo, Yejun April 7, 2021, 2:17 p.m. UTC
---
 libavfilter/dnn/dnn_backend_openvino.c | 30 ++++++++++++++++++++------
 libavfilter/dnn_filter_common.c        |  6 ++++++
 libavfilter/dnn_filter_common.h        |  1 +
 libavfilter/dnn_interface.h            |  3 +++
 4 files changed, 33 insertions(+), 7 deletions(-)
diff mbox series

Patch

diff --git a/libavfilter/dnn/dnn_backend_openvino.c b/libavfilter/dnn/dnn_backend_openvino.c
index 3bea2d526a..0757727a9c 100644
--- a/libavfilter/dnn/dnn_backend_openvino.c
+++ b/libavfilter/dnn/dnn_backend_openvino.c
@@ -236,16 +236,32 @@  static void infer_completion_callback(void *args)
     av_assert0(request->task_count >= 1);
     for (int i = 0; i < request->task_count; ++i) {
         task = request->tasks[i];
-        if (task->do_ioproc) {
-            if (task->ov_model->model->frame_post_proc != NULL) {
-                task->ov_model->model->frame_post_proc(task->out_frame, &output, task->ov_model->model->filter_ctx);
+
+        switch (task->ov_model->model->func_type) {
+        case DFT_PROCESS_FRAME:
+            if (task->do_ioproc) {
+                if (task->ov_model->model->frame_post_proc != NULL) {
+                    task->ov_model->model->frame_post_proc(task->out_frame, &output, task->ov_model->model->filter_ctx);
+                } else {
+                    ff_proc_from_dnn_to_frame(task->out_frame, &output, ctx);
+                }
             } else {
-                ff_proc_from_dnn_to_frame(task->out_frame, &output, ctx);
+                task->out_frame->width = output.width;
+                task->out_frame->height = output.height;
             }
-        } else {
-            task->out_frame->width = output.width;
-            task->out_frame->height = output.height;
+            break;
+        case DFT_ANALYTICS_DETECT:
+            if (!task->ov_model->model->detect_post_proc) {
+                av_log(ctx, AV_LOG_ERROR, "detect filter needs to provide post proc\n");
+                return;
+            }
+            task->ov_model->model->detect_post_proc(task->out_frame, &output, 1, task->ov_model->model->filter_ctx);
+            break;
+        default:
+            av_assert0(!"should not reach here");
+            break;
         }
+
         task->done = 1;
         output.data = (uint8_t *)output.data
                       + output.width * output.height * output.channels * get_datatype_size(output.dt);
diff --git a/libavfilter/dnn_filter_common.c b/libavfilter/dnn_filter_common.c
index dc5966332a..1b922455a3 100644
--- a/libavfilter/dnn_filter_common.c
+++ b/libavfilter/dnn_filter_common.c
@@ -71,6 +71,12 @@  int ff_dnn_set_frame_proc(DnnContext *ctx, FramePrePostProc pre_proc, FramePrePo
     return 0;
 }
 
+int ff_dnn_set_detect_post_proc(DnnContext *ctx, DetectPostProc post_proc)
+{
+    ctx->model->detect_post_proc = post_proc;
+    return 0;
+}
+
 DNNReturnType ff_dnn_get_input(DnnContext *ctx, DNNData *input)
 {
     return ctx->model->get_input(ctx->model->model, input, ctx->model_inputname);
diff --git a/libavfilter/dnn_filter_common.h b/libavfilter/dnn_filter_common.h
index c611d594dc..8deb18b39a 100644
--- a/libavfilter/dnn_filter_common.h
+++ b/libavfilter/dnn_filter_common.h
@@ -49,6 +49,7 @@  typedef struct DnnContext {
 
 int ff_dnn_init(DnnContext *ctx, DNNFunctionType func_type, AVFilterContext *filter_ctx);
 int ff_dnn_set_frame_proc(DnnContext *ctx, FramePrePostProc pre_proc, FramePrePostProc post_proc);
+int ff_dnn_set_detect_post_proc(DnnContext *ctx, DetectPostProc post_proc);
 DNNReturnType ff_dnn_get_input(DnnContext *ctx, DNNData *input);
 DNNReturnType ff_dnn_get_output(DnnContext *ctx, int input_width, int input_height, int *output_width, int *output_height);
 DNNReturnType ff_dnn_execute_model(DnnContext *ctx, AVFrame *in_frame, AVFrame *out_frame);
diff --git a/libavfilter/dnn_interface.h b/libavfilter/dnn_interface.h
index 3c7846f1a5..ae5a488341 100644
--- a/libavfilter/dnn_interface.h
+++ b/libavfilter/dnn_interface.h
@@ -64,6 +64,7 @@  typedef struct DNNData{
 } DNNData;
 
 typedef int (*FramePrePostProc)(AVFrame *frame, DNNData *model, AVFilterContext *filter_ctx);
+typedef int (*DetectPostProc)(AVFrame *frame, DNNData *output, uint32_t nb, AVFilterContext *filter_ctx);
 
 typedef struct DNNModel{
     // Stores model that can be different for different backends.
@@ -86,6 +87,8 @@  typedef struct DNNModel{
     // set the post process to transfer data from DNNData to AVFrame
     // the default implementation within DNN is used if it is not provided by the filter
     FramePrePostProc frame_post_proc;
+    // set the post process to interpret detect result from DNNData
+    DetectPostProc detect_post_proc;
 } DNNModel;
 
 // Stores pointers to functions for loading, executing, freeing DNN models for one of the backends.