From patchwork Thu Apr 25 02:14:42 2019 Content-Type: text/plain; charset="utf-8" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit X-Patchwork-Submitter: "Guo, Yejun" X-Patchwork-Id: 12902 Return-Path: X-Original-To: patchwork@ffaux-bg.ffmpeg.org Delivered-To: patchwork@ffaux-bg.ffmpeg.org Received: from ffbox0-bg.mplayerhq.hu (ffbox0-bg.ffmpeg.org [79.124.17.100]) by ffaux.localdomain (Postfix) with ESMTP id 22E7844897B for ; Thu, 25 Apr 2019 05:14:50 +0300 (EEST) Received: from [127.0.1.1] (localhost [127.0.0.1]) by ffbox0-bg.mplayerhq.hu (Postfix) with ESMTP id 1029268A893; Thu, 25 Apr 2019 05:14:50 +0300 (EEST) X-Original-To: ffmpeg-devel@ffmpeg.org Delivered-To: ffmpeg-devel@ffmpeg.org Received: from mga02.intel.com (mga02.intel.com [134.134.136.20]) by ffbox0-bg.mplayerhq.hu (Postfix) with ESMTPS id ADD376804C2 for ; Thu, 25 Apr 2019 05:14:47 +0300 (EEST) X-Amp-Result: SKIPPED(no attachment in message) X-Amp-File-Uploaded: False Received: from fmsmga005.fm.intel.com ([10.253.24.32]) by orsmga101.jf.intel.com with ESMTP/TLS/DHE-RSA-AES256-GCM-SHA384; 24 Apr 2019 19:14:45 -0700 X-ExtLoop1: 1 X-IronPort-AV: E=Sophos;i="5.60,392,1549958400"; d="scan'208";a="340548209" Received: from yguo18-skl-u1604.sh.intel.com ([10.239.13.25]) by fmsmga005.fm.intel.com with ESMTP; 24 Apr 2019 19:14:45 -0700 From: "Guo, Yejun" To: ffmpeg-devel@ffmpeg.org Date: Thu, 25 Apr 2019 10:14:42 +0800 Message-Id: <1556158482-15306-1-git-send-email-yejun.guo@intel.com> X-Mailer: git-send-email 2.7.4 Subject: [FFmpeg-devel] [PATCH V2 7/7] libavfilter/dnn: add more data type support for dnn model input X-BeenThere: ffmpeg-devel@ffmpeg.org X-Mailman-Version: 2.1.20 Precedence: list List-Id: FFmpeg development discussions and patches List-Unsubscribe: , List-Archive: List-Post: List-Help: List-Subscribe: , Reply-To: FFmpeg development discussions and patches Cc: yejun.guo@intel.com MIME-Version: 1.0 Errors-To: ffmpeg-devel-bounces@ffmpeg.org Sender: "ffmpeg-devel" currently, only float is supported as model input, actually, there are other data types, this patch adds uint8. Signed-off-by: Guo, Yejun --- libavfilter/dnn_backend_native.c | 4 +++- libavfilter/dnn_backend_tf.c | 28 ++++++++++++++++++++++++---- libavfilter/dnn_interface.h | 10 +++++++++- libavfilter/vf_sr.c | 4 +++- 4 files changed, 39 insertions(+), 7 deletions(-) diff --git a/libavfilter/dnn_backend_native.c b/libavfilter/dnn_backend_native.c index 8a83c63..06fbdf3 100644 --- a/libavfilter/dnn_backend_native.c +++ b/libavfilter/dnn_backend_native.c @@ -24,8 +24,9 @@ */ #include "dnn_backend_native.h" +#include "libavutil/avassert.h" -static DNNReturnType set_input_output_native(void *model, DNNData *input, const char *input_name, const char **output_names, uint32_t nb_output) +static DNNReturnType set_input_output_native(void *model, DNNInputData *input, const char *input_name, const char **output_names, uint32_t nb_output) { ConvolutionalNetwork *network = (ConvolutionalNetwork *)model; InputParams *input_params; @@ -45,6 +46,7 @@ static DNNReturnType set_input_output_native(void *model, DNNData *input, const if (input->data){ av_freep(&input->data); } + av_assert0(input->dt == DNN_FLOAT); network->layers[0].output = input->data = av_malloc(cur_height * cur_width * cur_channels * sizeof(float)); if (!network->layers[0].output){ return DNN_ERROR; diff --git a/libavfilter/dnn_backend_tf.c b/libavfilter/dnn_backend_tf.c index ca6472d..ba959ae 100644 --- a/libavfilter/dnn_backend_tf.c +++ b/libavfilter/dnn_backend_tf.c @@ -79,10 +79,31 @@ static TF_Buffer *read_graph(const char *model_filename) return graph_buf; } -static DNNReturnType set_input_output_tf(void *model, DNNData *input, const char *input_name, const char **output_names, uint32_t nb_output) +static TF_Tensor *allocate_input_tensor(const DNNInputData *input) { - TFModel *tf_model = (TFModel *)model; + TF_DataType dt; + size_t size; int64_t input_dims[] = {1, input->height, input->width, input->channels}; + switch (input->dt) { + case DNN_FLOAT: + dt = TF_FLOAT; + size = sizeof(float); + break; + case DNN_UINT8: + dt = TF_UINT8; + size = sizeof(char); + break; + default: + av_assert0(!"should not reach here"); + } + + return TF_AllocateTensor(dt, input_dims, 4, + input_dims[1] * input_dims[2] * input_dims[3] * size); +} + +static DNNReturnType set_input_output_tf(void *model, DNNInputData *input, const char *input_name, const char **output_names, uint32_t nb_output) +{ + TFModel *tf_model = (TFModel *)model; TF_SessionOptions *sess_opts; const TF_Operation *init_op = TF_GraphOperationByName(tf_model->graph, "init"); @@ -95,8 +116,7 @@ static DNNReturnType set_input_output_tf(void *model, DNNData *input, const char if (tf_model->input_tensor){ TF_DeleteTensor(tf_model->input_tensor); } - tf_model->input_tensor = TF_AllocateTensor(TF_FLOAT, input_dims, 4, - input_dims[1] * input_dims[2] * input_dims[3] * sizeof(float)); + tf_model->input_tensor = allocate_input_tensor(input); if (!tf_model->input_tensor){ return DNN_ERROR; } diff --git a/libavfilter/dnn_interface.h b/libavfilter/dnn_interface.h index 73d226e..c24df0e 100644 --- a/libavfilter/dnn_interface.h +++ b/libavfilter/dnn_interface.h @@ -32,6 +32,14 @@ typedef enum {DNN_SUCCESS, DNN_ERROR} DNNReturnType; typedef enum {DNN_NATIVE, DNN_TF} DNNBackendType; +typedef enum {DNN_FLOAT, DNN_UINT8} DNNDataType; + +typedef struct DNNInputData{ + void *data; + DNNDataType dt; + int width, height, channels; +} DNNInputData; + typedef struct DNNData{ float *data; int width, height, channels; @@ -42,7 +50,7 @@ typedef struct DNNModel{ void *model; // Sets model input and output. // Should be called at least once before model execution. - DNNReturnType (*set_input_output)(void *model, DNNData *input, const char *input_name, const char **output_names, uint32_t nb_output); + DNNReturnType (*set_input_output)(void *model, DNNInputData *input, const char *input_name, const char **output_names, uint32_t nb_output); } DNNModel; // Stores pointers to functions for loading, executing, freeing DNN models for one of the backends. diff --git a/libavfilter/vf_sr.c b/libavfilter/vf_sr.c index b4d4165..c0d7126 100644 --- a/libavfilter/vf_sr.c +++ b/libavfilter/vf_sr.c @@ -40,7 +40,8 @@ typedef struct SRContext { DNNBackendType backend_type; DNNModule *dnn_module; DNNModel *model; - DNNData input, output; + DNNInputData input; + DNNData output; int scale_factor; struct SwsContext *sws_contexts[3]; int sws_slice_h, sws_input_linesize, sws_output_linesize; @@ -87,6 +88,7 @@ static av_cold int init(AVFilterContext *context) return AVERROR(EIO); } + sr_context->input.dt = DNN_FLOAT; sr_context->sws_contexts[0] = NULL; sr_context->sws_contexts[1] = NULL; sr_context->sws_contexts[2] = NULL;