diff mbox series

[FFmpeg-devel,V2,04/10] dnn: add function type for model

Message ID 20210210093432.9135-4-yejun.guo@intel.com
State Accepted
Headers show
Series [FFmpeg-devel,V2,01/10] dnn_backend_openvino.c: fix mismatch between ffmpeg(NHWC) and openvino(NCHW)
Related show

Checks

Context Check Description
andriy/x86_make success Make finished
andriy/x86_make_fate success Make fate finished
andriy/PPC64_make success Make finished
andriy/PPC64_make_fate success Make fate finished

Commit Message

Guo Yejun Feb. 10, 2021, 9:34 a.m. UTC
So the backend knows the usage of model is for frame processing,
detect, classify, etc. Each function type has different behavior
in backend when handling the input/output data of the model.

Signed-off-by: Guo, Yejun <yejun.guo@intel.com>
---
 libavfilter/dnn/dnn_backend_native.c   |  3 ++-
 libavfilter/dnn/dnn_backend_native.h   |  2 +-
 libavfilter/dnn/dnn_backend_openvino.c |  3 ++-
 libavfilter/dnn/dnn_backend_openvino.h |  2 +-
 libavfilter/dnn/dnn_backend_tf.c       |  5 +++--
 libavfilter/dnn/dnn_backend_tf.h       |  2 +-
 libavfilter/dnn_filter_common.c        |  4 ++--
 libavfilter/dnn_filter_common.h        |  2 +-
 libavfilter/dnn_interface.h            | 11 ++++++++++-
 libavfilter/vf_derain.c                |  2 +-
 libavfilter/vf_dnn_processing.c        |  2 +-
 libavfilter/vf_sr.c                    |  2 +-
 12 files changed, 26 insertions(+), 14 deletions(-)
diff mbox series

Patch

diff --git a/libavfilter/dnn/dnn_backend_native.c b/libavfilter/dnn/dnn_backend_native.c
index 87f3568cc2..be6451367a 100644
--- a/libavfilter/dnn/dnn_backend_native.c
+++ b/libavfilter/dnn/dnn_backend_native.c
@@ -112,7 +112,7 @@  static DNNReturnType get_output_native(void *model, const char *input_name, int
 // layers_num,layer_type,layer_parameterss,layer_type,layer_parameters...
 // For CONV layer: activation_function, input_num, output_num, kernel_size, kernel, biases
 // For DEPTH_TO_SPACE layer: block_size
-DNNModel *ff_dnn_load_model_native(const char *model_filename, const char *options, AVFilterContext *filter_ctx)
+DNNModel *ff_dnn_load_model_native(const char *model_filename, DNNFunctionType func_type, const char *options, AVFilterContext *filter_ctx)
 {
     DNNModel *model = NULL;
     char header_expected[] = "FFMPEGDNNNATIVE";
@@ -256,6 +256,7 @@  DNNModel *ff_dnn_load_model_native(const char *model_filename, const char *optio
     model->get_input = &get_input_native;
     model->get_output = &get_output_native;
     model->filter_ctx = filter_ctx;
+    model->func_type = func_type;
 
     return model;
 
diff --git a/libavfilter/dnn/dnn_backend_native.h b/libavfilter/dnn/dnn_backend_native.h
index 5c8ce82b35..d313c48f3a 100644
--- a/libavfilter/dnn/dnn_backend_native.h
+++ b/libavfilter/dnn/dnn_backend_native.h
@@ -128,7 +128,7 @@  typedef struct NativeModel{
     int32_t operands_num;
 } NativeModel;
 
-DNNModel *ff_dnn_load_model_native(const char *model_filename, const char *options, AVFilterContext *filter_ctx);
+DNNModel *ff_dnn_load_model_native(const char *model_filename, DNNFunctionType func_type, const char *options, AVFilterContext *filter_ctx);
 
 DNNReturnType ff_dnn_execute_model_native(const DNNModel *model, const char *input_name, AVFrame *in_frame,
                                           const char **output_names, uint32_t nb_output, AVFrame *out_frame);
diff --git a/libavfilter/dnn/dnn_backend_openvino.c b/libavfilter/dnn/dnn_backend_openvino.c
index ed41b721fc..7c1abb3eeb 100644
--- a/libavfilter/dnn/dnn_backend_openvino.c
+++ b/libavfilter/dnn/dnn_backend_openvino.c
@@ -524,7 +524,7 @@  static DNNReturnType get_output_ov(void *model, const char *input_name, int inpu
     return ret;
 }
 
-DNNModel *ff_dnn_load_model_ov(const char *model_filename, const char *options, AVFilterContext *filter_ctx)
+DNNModel *ff_dnn_load_model_ov(const char *model_filename, DNNFunctionType func_type, const char *options, AVFilterContext *filter_ctx)
 {
     DNNModel *model = NULL;
     OVModel *ov_model = NULL;
@@ -572,6 +572,7 @@  DNNModel *ff_dnn_load_model_ov(const char *model_filename, const char *options,
     model->get_output = &get_output_ov;
     model->options = options;
     model->filter_ctx = filter_ctx;
+    model->func_type = func_type;
 
     return model;
 
diff --git a/libavfilter/dnn/dnn_backend_openvino.h b/libavfilter/dnn/dnn_backend_openvino.h
index 23b819440e..a484a7be32 100644
--- a/libavfilter/dnn/dnn_backend_openvino.h
+++ b/libavfilter/dnn/dnn_backend_openvino.h
@@ -29,7 +29,7 @@ 
 
 #include "../dnn_interface.h"
 
-DNNModel *ff_dnn_load_model_ov(const char *model_filename, const char *options, AVFilterContext *filter_ctx);
+DNNModel *ff_dnn_load_model_ov(const char *model_filename, DNNFunctionType func_type, const char *options, AVFilterContext *filter_ctx);
 
 DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, const char *input_name, AVFrame *in_frame,
                                       const char **output_names, uint32_t nb_output, AVFrame *out_frame);
diff --git a/libavfilter/dnn/dnn_backend_tf.c b/libavfilter/dnn/dnn_backend_tf.c
index 71a2a308b5..e7e5f221f3 100644
--- a/libavfilter/dnn/dnn_backend_tf.c
+++ b/libavfilter/dnn/dnn_backend_tf.c
@@ -580,7 +580,7 @@  static DNNReturnType load_native_model(TFModel *tf_model, const char *model_file
     DNNModel *model = NULL;
     NativeModel *native_model;
 
-    model = ff_dnn_load_model_native(model_filename, NULL, NULL);
+    model = ff_dnn_load_model_native(model_filename, DFT_PROCESS_FRAME, NULL, NULL);
     if (!model){
         av_log(ctx, AV_LOG_ERROR, "Failed to load native model\n");
         return DNN_ERROR;
@@ -664,7 +664,7 @@  static DNNReturnType load_native_model(TFModel *tf_model, const char *model_file
     return DNN_SUCCESS;
 }
 
-DNNModel *ff_dnn_load_model_tf(const char *model_filename, const char *options, AVFilterContext *filter_ctx)
+DNNModel *ff_dnn_load_model_tf(const char *model_filename, DNNFunctionType func_type, const char *options, AVFilterContext *filter_ctx)
 {
     DNNModel *model = NULL;
     TFModel *tf_model = NULL;
@@ -705,6 +705,7 @@  DNNModel *ff_dnn_load_model_tf(const char *model_filename, const char *options,
     model->get_output = &get_output_tf;
     model->options = options;
     model->filter_ctx = filter_ctx;
+    model->func_type = func_type;
 
     return model;
 }
diff --git a/libavfilter/dnn/dnn_backend_tf.h b/libavfilter/dnn/dnn_backend_tf.h
index cac8936729..8cec04748e 100644
--- a/libavfilter/dnn/dnn_backend_tf.h
+++ b/libavfilter/dnn/dnn_backend_tf.h
@@ -29,7 +29,7 @@ 
 
 #include "../dnn_interface.h"
 
-DNNModel *ff_dnn_load_model_tf(const char *model_filename, const char *options, AVFilterContext *filter_ctx);
+DNNModel *ff_dnn_load_model_tf(const char *model_filename, DNNFunctionType func_type, const char *options, AVFilterContext *filter_ctx);
 
 DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, const char *input_name, AVFrame *in_frame,
                                       const char **output_names, uint32_t nb_output, AVFrame *out_frame);
diff --git a/libavfilter/dnn_filter_common.c b/libavfilter/dnn_filter_common.c
index 5d0d7d3b90..413adba406 100644
--- a/libavfilter/dnn_filter_common.c
+++ b/libavfilter/dnn_filter_common.c
@@ -18,7 +18,7 @@ 
 
 #include "dnn_filter_common.h"
 
-int ff_dnn_init(DnnContext *ctx, AVFilterContext *filter_ctx)
+int ff_dnn_init(DnnContext *ctx, DNNFunctionType func_type, AVFilterContext *filter_ctx)
 {
     if (!ctx->model_filename) {
         av_log(filter_ctx, AV_LOG_ERROR, "model file for network is not specified\n");
@@ -43,7 +43,7 @@  int ff_dnn_init(DnnContext *ctx, AVFilterContext *filter_ctx)
         return AVERROR(EINVAL);
     }
 
-    ctx->model = (ctx->dnn_module->load_model)(ctx->model_filename, ctx->backend_options, filter_ctx);
+    ctx->model = (ctx->dnn_module->load_model)(ctx->model_filename, func_type, ctx->backend_options, filter_ctx);
     if (!ctx->model) {
         av_log(filter_ctx, AV_LOG_ERROR, "could not load DNN model\n");
         return AVERROR(EINVAL);
diff --git a/libavfilter/dnn_filter_common.h b/libavfilter/dnn_filter_common.h
index ab49a992ed..79c4d3efe3 100644
--- a/libavfilter/dnn_filter_common.h
+++ b/libavfilter/dnn_filter_common.h
@@ -47,7 +47,7 @@  typedef struct DnnContext {
     { "async",              "use DNN async inference",    OFFSET(async),            AV_OPT_TYPE_BOOL,      { .i64 = 1},     0, 1, FLAGS},
 
 
-int ff_dnn_init(DnnContext *ctx, AVFilterContext *filter_ctx);
+int ff_dnn_init(DnnContext *ctx, DNNFunctionType func_type, AVFilterContext *filter_ctx);
 DNNReturnType ff_dnn_get_input(DnnContext *ctx, DNNData *input);
 DNNReturnType ff_dnn_get_output(DnnContext *ctx, int input_width, int input_height, int *output_width, int *output_height);
 DNNReturnType ff_dnn_execute_model(DnnContext *ctx, AVFrame *in_frame, AVFrame *out_frame);
diff --git a/libavfilter/dnn_interface.h b/libavfilter/dnn_interface.h
index ff338ea084..2fb9b15676 100644
--- a/libavfilter/dnn_interface.h
+++ b/libavfilter/dnn_interface.h
@@ -43,6 +43,13 @@  typedef enum {
     DAST_SUCCESS            // got a result frame successfully
 } DNNAsyncStatusType;
 
+typedef enum {
+    DFT_NONE,
+    DFT_PROCESS_FRAME,      // process the whole frame
+    DFT_ANALYTICS_DETECT,   // detect from the whole frame
+    // we can add more such as detect_from_crop, classify_from_bbox, etc.
+}DNNFunctionType;
+
 typedef struct DNNData{
     void *data;
     DNNDataType dt;
@@ -56,6 +63,8 @@  typedef struct DNNModel{
     const char *options;
     // Stores FilterContext used for the interaction between AVFrame and DNNData
     AVFilterContext *filter_ctx;
+    // Stores function type of the model
+    DNNFunctionType func_type;
     // Gets model input information
     // Just reuse struct DNNData here, actually the DNNData.data field is not needed.
     DNNReturnType (*get_input)(void *model, DNNData *input, const char *input_name);
@@ -73,7 +82,7 @@  typedef struct DNNModel{
 // Stores pointers to functions for loading, executing, freeing DNN models for one of the backends.
 typedef struct DNNModule{
     // Loads model and parameters from given file. Returns NULL if it is not possible.
-    DNNModel *(*load_model)(const char *model_filename, const char *options, AVFilterContext *filter_ctx);
+    DNNModel *(*load_model)(const char *model_filename, DNNFunctionType func_type, const char *options, AVFilterContext *filter_ctx);
     // Executes model with specified input and output. Returns DNN_ERROR otherwise.
     DNNReturnType (*execute_model)(const DNNModel *model, const char *input_name, AVFrame *in_frame,
                                    const char **output_names, uint32_t nb_output, AVFrame *out_frame);
diff --git a/libavfilter/vf_derain.c b/libavfilter/vf_derain.c
index ec9853d957..7814fc1e03 100644
--- a/libavfilter/vf_derain.c
+++ b/libavfilter/vf_derain.c
@@ -100,7 +100,7 @@  static int filter_frame(AVFilterLink *inlink, AVFrame *in)
 static av_cold int init(AVFilterContext *ctx)
 {
     DRContext *dr_context = ctx->priv;
-    return ff_dnn_init(&dr_context->dnnctx, ctx);
+    return ff_dnn_init(&dr_context->dnnctx, DFT_PROCESS_FRAME, ctx);
 }
 
 static av_cold void uninit(AVFilterContext *ctx)
diff --git a/libavfilter/vf_dnn_processing.c b/libavfilter/vf_dnn_processing.c
index 08ebf122c9..88e95e8ae3 100644
--- a/libavfilter/vf_dnn_processing.c
+++ b/libavfilter/vf_dnn_processing.c
@@ -62,7 +62,7 @@  AVFILTER_DEFINE_CLASS(dnn_processing);
 static av_cold int init(AVFilterContext *context)
 {
     DnnProcessingContext *ctx = context->priv;
-    return ff_dnn_init(&ctx->dnnctx, context);
+    return ff_dnn_init(&ctx->dnnctx, DFT_PROCESS_FRAME, context);
 }
 
 static int query_formats(AVFilterContext *context)
diff --git a/libavfilter/vf_sr.c b/libavfilter/vf_sr.c
index 20334a84c4..45f941acdb 100644
--- a/libavfilter/vf_sr.c
+++ b/libavfilter/vf_sr.c
@@ -63,7 +63,7 @@  AVFILTER_DEFINE_CLASS(sr);
 static av_cold int init(AVFilterContext *context)
 {
     SRContext *sr_context = context->priv;
-    return ff_dnn_init(&sr_context->dnnctx, context);
+    return ff_dnn_init(&sr_context->dnnctx, DFT_PROCESS_FRAME, context);
 }
 
 static int query_formats(AVFilterContext *context)