diff mbox

[FFmpeg-devel,3/4] libavfilter/dnn/dnn_backend_native: find the input operand according to input name

Message ID 1568951763-6118-1-git-send-email-yejun.guo@intel.com
State Accepted
Commit 75ca94f3cff8e86036010f496b975cf9c5a7ffb1
Headers show

Commit Message

Guo, Yejun Sept. 20, 2019, 3:56 a.m. UTC
Signed-off-by: Guo, Yejun <yejun.guo@intel.com>
---
 libavfilter/dnn/dnn_backend_native.c | 39 +++++++++++++++++++++---------------
 1 file changed, 23 insertions(+), 16 deletions(-)

Comments

Pedro Arthur Sept. 20, 2019, 6:18 p.m. UTC | #1
Em sex, 20 de set de 2019 às 01:01, Guo, Yejun <yejun.guo@intel.com> escreveu:
>
> Signed-off-by: Guo, Yejun <yejun.guo@intel.com>
> ---
>  libavfilter/dnn/dnn_backend_native.c | 39 +++++++++++++++++++++---------------
>  1 file changed, 23 insertions(+), 16 deletions(-)
>
> diff --git a/libavfilter/dnn/dnn_backend_native.c b/libavfilter/dnn/dnn_backend_native.c
> index 22a9a33..1b0aea2 100644
> --- a/libavfilter/dnn/dnn_backend_native.c
> +++ b/libavfilter/dnn/dnn_backend_native.c
> @@ -33,30 +33,37 @@
>  static DNNReturnType set_input_output_native(void *model, DNNInputData *input, const char *input_name, const char **output_names, uint32_t nb_output)
>  {
>      ConvolutionalNetwork *network = (ConvolutionalNetwork *)model;
> +    DnnOperand *oprd = NULL;
>
>      if (network->layers_num <= 0 || network->operands_num <= 0)
>          return DNN_ERROR;
>
>      av_assert0(input->dt == DNN_FLOAT);
> +    for (int i = 0; i < network->operands_num; ++i) {
> +        oprd = &network->operands[i];
> +        if (strcmp(oprd->name, input_name) == 0) {
> +            if (oprd->type != DOT_INPUT)
> +                return DNN_ERROR;
> +            break;
> +        }
> +        oprd = NULL;
> +    }
>
> -    /**
> -     * as the first step, suppose network->operands[0] is the input operand.
> -     */
> -    network->operands[0].dims[0] = 1;
> -    network->operands[0].dims[1] = input->height;
> -    network->operands[0].dims[2] = input->width;
> -    network->operands[0].dims[3] = input->channels;
> -    network->operands[0].type = DOT_INPUT;
> -    network->operands[0].data_type = DNN_FLOAT;
> -    network->operands[0].isNHWC = 1;
> -
> -    av_freep(&network->operands[0].data);
> -    network->operands[0].length = calculate_operand_data_length(&network->operands[0]);
> -    network->operands[0].data = av_malloc(network->operands[0].length);
> -    if (!network->operands[0].data)
> +    if (!oprd)
> +        return DNN_ERROR;
> +
> +    oprd->dims[0] = 1;
> +    oprd->dims[1] = input->height;
> +    oprd->dims[2] = input->width;
> +    oprd->dims[3] = input->channels;
> +
> +    av_freep(&oprd->data);
> +    oprd->length = calculate_operand_data_length(oprd);
> +    oprd->data = av_malloc(oprd->length);
> +    if (!oprd->data)
>          return DNN_ERROR;
>
> -    input->data = network->operands[0].data;
> +    input->data = oprd->data;
>      return DNN_SUCCESS;
>  }
>
> --
> 2.7.4
>
LGTM, pushed.

> _______________________________________________
> ffmpeg-devel mailing list
> ffmpeg-devel@ffmpeg.org
> https://ffmpeg.org/mailman/listinfo/ffmpeg-devel
>
> To unsubscribe, visit link above, or email
> ffmpeg-devel-request@ffmpeg.org with subject "unsubscribe".
diff mbox

Patch

diff --git a/libavfilter/dnn/dnn_backend_native.c b/libavfilter/dnn/dnn_backend_native.c
index 22a9a33..1b0aea2 100644
--- a/libavfilter/dnn/dnn_backend_native.c
+++ b/libavfilter/dnn/dnn_backend_native.c
@@ -33,30 +33,37 @@ 
 static DNNReturnType set_input_output_native(void *model, DNNInputData *input, const char *input_name, const char **output_names, uint32_t nb_output)
 {
     ConvolutionalNetwork *network = (ConvolutionalNetwork *)model;
+    DnnOperand *oprd = NULL;
 
     if (network->layers_num <= 0 || network->operands_num <= 0)
         return DNN_ERROR;
 
     av_assert0(input->dt == DNN_FLOAT);
+    for (int i = 0; i < network->operands_num; ++i) {
+        oprd = &network->operands[i];
+        if (strcmp(oprd->name, input_name) == 0) {
+            if (oprd->type != DOT_INPUT)
+                return DNN_ERROR;
+            break;
+        }
+        oprd = NULL;
+    }
 
-    /**
-     * as the first step, suppose network->operands[0] is the input operand.
-     */
-    network->operands[0].dims[0] = 1;
-    network->operands[0].dims[1] = input->height;
-    network->operands[0].dims[2] = input->width;
-    network->operands[0].dims[3] = input->channels;
-    network->operands[0].type = DOT_INPUT;
-    network->operands[0].data_type = DNN_FLOAT;
-    network->operands[0].isNHWC = 1;
-
-    av_freep(&network->operands[0].data);
-    network->operands[0].length = calculate_operand_data_length(&network->operands[0]);
-    network->operands[0].data = av_malloc(network->operands[0].length);
-    if (!network->operands[0].data)
+    if (!oprd)
+        return DNN_ERROR;
+
+    oprd->dims[0] = 1;
+    oprd->dims[1] = input->height;
+    oprd->dims[2] = input->width;
+    oprd->dims[3] = input->channels;
+
+    av_freep(&oprd->data);
+    oprd->length = calculate_operand_data_length(oprd);
+    oprd->data = av_malloc(oprd->length);
+    if (!oprd->data)
         return DNN_ERROR;
 
-    input->data = network->operands[0].data;
+    input->data = oprd->data;
     return DNN_SUCCESS;
 }