From patchwork Fri Nov 22 07:50:11 2019 Content-Type: text/plain; charset="utf-8" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit X-Patchwork-Submitter: "Guo, Yejun" X-Patchwork-Id: 16385 Return-Path: X-Original-To: patchwork@ffaux-bg.ffmpeg.org Delivered-To: patchwork@ffaux-bg.ffmpeg.org Received: from ffbox0-bg.mplayerhq.hu (ffbox0-bg.ffmpeg.org [79.124.17.100]) by ffaux.localdomain (Postfix) with ESMTP id 490EB44605E for ; Fri, 22 Nov 2019 09:57:08 +0200 (EET) Received: from [127.0.1.1] (localhost [127.0.0.1]) by ffbox0-bg.mplayerhq.hu (Postfix) with ESMTP id 32D8468AD6C; Fri, 22 Nov 2019 09:57:08 +0200 (EET) X-Original-To: ffmpeg-devel@ffmpeg.org Delivered-To: ffmpeg-devel@ffmpeg.org Received: from mga14.intel.com (mga14.intel.com [192.55.52.115]) by ffbox0-bg.mplayerhq.hu (Postfix) with ESMTPS id 97FDA6880E8 for ; Fri, 22 Nov 2019 09:57:01 +0200 (EET) X-Amp-Result: SKIPPED(no attachment in message) X-Amp-File-Uploaded: False Received: from orsmga003.jf.intel.com ([10.7.209.27]) by fmsmga103.fm.intel.com with ESMTP/TLS/DHE-RSA-AES256-GCM-SHA384; 21 Nov 2019 23:56:59 -0800 X-ExtLoop1: 1 X-IronPort-AV: E=Sophos;i="5.69,229,1571727600"; d="scan'208";a="210180260" Received: from yguo18-skl-u1604.sh.intel.com ([10.239.13.25]) by orsmga003.jf.intel.com with ESMTP; 21 Nov 2019 23:56:58 -0800 From: "Guo, Yejun" To: ffmpeg-devel@ffmpeg.org Date: Fri, 22 Nov 2019 15:50:11 +0800 Message-Id: <1574409011-15833-1-git-send-email-yejun.guo@intel.com> X-Mailer: git-send-email 2.7.4 Subject: [FFmpeg-devel] [PATCH 3/4] avfilter/vf_dnn_processing: add format GRAY8 and GRAYF32 support X-BeenThere: ffmpeg-devel@ffmpeg.org X-Mailman-Version: 2.1.20 Precedence: list List-Id: FFmpeg development discussions and patches List-Unsubscribe: , List-Archive: List-Post: List-Help: List-Subscribe: , Reply-To: FFmpeg development discussions and patches Cc: yejun.guo@intel.com MIME-Version: 1.0 Errors-To: ffmpeg-devel-bounces@ffmpeg.org Sender: "ffmpeg-devel" Signed-off-by: Guo, Yejun --- doc/filters.texi | 8 ++- libavfilter/vf_dnn_processing.c | 147 ++++++++++++++++++++++++++++++---------- 2 files changed, 118 insertions(+), 37 deletions(-) diff --git a/doc/filters.texi b/doc/filters.texi index 1f86ae1..c3f7997 100644 --- a/doc/filters.texi +++ b/doc/filters.texi @@ -8992,7 +8992,13 @@ Set the input name of the dnn network. Set the output name of the dnn network. @item fmt -Set the pixel format for the Frame. Allowed values are @code{AV_PIX_FMT_RGB24}, and @code{AV_PIX_FMT_BGR24}. +Set the pixel format for the Frame, the value is determined by the input of the dnn network model. + +If the model handles RGB (or BGR) image and the data type of model input is uint8, fmt must be @code{AV_PIX_FMT_RGB24} (or @code{AV_PIX_FMT_BGR24}. +If the model handles RGB (or BGR) image and the data type of model input is float, fmt must be @code{AV_PIX_FMT_RGB24} (or @code{AV_PIX_FMT_BGR24}, and this filter will do data type conversion internally. +If the model handles GRAY image and the data type of model input is uint8, fmt must be @code{AV_PIX_FMT_GRAY8}. +If the model handles GRAY image and the data type of model input is float, fmt must be @code{AV_PIX_FMT_GRAYF32}. + Default value is @code{AV_PIX_FMT_RGB24}. @end table diff --git a/libavfilter/vf_dnn_processing.c b/libavfilter/vf_dnn_processing.c index ce976ec..963dd5e 100644 --- a/libavfilter/vf_dnn_processing.c +++ b/libavfilter/vf_dnn_processing.c @@ -70,10 +70,12 @@ static av_cold int init(AVFilterContext *context) { DnnProcessingContext *ctx = context->priv; int supported = 0; - // as the first step, only rgb24 and bgr24 are supported + // to support more formats const enum AVPixelFormat supported_pixel_fmts[] = { AV_PIX_FMT_RGB24, AV_PIX_FMT_BGR24, + AV_PIX_FMT_GRAY8, + AV_PIX_FMT_GRAYF32, }; for (int i = 0; i < sizeof(supported_pixel_fmts) / sizeof(enum AVPixelFormat); ++i) { if (supported_pixel_fmts[i] == ctx->fmt) { @@ -156,14 +158,38 @@ static int config_input(AVFilterLink *inlink) return AVERROR(EIO); } - if (model_input.channels != 3) { - av_log(ctx, AV_LOG_ERROR, "the model requires input channels %d\n", - model_input.channels); - return AVERROR(EIO); - } - if (model_input.dt != DNN_FLOAT && model_input.dt != DNN_UINT8) { - av_log(ctx, AV_LOG_ERROR, "only support dnn models with input data type as float32 and uint8.\n"); - return AVERROR(EIO); + if (ctx->fmt == AV_PIX_FMT_RGB24 || ctx->fmt == AV_PIX_FMT_BGR24) { + if (model_input.channels != 3) { + av_log(ctx, AV_LOG_ERROR, "channel number 3 is required, but the actual channel number is %d\n", + model_input.channels); + return AVERROR(EIO); + } + if (model_input.dt != DNN_FLOAT && model_input.dt != DNN_UINT8) { + av_log(ctx, AV_LOG_ERROR, "only support dnn models with input data type as float32 and uint8.\n"); + return AVERROR(EIO); + } + } else if (ctx->fmt == AV_PIX_FMT_GRAY8) { + if (model_input.channels != 1) { + av_log(ctx, AV_LOG_ERROR, "channel number 1 is required, but the actual channel number is %d\n", + model_input.channels); + return AVERROR(EIO); + } + if (model_input.dt != DNN_UINT8) { + av_log(ctx, AV_LOG_ERROR, "only support dnn models with input data type as uint8.\n"); + return AVERROR(EIO); + } + } else if (ctx->fmt == AV_PIX_FMT_GRAYF32) { + if (model_input.channels != 1) { + av_log(ctx, AV_LOG_ERROR, "channel number 1 is required, but the actual channel number is %d\n", + model_input.channels); + return AVERROR(EIO); + } + if (model_input.dt != DNN_FLOAT) { + av_log(ctx, AV_LOG_ERROR, "only support dnn models with input data type as float.\n"); + return AVERROR(EIO); + } + } else { + av_assert0(!"should not reach here."); } ctx->input.width = inlink->w; @@ -203,28 +229,49 @@ static int config_output(AVFilterLink *outlink) static int copy_from_frame_to_dnn(DNNData *dnn_input, const AVFrame *frame) { - // extend this function to support more formats - av_assert0(frame->format == AV_PIX_FMT_RGB24 || frame->format == AV_PIX_FMT_BGR24); - - if (dnn_input->dt == DNN_FLOAT) { - float *dnn_input_data = dnn_input->data; - for (int i = 0; i < frame->height; i++) { - for(int j = 0; j < frame->width * 3; j++) { - int k = i * frame->linesize[0] + j; - int t = i * frame->width * 3 + j; - dnn_input_data[t] = frame->data[0][k] / 255.0f; + if (frame->format == AV_PIX_FMT_RGB24 || frame->format == AV_PIX_FMT_BGR24) { + if (dnn_input->dt == DNN_FLOAT) { + float *dnn_input_data = dnn_input->data; + for (int i = 0; i < frame->height; i++) { + for(int j = 0; j < frame->width * 3; j++) { + int k = i * frame->linesize[0] + j; + int t = i * frame->width * 3 + j; + dnn_input_data[t] = frame->data[0][k] / 255.0f; + } + } + } else { + uint8_t *dnn_input_data = dnn_input->data; + av_assert0(dnn_input->dt == DNN_UINT8); + for (int i = 0; i < frame->height; i++) { + for(int j = 0; j < frame->width * 3; j++) { + int k = i * frame->linesize[0] + j; + int t = i * frame->width * 3 + j; + dnn_input_data[t] = frame->data[0][k]; + } } } - } else { + } else if (frame->format == AV_PIX_FMT_GRAY8) { uint8_t *dnn_input_data = dnn_input->data; av_assert0(dnn_input->dt == DNN_UINT8); for (int i = 0; i < frame->height; i++) { - for(int j = 0; j < frame->width * 3; j++) { + for(int j = 0; j < frame->width; j++) { int k = i * frame->linesize[0] + j; - int t = i * frame->width * 3 + j; + int t = i * frame->width + j; dnn_input_data[t] = frame->data[0][k]; } } + } else if (frame->format == AV_PIX_FMT_GRAYF32) { + float *dnn_input_data = dnn_input->data; + av_assert0(dnn_input->dt == DNN_FLOAT); + for (int i = 0; i < frame->height; i++) { + for(int j = 0; j < frame->width; j++) { + int k = i * frame->linesize[0] + j * sizeof(float); + int t = i * frame->width + j; + dnn_input_data[t] = *(float*)(frame->data[0] + k); + } + } + } else { + av_assert0(!"should not reach here."); } return 0; @@ -232,28 +279,49 @@ static int copy_from_frame_to_dnn(DNNData *dnn_input, const AVFrame *frame) static int copy_from_dnn_to_frame(AVFrame *frame, const DNNData *dnn_output) { - // extend this function to support more formats - av_assert0(frame->format == AV_PIX_FMT_RGB24 || frame->format == AV_PIX_FMT_BGR24); - - if (dnn_output->dt == DNN_FLOAT) { - float *dnn_output_data = dnn_output->data; - for (int i = 0; i < frame->height; i++) { - for(int j = 0; j < frame->width * 3; j++) { - int k = i * frame->linesize[0] + j; - int t = i * frame->width * 3 + j; - frame->data[0][k] = av_clip_uintp2((int)(dnn_output_data[t] * 255.0f), 8); + if (frame->format == AV_PIX_FMT_RGB24 || frame->format == AV_PIX_FMT_BGR24) { + if (dnn_output->dt == DNN_FLOAT) { + float *dnn_output_data = dnn_output->data; + for (int i = 0; i < frame->height; i++) { + for(int j = 0; j < frame->width * 3; j++) { + int k = i * frame->linesize[0] + j; + int t = i * frame->width * 3 + j; + frame->data[0][k] = av_clip_uintp2((int)(dnn_output_data[t] * 255.0f), 8); + } + } + } else { + uint8_t *dnn_output_data = dnn_output->data; + av_assert0(dnn_output->dt == DNN_UINT8); + for (int i = 0; i < frame->height; i++) { + for(int j = 0; j < frame->width * 3; j++) { + int k = i * frame->linesize[0] + j; + int t = i * frame->width * 3 + j; + frame->data[0][k] = dnn_output_data[t]; + } } } - } else { + } else if (frame->format == AV_PIX_FMT_GRAY8) { uint8_t *dnn_output_data = dnn_output->data; av_assert0(dnn_output->dt == DNN_UINT8); for (int i = 0; i < frame->height; i++) { - for(int j = 0; j < frame->width * 3; j++) { + for(int j = 0; j < frame->width; j++) { int k = i * frame->linesize[0] + j; - int t = i * frame->width * 3 + j; + int t = i * frame->width + j; frame->data[0][k] = dnn_output_data[t]; } } + } else if (frame->format == AV_PIX_FMT_GRAYF32) { + float *dnn_output_data = dnn_output->data; + av_assert0(dnn_output->dt == DNN_FLOAT); + for (int i = 0; i < frame->height; i++) { + for(int j = 0; j < frame->width; j++) { + int k = i * frame->linesize[0] + j * sizeof(float); + int t = i * frame->width + j; + *(float*)(frame->data[0] + k) = dnn_output_data[t]; + } + } + } else { + av_assert0(!"should not reach here."); } return 0; @@ -275,7 +343,14 @@ static int filter_frame(AVFilterLink *inlink, AVFrame *in) av_frame_free(&in); return AVERROR(EIO); } - av_assert0(ctx->output.channels == 3); + + if (ctx->fmt == AV_PIX_FMT_RGB24 || ctx->fmt == AV_PIX_FMT_BGR24) { + av_assert0(ctx->output.channels == 3); + } else if (ctx->fmt == AV_PIX_FMT_GRAY8 || ctx->fmt == AV_PIX_FMT_GRAYF32) { + av_assert0(ctx->output.channels == 1); + } else { + av_assert0(!"should not reach here"); + } out = ff_get_video_buffer(outlink, outlink->w, outlink->h); if (!out) {