diff mbox

[FFmpeg-devel,V2,3/7] libavfilter/dnn: remove limit for the name of DNN model input/output

Message ID 1556158448-15121-1-git-send-email-yejun.guo@intel.com
State New
Headers show

Commit Message

Guo, Yejun April 25, 2019, 2:14 a.m. UTC
remove the requirment that the name of DNN model input/output
should be "x"/"y",

Signed-off-by: Guo, Yejun <yejun.guo@intel.com>
---
 libavfilter/dnn_backend_native.c |  2 +-
 libavfilter/dnn_backend_tf.c     | 10 +++++-----
 libavfilter/dnn_interface.h      |  2 +-
 libavfilter/vf_sr.c              |  4 ++--
 4 files changed, 9 insertions(+), 9 deletions(-)

Comments

Pedro Arthur April 29, 2019, 5:36 p.m. UTC | #1
Em qua, 24 de abr de 2019 às 23:14, Guo, Yejun <yejun.guo@intel.com> escreveu:
>
> remove the requirment that the name of DNN model input/output
> should be "x"/"y",
>
> Signed-off-by: Guo, Yejun <yejun.guo@intel.com>
> ---
>  libavfilter/dnn_backend_native.c |  2 +-
>  libavfilter/dnn_backend_tf.c     | 10 +++++-----
>  libavfilter/dnn_interface.h      |  2 +-
>  libavfilter/vf_sr.c              |  4 ++--
>  4 files changed, 9 insertions(+), 9 deletions(-)
>
> diff --git a/libavfilter/dnn_backend_native.c b/libavfilter/dnn_backend_native.c
> index 70d857f..fe43116 100644
> --- a/libavfilter/dnn_backend_native.c
> +++ b/libavfilter/dnn_backend_native.c
> @@ -25,7 +25,7 @@
>
>  #include "dnn_backend_native.h"
>
> -static DNNReturnType set_input_output_native(void *model, DNNData *input, DNNData *output)
> +static DNNReturnType set_input_output_native(void *model, DNNData *input, const char *input_name, DNNData *output, const char *output_name)
>  {
>      ConvolutionalNetwork *network = (ConvolutionalNetwork *)model;
>      InputParams *input_params;
> diff --git a/libavfilter/dnn_backend_tf.c b/libavfilter/dnn_backend_tf.c
> index 9e0c127..a838907 100644
> --- a/libavfilter/dnn_backend_tf.c
> +++ b/libavfilter/dnn_backend_tf.c
> @@ -76,7 +76,7 @@ static TF_Buffer *read_graph(const char *model_filename)
>      return graph_buf;
>  }
>
> -static DNNReturnType set_input_output_tf(void *model, DNNData *input, DNNData *output)
> +static DNNReturnType set_input_output_tf(void *model, DNNData *input, const char *input_name, DNNData *output, const char *output_name)
>  {
>      TFModel *tf_model = (TFModel *)model;
>      int64_t input_dims[] = {1, input->height, input->width, input->channels};
> @@ -84,8 +84,8 @@ static DNNReturnType set_input_output_tf(void *model, DNNData *input, DNNData *o
>      const TF_Operation *init_op = TF_GraphOperationByName(tf_model->graph, "init");
>      TF_Tensor *output_tensor;
>
> -    // Input operation should be named 'x'
> -    tf_model->input.oper = TF_GraphOperationByName(tf_model->graph, "x");
> +    // Input operation
> +    tf_model->input.oper = TF_GraphOperationByName(tf_model->graph, input_name);
>      if (!tf_model->input.oper){
>          return DNN_ERROR;
>      }
> @@ -100,8 +100,8 @@ static DNNReturnType set_input_output_tf(void *model, DNNData *input, DNNData *o
>      }
>      input->data = (float *)TF_TensorData(tf_model->input_tensor);
>
> -    // Output operation should be named 'y'
> -    tf_model->output.oper = TF_GraphOperationByName(tf_model->graph, "y");
> +    // Output operation
> +    tf_model->output.oper = TF_GraphOperationByName(tf_model->graph, output_name);
>      if (!tf_model->output.oper){
>          return DNN_ERROR;
>      }
> diff --git a/libavfilter/dnn_interface.h b/libavfilter/dnn_interface.h
> index e367343..0390e39 100644
> --- a/libavfilter/dnn_interface.h
> +++ b/libavfilter/dnn_interface.h
> @@ -40,7 +40,7 @@ typedef struct DNNModel{
>      void *model;
>      // Sets model input and output, while allocating additional memory for intermediate calculations.
>      // Should be called at least once before model execution.
> -    DNNReturnType (*set_input_output)(void *model, DNNData *input, DNNData *output);
> +    DNNReturnType (*set_input_output)(void *model, DNNData *input, const char *input_name, DNNData *output, const char *output_name);
>  } DNNModel;
>
>  // Stores pointers to functions for loading, executing, freeing DNN models for one of the backends.
> diff --git a/libavfilter/vf_sr.c b/libavfilter/vf_sr.c
> index 9bb0fc5..085ac19 100644
> --- a/libavfilter/vf_sr.c
> +++ b/libavfilter/vf_sr.c
> @@ -122,7 +122,7 @@ static int config_props(AVFilterLink *inlink)
>      sr_context->input.height = inlink->h * sr_context->scale_factor;
>      sr_context->input.channels = 1;
>
> -    result = (sr_context->model->set_input_output)(sr_context->model->model, &sr_context->input, &sr_context->output);
> +    result = (sr_context->model->set_input_output)(sr_context->model->model, &sr_context->input, "x", &sr_context->output, "y");
>      if (result != DNN_SUCCESS){
>          av_log(context, AV_LOG_ERROR, "could not set input and output for the model\n");
>          return AVERROR(EIO);
> @@ -131,7 +131,7 @@ static int config_props(AVFilterLink *inlink)
>      if (sr_context->input.height != sr_context->output.height || sr_context->input.width != sr_context->output.width){
>          sr_context->input.width = inlink->w;
>          sr_context->input.height = inlink->h;
> -        result = (sr_context->model->set_input_output)(sr_context->model->model, &sr_context->input, &sr_context->output);
> +        result = (sr_context->model->set_input_output)(sr_context->model->model, &sr_context->input, "x", &sr_context->output, "y");
>          if (result != DNN_SUCCESS){
>              av_log(context, AV_LOG_ERROR, "could not set input and output for the model\n");
>              return AVERROR(EIO);
> --
> 2.7.4
>

LGTM.

> _______________________________________________
> ffmpeg-devel mailing list
> ffmpeg-devel@ffmpeg.org
> https://ffmpeg.org/mailman/listinfo/ffmpeg-devel
>
> To unsubscribe, visit link above, or email
> ffmpeg-devel-request@ffmpeg.org with subject "unsubscribe".
diff mbox

Patch

diff --git a/libavfilter/dnn_backend_native.c b/libavfilter/dnn_backend_native.c
index 70d857f..fe43116 100644
--- a/libavfilter/dnn_backend_native.c
+++ b/libavfilter/dnn_backend_native.c
@@ -25,7 +25,7 @@ 
 
 #include "dnn_backend_native.h"
 
-static DNNReturnType set_input_output_native(void *model, DNNData *input, DNNData *output)
+static DNNReturnType set_input_output_native(void *model, DNNData *input, const char *input_name, DNNData *output, const char *output_name)
 {
     ConvolutionalNetwork *network = (ConvolutionalNetwork *)model;
     InputParams *input_params;
diff --git a/libavfilter/dnn_backend_tf.c b/libavfilter/dnn_backend_tf.c
index 9e0c127..a838907 100644
--- a/libavfilter/dnn_backend_tf.c
+++ b/libavfilter/dnn_backend_tf.c
@@ -76,7 +76,7 @@  static TF_Buffer *read_graph(const char *model_filename)
     return graph_buf;
 }
 
-static DNNReturnType set_input_output_tf(void *model, DNNData *input, DNNData *output)
+static DNNReturnType set_input_output_tf(void *model, DNNData *input, const char *input_name, DNNData *output, const char *output_name)
 {
     TFModel *tf_model = (TFModel *)model;
     int64_t input_dims[] = {1, input->height, input->width, input->channels};
@@ -84,8 +84,8 @@  static DNNReturnType set_input_output_tf(void *model, DNNData *input, DNNData *o
     const TF_Operation *init_op = TF_GraphOperationByName(tf_model->graph, "init");
     TF_Tensor *output_tensor;
 
-    // Input operation should be named 'x'
-    tf_model->input.oper = TF_GraphOperationByName(tf_model->graph, "x");
+    // Input operation
+    tf_model->input.oper = TF_GraphOperationByName(tf_model->graph, input_name);
     if (!tf_model->input.oper){
         return DNN_ERROR;
     }
@@ -100,8 +100,8 @@  static DNNReturnType set_input_output_tf(void *model, DNNData *input, DNNData *o
     }
     input->data = (float *)TF_TensorData(tf_model->input_tensor);
 
-    // Output operation should be named 'y'
-    tf_model->output.oper = TF_GraphOperationByName(tf_model->graph, "y");
+    // Output operation
+    tf_model->output.oper = TF_GraphOperationByName(tf_model->graph, output_name);
     if (!tf_model->output.oper){
         return DNN_ERROR;
     }
diff --git a/libavfilter/dnn_interface.h b/libavfilter/dnn_interface.h
index e367343..0390e39 100644
--- a/libavfilter/dnn_interface.h
+++ b/libavfilter/dnn_interface.h
@@ -40,7 +40,7 @@  typedef struct DNNModel{
     void *model;
     // Sets model input and output, while allocating additional memory for intermediate calculations.
     // Should be called at least once before model execution.
-    DNNReturnType (*set_input_output)(void *model, DNNData *input, DNNData *output);
+    DNNReturnType (*set_input_output)(void *model, DNNData *input, const char *input_name, DNNData *output, const char *output_name);
 } DNNModel;
 
 // Stores pointers to functions for loading, executing, freeing DNN models for one of the backends.
diff --git a/libavfilter/vf_sr.c b/libavfilter/vf_sr.c
index 9bb0fc5..085ac19 100644
--- a/libavfilter/vf_sr.c
+++ b/libavfilter/vf_sr.c
@@ -122,7 +122,7 @@  static int config_props(AVFilterLink *inlink)
     sr_context->input.height = inlink->h * sr_context->scale_factor;
     sr_context->input.channels = 1;
 
-    result = (sr_context->model->set_input_output)(sr_context->model->model, &sr_context->input, &sr_context->output);
+    result = (sr_context->model->set_input_output)(sr_context->model->model, &sr_context->input, "x", &sr_context->output, "y");
     if (result != DNN_SUCCESS){
         av_log(context, AV_LOG_ERROR, "could not set input and output for the model\n");
         return AVERROR(EIO);
@@ -131,7 +131,7 @@  static int config_props(AVFilterLink *inlink)
     if (sr_context->input.height != sr_context->output.height || sr_context->input.width != sr_context->output.width){
         sr_context->input.width = inlink->w;
         sr_context->input.height = inlink->h;
-        result = (sr_context->model->set_input_output)(sr_context->model->model, &sr_context->input, &sr_context->output);
+        result = (sr_context->model->set_input_output)(sr_context->model->model, &sr_context->input, "x", &sr_context->output, "y");
         if (result != DNN_SUCCESS){
             av_log(context, AV_LOG_ERROR, "could not set input and output for the model\n");
             return AVERROR(EIO);