diff mbox series

[FFmpeg-devel,V2,2/4] lavfi/dnn_backend_tensorflow: add multiple outputs support

Message ID 20210506084610.23487-2-ting.fu@intel.com
State Accepted
Headers show
Series [FFmpeg-devel,V2,1/4] dnn: add DCO_RGB color order to enum DNNColorOrder
Related show

Checks

Context Check Description
andriy/x86_make success Make finished
andriy/x86_make_fate success Make fate finished
andriy/PPC64_make success Make finished
andriy/PPC64_make_fate success Make fate finished

Commit Message

Ting Fu May 6, 2021, 8:46 a.m. UTC
Signed-off-by: Ting Fu <ting.fu@intel.com>
---
 libavfilter/dnn/dnn_backend_tf.c | 49 ++++++++++++++---------------
 libavfilter/dnn_filter_common.c  | 53 ++++++++++++++++++++++++++------
 libavfilter/dnn_filter_common.h  |  6 ++--
 libavfilter/vf_derain.c          |  2 +-
 libavfilter/vf_sr.c              |  2 +-
 5 files changed, 75 insertions(+), 37 deletions(-)
diff mbox series

Patch

diff --git a/libavfilter/dnn/dnn_backend_tf.c b/libavfilter/dnn/dnn_backend_tf.c
index 45da29ae70..b6b1812cd9 100644
--- a/libavfilter/dnn/dnn_backend_tf.c
+++ b/libavfilter/dnn/dnn_backend_tf.c
@@ -155,7 +155,7 @@  static DNNReturnType get_input_tf(void *model, DNNData *input, const char *input
     TF_DeleteStatus(status);
 
     // currently only NHWC is supported
-    av_assert0(dims[0] == 1);
+    av_assert0(dims[0] == 1 || dims[0] == -1);
     input->height = dims[1];
     input->width = dims[2];
     input->channels = dims[3];
@@ -707,7 +707,7 @@  static DNNReturnType execute_model_tf(const DNNModel *model, const char *input_n
     TF_Output *tf_outputs;
     TFModel *tf_model = model->model;
     TFContext *ctx = &tf_model->ctx;
-    DNNData input, output;
+    DNNData input, *outputs;
     TF_Tensor **output_tensors;
     TF_Output tf_input;
     TF_Tensor *input_tensor;
@@ -738,14 +738,6 @@  static DNNReturnType execute_model_tf(const DNNModel *model, const char *input_n
         }
     }
 
-    if (nb_output != 1) {
-        // currently, the filter does not need multiple outputs,
-        // so we just pending the support until we really need it.
-        TF_DeleteTensor(input_tensor);
-        avpriv_report_missing_feature(ctx, "multiple outputs");
-        return DNN_ERROR;
-    }
-
     tf_outputs = av_malloc_array(nb_output, sizeof(*tf_outputs));
     if (tf_outputs == NULL) {
         TF_DeleteTensor(input_tensor);
@@ -785,23 +777,31 @@  static DNNReturnType execute_model_tf(const DNNModel *model, const char *input_n
         return DNN_ERROR;
     }
 
+    outputs = av_malloc_array(nb_output, sizeof(*outputs));
+    if (!outputs) {
+        TF_DeleteTensor(input_tensor);
+        av_freep(&tf_outputs);
+        av_freep(&output_tensors);
+        av_log(ctx, AV_LOG_ERROR, "Failed to allocate memory for *outputs\n"); \
+        return DNN_ERROR;
+    }
+
     for (uint32_t i = 0; i < nb_output; ++i) {
-        output.height = TF_Dim(output_tensors[i], 1);
-        output.width = TF_Dim(output_tensors[i], 2);
-        output.channels = TF_Dim(output_tensors[i], 3);
-        output.data = TF_TensorData(output_tensors[i]);
-        output.dt = TF_TensorType(output_tensors[i]);
-
-        if (do_ioproc) {
-            if (tf_model->model->frame_post_proc != NULL) {
-                tf_model->model->frame_post_proc(out_frame, &output, tf_model->model->filter_ctx);
-            } else {
-                ff_proc_from_dnn_to_frame(out_frame, &output, ctx);
-            }
+        outputs[i].height = TF_Dim(output_tensors[i], 1);
+        outputs[i].width = TF_Dim(output_tensors[i], 2);
+        outputs[i].channels = TF_Dim(output_tensors[i], 3);
+        outputs[i].data = TF_TensorData(output_tensors[i]);
+        outputs[i].dt = TF_TensorType(output_tensors[i]);
+    }
+    if (do_ioproc) {
+        if (tf_model->model->frame_post_proc != NULL) {
+            tf_model->model->frame_post_proc(out_frame, outputs, tf_model->model->filter_ctx);
         } else {
-            out_frame->width = output.width;
-            out_frame->height = output.height;
+            ff_proc_from_dnn_to_frame(out_frame, outputs, ctx);
         }
+    } else {
+        out_frame->width = outputs[0].width;
+        out_frame->height = outputs[0].height;
     }
 
     for (uint32_t i = 0; i < nb_output; ++i) {
@@ -812,6 +812,7 @@  static DNNReturnType execute_model_tf(const DNNModel *model, const char *input_n
     TF_DeleteTensor(input_tensor);
     av_freep(&output_tensors);
     av_freep(&tf_outputs);
+    av_freep(&outputs);
     return DNN_SUCCESS;
 }
 
diff --git a/libavfilter/dnn_filter_common.c b/libavfilter/dnn_filter_common.c
index 52c7a5392a..0ed0ac2e30 100644
--- a/libavfilter/dnn_filter_common.c
+++ b/libavfilter/dnn_filter_common.c
@@ -17,6 +17,39 @@ 
  */
 
 #include "dnn_filter_common.h"
+#include "libavutil/avstring.h"
+
+#define MAX_SUPPORTED_OUTPUTS_NB 4
+
+static char **separate_output_names(const char *expr, const char *val_sep, int *separated_nb)
+{
+    char *val, **parsed_vals = NULL;
+    int val_num = 0;
+    if (!expr || !val_sep || !separated_nb) {
+        return NULL;
+    }
+
+    parsed_vals = av_mallocz_array(MAX_SUPPORTED_OUTPUTS_NB, sizeof(*parsed_vals));
+    if (!parsed_vals) {
+        return NULL;
+    }
+
+    do {
+        val = av_get_token(&expr, val_sep);
+        if(val) {
+            parsed_vals[val_num] = val;
+            val_num++;
+        }
+        if (*expr) {
+            expr++;
+        }
+    } while(*expr);
+
+    parsed_vals[val_num] = NULL;
+    *separated_nb = val_num;
+
+    return parsed_vals;
+}
 
 int ff_dnn_init(DnnContext *ctx, DNNFunctionType func_type, AVFilterContext *filter_ctx)
 {
@@ -28,8 +61,10 @@  int ff_dnn_init(DnnContext *ctx, DNNFunctionType func_type, AVFilterContext *fil
         av_log(filter_ctx, AV_LOG_ERROR, "input name of the model network is not specified\n");
         return AVERROR(EINVAL);
     }
-    if (!ctx->model_outputname) {
-        av_log(filter_ctx, AV_LOG_ERROR, "output name of the model network is not specified\n");
+
+    ctx->model_outputnames = separate_output_names(ctx->model_outputnames_string, "&", &ctx->nb_outputs);
+    if (!ctx->model_outputnames) {
+        av_log(filter_ctx, AV_LOG_ERROR, "could not parse model output names\n");
         return AVERROR(EINVAL);
     }
 
@@ -91,15 +126,15 @@  DNNReturnType ff_dnn_get_input(DnnContext *ctx, DNNData *input)
 DNNReturnType ff_dnn_get_output(DnnContext *ctx, int input_width, int input_height, int *output_width, int *output_height)
 {
     return ctx->model->get_output(ctx->model->model, ctx->model_inputname, input_width, input_height,
-                                    ctx->model_outputname, output_width, output_height);
+                                    (const char *)ctx->model_outputnames[0], output_width, output_height);
 }
 
 DNNReturnType ff_dnn_execute_model(DnnContext *ctx, AVFrame *in_frame, AVFrame *out_frame)
 {
     DNNExecBaseParams exec_params = {
         .input_name     = ctx->model_inputname,
-        .output_names   = (const char **)&ctx->model_outputname,
-        .nb_output      = 1,
+        .output_names   = (const char **)ctx->model_outputnames,
+        .nb_output      = ctx->nb_outputs,
         .in_frame       = in_frame,
         .out_frame      = out_frame,
     };
@@ -110,8 +145,8 @@  DNNReturnType ff_dnn_execute_model_async(DnnContext *ctx, AVFrame *in_frame, AVF
 {
     DNNExecBaseParams exec_params = {
         .input_name     = ctx->model_inputname,
-        .output_names   = (const char **)&ctx->model_outputname,
-        .nb_output      = 1,
+        .output_names   = (const char **)ctx->model_outputnames,
+        .nb_output      = ctx->nb_outputs,
         .in_frame       = in_frame,
         .out_frame      = out_frame,
     };
@@ -123,8 +158,8 @@  DNNReturnType ff_dnn_execute_model_classification(DnnContext *ctx, AVFrame *in_f
     DNNExecClassificationParams class_params = {
         {
             .input_name     = ctx->model_inputname,
-            .output_names   = (const char **)&ctx->model_outputname,
-            .nb_output      = 1,
+            .output_names   = (const char **)ctx->model_outputnames,
+            .nb_output      = ctx->nb_outputs,
             .in_frame       = in_frame,
             .out_frame      = out_frame,
         },
diff --git a/libavfilter/dnn_filter_common.h b/libavfilter/dnn_filter_common.h
index e7736d2bac..e3a396d74a 100644
--- a/libavfilter/dnn_filter_common.h
+++ b/libavfilter/dnn_filter_common.h
@@ -30,10 +30,12 @@  typedef struct DnnContext {
     char *model_filename;
     DNNBackendType backend_type;
     char *model_inputname;
-    char *model_outputname;
+    char *model_outputnames_string;
+    uint32_t nb_outputs;
     char *backend_options;
     int async;
 
+    char **model_outputnames;
     DNNModule *dnn_module;
     DNNModel *model;
 } DnnContext;
@@ -41,7 +43,7 @@  typedef struct DnnContext {
 #define DNN_COMMON_OPTIONS \
     { "model",              "path to model file",         OFFSET(model_filename),   AV_OPT_TYPE_STRING,    { .str = NULL }, 0, 0, FLAGS },\
     { "input",              "input name of the model",    OFFSET(model_inputname),  AV_OPT_TYPE_STRING,    { .str = NULL }, 0, 0, FLAGS },\
-    { "output",             "output name of the model",   OFFSET(model_outputname), AV_OPT_TYPE_STRING,    { .str = NULL }, 0, 0, FLAGS },\
+    { "output",             "output name of the model",   OFFSET(model_outputnames_string), AV_OPT_TYPE_STRING, { .str = NULL }, 0, 0, FLAGS },\
     { "backend_configs",    "backend configs",            OFFSET(backend_options),  AV_OPT_TYPE_STRING,    { .str = NULL }, 0, 0, FLAGS },\
     { "options",            "backend configs",            OFFSET(backend_options),  AV_OPT_TYPE_STRING,    { .str = NULL }, 0, 0, FLAGS },\
     { "async",              "use DNN async inference",    OFFSET(async),            AV_OPT_TYPE_BOOL,      { .i64 = 1},     0, 1, FLAGS},
diff --git a/libavfilter/vf_derain.c b/libavfilter/vf_derain.c
index 76c4ef414f..5037f3a5f7 100644
--- a/libavfilter/vf_derain.c
+++ b/libavfilter/vf_derain.c
@@ -50,7 +50,7 @@  static const AVOption derain_options[] = {
 #endif
     { "model",       "path to model file",          OFFSET(dnnctx.model_filename),   AV_OPT_TYPE_STRING,    { .str = NULL }, 0, 0, FLAGS },
     { "input",       "input name of the model",     OFFSET(dnnctx.model_inputname),  AV_OPT_TYPE_STRING,    { .str = "x" },  0, 0, FLAGS },
-    { "output",      "output name of the model",    OFFSET(dnnctx.model_outputname), AV_OPT_TYPE_STRING,    { .str = "y" },  0, 0, FLAGS },
+    { "output",      "output name of the model",    OFFSET(dnnctx.model_outputnames_string), AV_OPT_TYPE_STRING,    { .str = "y" },  0, 0, FLAGS },
     { NULL }
 };
 
diff --git a/libavfilter/vf_sr.c b/libavfilter/vf_sr.c
index 4360439ca6..f930b38748 100644
--- a/libavfilter/vf_sr.c
+++ b/libavfilter/vf_sr.c
@@ -54,7 +54,7 @@  static const AVOption sr_options[] = {
     { "scale_factor", "scale factor for SRCNN model", OFFSET(scale_factor), AV_OPT_TYPE_INT, { .i64 = 2 }, 2, 4, FLAGS },
     { "model", "path to model file specifying network architecture and its parameters", OFFSET(dnnctx.model_filename), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, FLAGS },
     { "input",       "input name of the model",     OFFSET(dnnctx.model_inputname),  AV_OPT_TYPE_STRING,    { .str = "x" },  0, 0, FLAGS },
-    { "output",      "output name of the model",    OFFSET(dnnctx.model_outputname), AV_OPT_TYPE_STRING,    { .str = "y" },  0, 0, FLAGS },
+    { "output",      "output name of the model",    OFFSET(dnnctx.model_outputnames_string), AV_OPT_TYPE_STRING,    { .str = "y" },  0, 0, FLAGS },
     { NULL }
 };