From patchwork Wed Jan 17 07:21:50 2024 Content-Type: text/plain; charset="utf-8" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit X-Patchwork-Submitter: "Chen, Wenbin" X-Patchwork-Id: 45623 Delivered-To: ffmpegpatchwork2@gmail.com Received: by 2002:a05:6a20:c58a:b0:199:de12:6fa6 with SMTP id gn10csp185504pzb; Tue, 16 Jan 2024 23:22:18 -0800 (PST) X-Google-Smtp-Source: AGHT+IGYxL4D/8RPJ4wnaold1ORC4We2BC99iVMjkqGUx5ukgEss8BAqgKSwQIsMD1l9ZYY6pbhp X-Received: by 2002:a05:6402:717:b0:559:c6da:c889 with SMTP id w23-20020a056402071700b00559c6dac889mr667041edx.1.1705476137850; Tue, 16 Jan 2024 23:22:17 -0800 (PST) ARC-Seal: i=1; a=rsa-sha256; t=1705476137; cv=none; d=google.com; s=arc-20160816; b=t7cES3JOLTzAv+ktQ3NkSNw+mxjaNiBPwsNO+qDpH+8TvMOdzUa6mQWLHoqLf+g8j7 J7SEVX4h9F8tsntnEqk/fJ3qajiU+g9g2Ir4ohrwe1eTi2XAjJCXLX/JwCsheR8LLCEF C4I6EOiw+LfW5Jb3eIDK8vTxqNCK4sHEKHA7J2xyPn3U2QJNwWHH/zlcqPLwwlPDXQCO PuB04eVH4p2tqotR9M4LAtzRoTrCpisX/35JvxpqNuvUehApW8ohHXe43bpTnqphzbG8 7AJatA3hCvv23e/obc8eqRs/DLJ0NtmpVhYWA58RBXMxHmB3quKdU4voOFgyNlxEvu9K ydtQ== ARC-Message-Signature: i=1; a=rsa-sha256; c=relaxed/relaxed; d=google.com; s=arc-20160816; h=sender:errors-to:content-transfer-encoding:reply-to:list-subscribe :list-help:list-post:list-archive:list-unsubscribe:list-id :precedence:subject:mime-version:references:in-reply-to:message-id :date:to:from:dkim-signature:delivered-to; bh=ZAaGlxtKNtiOp7x7fBktGE2h7xT06Pte67oNC9ZQG2o=; fh=YOA8vD9MJZuwZ71F/05pj6KdCjf6jQRmzLS+CATXUQk=; b=wWWZkfbxhYKMZH401Zvicw3VwQqUynqknUL5wP8QwZQiXcSBGIJEmdHF8dBlHKeXPk /kDMaP6rti4Jy4cRipe0hj5niVHa8/rcQsm7cvX9vgGeFPVc8kZEHtTvb1ACOClRhu8k t9/jgzmyRxU3VJmaoe2NK146r/Jqm1SE3GbODauEO4y2He1B3cwkLZsPLe4BlKsJkeY9 O8I2kpN6qTBl/nDnVR+Kdkg/MVCZbcRL2kEKaI9Qpd5yJaiQUdpzuwhWCs/RFae997Yc aE6p51ajj5mbl/0ZAtU02U9/Uiz3WVB2ZgOBJO54NWET3eJ0J3XN3mz1sWVAeKdJftfk Chqw== ARC-Authentication-Results: i=1; mx.google.com; dkim=neutral (body hash did not verify) header.i=@intel.com header.s=Intel header.b=dvseoZtW; spf=pass (google.com: domain of ffmpeg-devel-bounces@ffmpeg.org designates 79.124.17.100 as permitted sender) smtp.mailfrom=ffmpeg-devel-bounces@ffmpeg.org Return-Path: Received: from ffbox0-bg.mplayerhq.hu (ffbox0-bg.ffmpeg.org. [79.124.17.100]) by mx.google.com with ESMTP id d8-20020a50ea88000000b005597580973asi1856080edo.186.2024.01.16.23.22.17; Tue, 16 Jan 2024 23:22:17 -0800 (PST) Received-SPF: pass (google.com: domain of ffmpeg-devel-bounces@ffmpeg.org designates 79.124.17.100 as permitted sender) client-ip=79.124.17.100; Authentication-Results: mx.google.com; dkim=neutral (body hash did not verify) header.i=@intel.com header.s=Intel header.b=dvseoZtW; spf=pass (google.com: domain of ffmpeg-devel-bounces@ffmpeg.org designates 79.124.17.100 as permitted sender) smtp.mailfrom=ffmpeg-devel-bounces@ffmpeg.org Received: from [127.0.1.1] (localhost [127.0.0.1]) by ffbox0-bg.mplayerhq.hu (Postfix) with ESMTP id 9923D68D048; Wed, 17 Jan 2024 09:22:10 +0200 (EET) X-Original-To: ffmpeg-devel@ffmpeg.org Delivered-To: ffmpeg-devel@ffmpeg.org Received: from mgamail.intel.com (mgamail.intel.com [192.198.163.9]) by ffbox0-bg.mplayerhq.hu (Postfix) with ESMTPS id D486968D01F for ; Wed, 17 Jan 2024 09:22:02 +0200 (EET) DKIM-Signature: v=1; a=rsa-sha256; c=relaxed/simple; d=intel.com; i=@intel.com; q=dns/txt; s=Intel; t=1705476128; x=1737012128; h=from:to:subject:date:message-id:in-reply-to:references: mime-version:content-transfer-encoding; bh=vtFT8w4pGINkEVHCOcZJfJEn3/lXzCkPcR7myd3LIk0=; b=dvseoZtWpEMpDpqFzYRB2JtjflpVEhXdO8ZZNDL4h/sq9bN8fgHzlm/s qenodVdLyeab3U2gPiGTcAOzPgzJfO7WX/VRvXjJESGCaBLy3AljuJFmt OTR9+n9uoN/7MoXFsOhCdVGj/vRRttIuXu4QnSicKWp1B07KBfgMnMXyX uewkzUiqWhBg9pr68/iXVMeGxti8c1t8+OMxgepjas/YXg1hFHnEuGLKf 0eZJ/3X+AOUxfU5Af+ljUmrNr65hH6DO9dsHyrN5H4Eg2PRR4/XsbZKza qUj/bbrxvjr0KEX8GePB2GJM2IodEitFHb3xS/KV0DLyv7ddCQE5tDXl8 g==; X-IronPort-AV: E=McAfee;i="6600,9927,10955"; a="6850315" X-IronPort-AV: E=Sophos;i="6.05,200,1701158400"; d="scan'208";a="6850315" Received: from fmsmga006.fm.intel.com ([10.253.24.20]) by fmvoesa103.fm.intel.com with ESMTP/TLS/ECDHE-RSA-AES256-GCM-SHA384; 16 Jan 2024 23:21:55 -0800 X-ExtLoop1: 1 X-IronPort-AV: E=McAfee;i="6600,9927,10955"; a="1031256785" X-IronPort-AV: E=Sophos;i="6.05,200,1701158400"; d="scan'208";a="1031256785" Received: from wenbin-z390-aorus-ultra.sh.intel.com ([10.239.156.43]) by fmsmga006.fm.intel.com with ESMTP; 16 Jan 2024 23:21:53 -0800 From: wenbin.chen-at-intel.com@ffmpeg.org To: ffmpeg-devel@ffmpeg.org Date: Wed, 17 Jan 2024 15:21:50 +0800 Message-Id: <20240117072151.2155795-2-wenbin.chen@intel.com> X-Mailer: git-send-email 2.34.1 In-Reply-To: <20240117072151.2155795-1-wenbin.chen@intel.com> References: <20240117072151.2155795-1-wenbin.chen@intel.com> MIME-Version: 1.0 Subject: [FFmpeg-devel] [PATCH 2/3] libavfilter/dnn_interface: use dims to represent shapes X-BeenThere: ffmpeg-devel@ffmpeg.org X-Mailman-Version: 2.1.29 Precedence: list List-Id: FFmpeg development discussions and patches List-Unsubscribe: , List-Archive: List-Post: List-Help: List-Subscribe: , Reply-To: FFmpeg development discussions and patches Errors-To: ffmpeg-devel-bounces@ffmpeg.org Sender: "ffmpeg-devel" X-TUID: ad+aW3d9ZrnY From: Wenbin Chen For detect and classify output, width and height make no sence, so change width, height to dims to represent the shape of tensor. Use layout and dims to get width, height and channel. Signed-off-by: Wenbin Chen --- libavfilter/dnn/dnn_backend_openvino.c | 80 ++++++++++++++------------ libavfilter/dnn/dnn_backend_tf.c | 32 +++++++---- libavfilter/dnn/dnn_io_proc.c | 30 +++++++--- libavfilter/dnn_interface.h | 17 +++++- libavfilter/vf_dnn_classify.c | 6 +- libavfilter/vf_dnn_detect.c | 50 ++++++++-------- libavfilter/vf_dnn_processing.c | 21 ++++--- 7 files changed, 146 insertions(+), 90 deletions(-) diff --git a/libavfilter/dnn/dnn_backend_openvino.c b/libavfilter/dnn/dnn_backend_openvino.c index 590ddd586c..73b42c32b1 100644 --- a/libavfilter/dnn/dnn_backend_openvino.c +++ b/libavfilter/dnn/dnn_backend_openvino.c @@ -253,9 +253,9 @@ static int fill_model_input_ov(OVModel *ov_model, OVRequestItem *request) ov_shape_free(&input_shape); return ov2_map_error(status, NULL); } - input.height = dims[1]; - input.width = dims[2]; - input.channels = dims[3]; + for (int i = 0; i < input_shape.rank; i++) + input.dims[i] = dims[i]; + input.layout = DL_NHWC; input.dt = precision_to_datatype(precision); #else status = ie_infer_request_get_blob(request->infer_request, task->input_name, &input_blob); @@ -278,9 +278,9 @@ static int fill_model_input_ov(OVModel *ov_model, OVRequestItem *request) av_log(ctx, AV_LOG_ERROR, "Failed to get input blob buffer\n"); return DNN_GENERIC_ERROR; } - input.height = dims.dims[2]; - input.width = dims.dims[3]; - input.channels = dims.dims[1]; + for (int i = 0; i < input_shape.rank; i++) + input.dims[i] = dims[i]; + input.layout = DL_NCHW; input.data = blob_buffer.buffer; input.dt = precision_to_datatype(precision); #endif @@ -339,8 +339,8 @@ static int fill_model_input_ov(OVModel *ov_model, OVRequestItem *request) av_assert0(!"should not reach here"); break; } - input.data = (uint8_t *)input.data - + input.width * input.height * input.channels * get_datatype_size(input.dt); + input.data = (uint8_t *)input.data + + input.dims[1] * input.dims[2] * input.dims[3] * get_datatype_size(input.dt); } #if HAVE_OPENVINO2 ov_tensor_free(tensor); @@ -403,10 +403,11 @@ static void infer_completion_callback(void *args) goto end; } outputs[i].dt = precision_to_datatype(precision); - - outputs[i].channels = output_shape.rank > 2 ? dims[output_shape.rank - 3] : 1; - outputs[i].height = output_shape.rank > 1 ? dims[output_shape.rank - 2] : 1; - outputs[i].width = output_shape.rank > 0 ? dims[output_shape.rank - 1] : 1; + outputs[i].layout = DL_NCHW; + outputs[i].dims[0] = 1; + outputs[i].dims[1] = output_shape.rank > 2 ? dims[output_shape.rank - 3] : 1; + outputs[i].dims[2] = output_shape.rank > 1 ? dims[output_shape.rank - 2] : 1; + outputs[i].dims[3] = output_shape.rank > 0 ? dims[output_shape.rank - 1] : 1; av_assert0(request->lltask_count <= dims[0]); outputs[i].layout = ctx->options.layout; outputs[i].scale = ctx->options.scale; @@ -445,9 +446,9 @@ static void infer_completion_callback(void *args) return; } output.data = blob_buffer.buffer; - output.channels = dims.dims[1]; - output.height = dims.dims[2]; - output.width = dims.dims[3]; + output.layout = DL_NCHW; + for (int i = 0; i < 4; i++) + output.dims[i] = dims.dims[i]; av_assert0(request->lltask_count <= dims.dims[0]); output.dt = precision_to_datatype(precision); output.layout = ctx->options.layout; @@ -469,8 +470,10 @@ static void infer_completion_callback(void *args) ff_proc_from_dnn_to_frame(task->out_frame, outputs, ctx); } } else { - task->out_frame->width = outputs[0].width; - task->out_frame->height = outputs[0].height; + task->out_frame->width = + outputs[0].dims[dnn_get_width_idx_by_layout(outputs[0].layout)]; + task->out_frame->height = + outputs[0].dims[dnn_get_height_idx_by_layout(outputs[0].layout)]; } break; case DFT_ANALYTICS_DETECT: @@ -501,7 +504,8 @@ static void infer_completion_callback(void *args) av_freep(&request->lltasks[i]); for (int i = 0; i < ov_model->nb_outputs; i++) outputs[i].data = (uint8_t *)outputs[i].data + - outputs[i].width * outputs[i].height * outputs[i].channels * get_datatype_size(outputs[i].dt); + outputs[i].dims[1] * outputs[i].dims[2] * outputs[i].dims[3] * + get_datatype_size(outputs[i].dt); } end: #if HAVE_OPENVINO2 @@ -1085,7 +1089,6 @@ static int get_input_ov(void *model, DNNData *input, const char *input_name) #if HAVE_OPENVINO2 ov_shape_t input_shape = {0}; ov_element_type_e precision; - int64_t* dims; ov_status_e status; if (input_name) status = ov_model_const_input_by_name(ov_model->ov_model, input_name, &ov_model->input_port); @@ -1105,16 +1108,18 @@ static int get_input_ov(void *model, DNNData *input, const char *input_name) av_log(ctx, AV_LOG_ERROR, "Failed to get input port shape.\n"); return ov2_map_error(status, NULL); } - dims = input_shape.dims; - if (dims[1] <= 3) { // NCHW - input->channels = dims[1]; - input->height = input_resizable ? -1 : dims[2]; - input->width = input_resizable ? -1 : dims[3]; - } else { // NHWC - input->height = input_resizable ? -1 : dims[1]; - input->width = input_resizable ? -1 : dims[2]; - input->channels = dims[3]; + for (int i = 0; i < 4; i++) + input->dims[i] = input_shape.dims[i]; + if (input_resizable) { + input->dims[dnn_get_width_idx_by_layout(input->layout)] = -1; + input->dims[dnn_get_height_idx_by_layout(input->layout)] = -1; } + + if (input_shape.dims[1] <= 3) // NCHW + input->layout = DL_NCHW; + else // NHWC + input->layout = DL_NHWC; + input->dt = precision_to_datatype(precision); ov_shape_free(&input_shape); return 0; @@ -1144,15 +1149,18 @@ static int get_input_ov(void *model, DNNData *input, const char *input_name) return DNN_GENERIC_ERROR; } - if (dims[1] <= 3) { // NCHW - input->channels = dims[1]; - input->height = input_resizable ? -1 : dims[2]; - input->width = input_resizable ? -1 : dims[3]; - } else { // NHWC - input->height = input_resizable ? -1 : dims[1]; - input->width = input_resizable ? -1 : dims[2]; - input->channels = dims[3]; + for (int i = 0; i < 4; i++) + input->dims[i] = input_shape.dims[i]; + if (input_resizable) { + input->dims[dnn_get_width_idx_by_layout(input->layout)] = -1; + input->dims[dnn_get_height_idx_by_layout(input->layout)] = -1; } + + if (input_shape.dims[1] <= 3) // NCHW + input->layout = DL_NCHW; + else // NHWC + input->layout = DL_NHWC; + input->dt = precision_to_datatype(precision); return 0; } diff --git a/libavfilter/dnn/dnn_backend_tf.c b/libavfilter/dnn/dnn_backend_tf.c index 25046b58d9..27c5178bb5 100644 --- a/libavfilter/dnn/dnn_backend_tf.c +++ b/libavfilter/dnn/dnn_backend_tf.c @@ -251,7 +251,12 @@ static TF_Tensor *allocate_input_tensor(const DNNData *input) { TF_DataType dt; size_t size; - int64_t input_dims[] = {1, input->height, input->width, input->channels}; + int64_t input_dims[4] = { 0 }; + + input_dims[0] = 1; + input_dims[1] = input->dims[dnn_get_height_idx_by_layout(input->layout)]; + input_dims[2] = input->dims[dnn_get_width_idx_by_layout(input->layout)]; + input_dims[3] = input->dims[dnn_get_channel_idx_by_layout(input->layout)]; switch (input->dt) { case DNN_FLOAT: dt = TF_FLOAT; @@ -310,9 +315,9 @@ static int get_input_tf(void *model, DNNData *input, const char *input_name) // currently only NHWC is supported av_assert0(dims[0] == 1 || dims[0] == -1); - input->height = dims[1]; - input->width = dims[2]; - input->channels = dims[3]; + for (int i = 0; i < 4; i++) + input->dims[i] = dims[i]; + input->layout = DL_NHWC; return 0; } @@ -640,8 +645,8 @@ static int fill_model_input_tf(TFModel *tf_model, TFRequestItem *request) { } infer_request = request->infer_request; - input.height = task->in_frame->height; - input.width = task->in_frame->width; + input.dims[1] = task->in_frame->height; + input.dims[2] = task->in_frame->width; infer_request->tf_input = av_malloc(sizeof(TF_Output)); if (!infer_request->tf_input) { @@ -731,9 +736,12 @@ static void infer_completion_callback(void *args) { } for (uint32_t i = 0; i < task->nb_output; ++i) { - outputs[i].height = TF_Dim(infer_request->output_tensors[i], 1); - outputs[i].width = TF_Dim(infer_request->output_tensors[i], 2); - outputs[i].channels = TF_Dim(infer_request->output_tensors[i], 3); + outputs[i].dims[dnn_get_height_idx_by_layout(outputs[i].layout)] = + TF_Dim(infer_request->output_tensors[i], 1); + outputs[i].dims[dnn_get_width_idx_by_layout(outputs[i].layout)] = + TF_Dim(infer_request->output_tensors[i], 2); + outputs[i].dims[dnn_get_channel_idx_by_layout(outputs[i].layout)] = + TF_Dim(infer_request->output_tensors[i], 3); outputs[i].data = TF_TensorData(infer_request->output_tensors[i]); outputs[i].dt = (DNNDataType)TF_TensorType(infer_request->output_tensors[i]); } @@ -747,8 +755,10 @@ static void infer_completion_callback(void *args) { ff_proc_from_dnn_to_frame(task->out_frame, outputs, ctx); } } else { - task->out_frame->width = outputs[0].width; - task->out_frame->height = outputs[0].height; + task->out_frame->width = + outputs[0].dims[dnn_get_width_idx_by_layout(outputs[0].layout)]; + task->out_frame->height = + outputs[0].dims[dnn_get_height_idx_by_layout(outputs[0].layout)]; } break; case DFT_ANALYTICS_DETECT: diff --git a/libavfilter/dnn/dnn_io_proc.c b/libavfilter/dnn/dnn_io_proc.c index ab656e8ed7..e5d6edb301 100644 --- a/libavfilter/dnn/dnn_io_proc.c +++ b/libavfilter/dnn/dnn_io_proc.c @@ -70,7 +70,7 @@ int ff_proc_from_dnn_to_frame(AVFrame *frame, DNNData *output, void *log_ctx) dst_data = (void **)frame->data; linesize[0] = frame->linesize[0]; if (output->layout == DL_NCHW) { - middle_data = av_malloc(plane_size * output->channels); + middle_data = av_malloc(plane_size * output->dims[1]); if (!middle_data) { ret = AVERROR(ENOMEM); goto err; @@ -209,7 +209,7 @@ int ff_proc_from_frame_to_dnn(AVFrame *frame, DNNData *input, void *log_ctx) src_data = (void **)frame->data; linesize[0] = frame->linesize[0]; if (input->layout == DL_NCHW) { - middle_data = av_malloc(plane_size * input->channels); + middle_data = av_malloc(plane_size * input->dims[1]); if (!middle_data) { ret = AVERROR(ENOMEM); goto err; @@ -346,6 +346,7 @@ int ff_frame_to_dnn_classify(AVFrame *frame, DNNData *input, uint32_t bbox_index int ret = 0; enum AVPixelFormat fmt; int left, top, width, height; + int width_idx, height_idx; const AVDetectionBBoxHeader *header; const AVDetectionBBox *bbox; AVFrameSideData *sd = av_frame_get_side_data(frame, AV_FRAME_DATA_DETECTION_BBOXES); @@ -364,6 +365,9 @@ int ff_frame_to_dnn_classify(AVFrame *frame, DNNData *input, uint32_t bbox_index return AVERROR(ENOSYS); } + width_idx = dnn_get_width_idx_by_layout(input->layout); + height_idx = dnn_get_height_idx_by_layout(input->layout); + header = (const AVDetectionBBoxHeader *)sd->data; bbox = av_get_detection_bbox(header, bbox_index); @@ -374,17 +378,20 @@ int ff_frame_to_dnn_classify(AVFrame *frame, DNNData *input, uint32_t bbox_index fmt = get_pixel_format(input); sws_ctx = sws_getContext(width, height, frame->format, - input->width, input->height, fmt, + input->dims[width_idx], + input->dims[height_idx], fmt, SWS_FAST_BILINEAR, NULL, NULL, NULL); if (!sws_ctx) { av_log(log_ctx, AV_LOG_ERROR, "Failed to create scale context for the conversion " "fmt:%s s:%dx%d -> fmt:%s s:%dx%d\n", av_get_pix_fmt_name(frame->format), width, height, - av_get_pix_fmt_name(fmt), input->width, input->height); + av_get_pix_fmt_name(fmt), + input->dims[width_idx], + input->dims[height_idx]); return AVERROR(EINVAL); } - ret = av_image_fill_linesizes(linesizes, fmt, input->width); + ret = av_image_fill_linesizes(linesizes, fmt, input->dims[width_idx]); if (ret < 0) { av_log(log_ctx, AV_LOG_ERROR, "unable to get linesizes with av_image_fill_linesizes"); sws_freeContext(sws_ctx); @@ -414,7 +421,7 @@ int ff_frame_to_dnn_detect(AVFrame *frame, DNNData *input, void *log_ctx) { struct SwsContext *sws_ctx; int linesizes[4]; - int ret = 0; + int ret = 0, width_idx, height_idx; enum AVPixelFormat fmt = get_pixel_format(input); /* (scale != 1 and scale != 0) or mean != 0 */ @@ -430,18 +437,23 @@ int ff_frame_to_dnn_detect(AVFrame *frame, DNNData *input, void *log_ctx) return AVERROR(ENOSYS); } + width_idx = dnn_get_width_idx_by_layout(input->layout); + height_idx = dnn_get_height_idx_by_layout(input->layout); + sws_ctx = sws_getContext(frame->width, frame->height, frame->format, - input->width, input->height, fmt, + input->dims[width_idx], + input->dims[height_idx], fmt, SWS_FAST_BILINEAR, NULL, NULL, NULL); if (!sws_ctx) { av_log(log_ctx, AV_LOG_ERROR, "Impossible to create scale context for the conversion " "fmt:%s s:%dx%d -> fmt:%s s:%dx%d\n", av_get_pix_fmt_name(frame->format), frame->width, frame->height, - av_get_pix_fmt_name(fmt), input->width, input->height); + av_get_pix_fmt_name(fmt), input->dims[width_idx], + input->dims[height_idx]); return AVERROR(EINVAL); } - ret = av_image_fill_linesizes(linesizes, fmt, input->width); + ret = av_image_fill_linesizes(linesizes, fmt, input->dims[width_idx]); if (ret < 0) { av_log(log_ctx, AV_LOG_ERROR, "unable to get linesizes with av_image_fill_linesizes"); sws_freeContext(sws_ctx); diff --git a/libavfilter/dnn_interface.h b/libavfilter/dnn_interface.h index 183d8418b2..852d88baa8 100644 --- a/libavfilter/dnn_interface.h +++ b/libavfilter/dnn_interface.h @@ -64,7 +64,7 @@ typedef enum { typedef struct DNNData{ void *data; - int width, height, channels; + int dims[4]; // dt and order together decide the color format DNNDataType dt; DNNColorOrder order; @@ -134,4 +134,19 @@ typedef struct DNNModule{ // Initializes DNNModule depending on chosen backend. const DNNModule *ff_get_dnn_module(DNNBackendType backend_type, void *log_ctx); +static inline int dnn_get_width_idx_by_layout(DNNLayout layout) +{ + return layout == DL_NHWC ? 2 : 3; +} + +static inline int dnn_get_height_idx_by_layout(DNNLayout layout) +{ + return layout == DL_NHWC ? 1 : 2; +} + +static inline int dnn_get_channel_idx_by_layout(DNNLayout layout) +{ + return layout == DL_NHWC ? 3 : 1; +} + #endif diff --git a/libavfilter/vf_dnn_classify.c b/libavfilter/vf_dnn_classify.c index e88e59d09c..d180c3b461 100644 --- a/libavfilter/vf_dnn_classify.c +++ b/libavfilter/vf_dnn_classify.c @@ -68,8 +68,8 @@ static int dnn_classify_post_proc(AVFrame *frame, DNNData *output, uint32_t bbox uint32_t label_id; float confidence; AVFrameSideData *sd; - - if (output->channels <= 0) { + int output_size = output->dims[3] * output->dims[2] * output->dims[1]; + if (output_size <= 0) { return -1; } @@ -88,7 +88,7 @@ static int dnn_classify_post_proc(AVFrame *frame, DNNData *output, uint32_t bbox classifications = output->data; label_id = 0; confidence= classifications[0]; - for (int i = 1; i < output->channels; i++) { + for (int i = 1; i < output_size; i++) { if (classifications[i] > confidence) { label_id = i; confidence= classifications[i]; diff --git a/libavfilter/vf_dnn_detect.c b/libavfilter/vf_dnn_detect.c index 249cbba0f7..caccbf7a12 100644 --- a/libavfilter/vf_dnn_detect.c +++ b/libavfilter/vf_dnn_detect.c @@ -166,14 +166,14 @@ static int dnn_detect_parse_yolo_output(AVFrame *frame, DNNData *output, int out scale_w = cell_w; scale_h = cell_h; } else { - if (output[output_index].height != output[output_index].width && - output[output_index].height == output[output_index].channels) { + if (output[output_index].dims[2] != output[output_index].dims[3] && + output[output_index].dims[2] == output[output_index].dims[1]) { is_NHWC = 1; - cell_w = output[output_index].height; - cell_h = output[output_index].channels; + cell_w = output[output_index].dims[2]; + cell_h = output[output_index].dims[1]; } else { - cell_w = output[output_index].width; - cell_h = output[output_index].height; + cell_w = output[output_index].dims[3]; + cell_h = output[output_index].dims[2]; } scale_w = ctx->scale_width; scale_h = ctx->scale_height; @@ -205,14 +205,14 @@ static int dnn_detect_parse_yolo_output(AVFrame *frame, DNNData *output, int out return AVERROR(EINVAL); } - if (output[output_index].channels * output[output_index].width * - output[output_index].height % (box_size * cell_w * cell_h)) { + if (output[output_index].dims[1] * output[output_index].dims[2] * + output[output_index].dims[3] % (box_size * cell_w * cell_h)) { av_log(filter_ctx, AV_LOG_ERROR, "wrong cell_w, cell_h or nb_classes\n"); return AVERROR(EINVAL); } - detection_boxes = output[output_index].channels * - output[output_index].height * - output[output_index].width / box_size / cell_w / cell_h; + detection_boxes = output[output_index].dims[1] * + output[output_index].dims[2] * + output[output_index].dims[3] / box_size / cell_w / cell_h; anchors = anchors + (detection_boxes * output_index * 2); /** @@ -373,18 +373,18 @@ static int dnn_detect_post_proc_ssd(AVFrame *frame, DNNData *output, int nb_outp int scale_w = ctx->scale_width; int scale_h = ctx->scale_height; - if (nb_outputs == 1 && output->width == 7) { - proposal_count = output->height; - detect_size = output->width; + if (nb_outputs == 1 && output->dims[3] == 7) { + proposal_count = output->dims[2]; + detect_size = output->dims[3]; detections = output->data; - } else if (nb_outputs == 2 && output[0].width == 5) { - proposal_count = output[0].height; - detect_size = output[0].width; + } else if (nb_outputs == 2 && output[0].dims[3] == 5) { + proposal_count = output[0].dims[2]; + detect_size = output[0].dims[3]; detections = output[0].data; labels = output[1].data; - } else if (nb_outputs == 2 && output[1].width == 5) { - proposal_count = output[1].height; - detect_size = output[1].width; + } else if (nb_outputs == 2 && output[1].dims[3] == 5) { + proposal_count = output[1].dims[2]; + detect_size = output[1].dims[3]; detections = output[1].data; labels = output[0].data; } else { @@ -821,15 +821,19 @@ static int config_input(AVFilterLink *inlink) AVFilterContext *context = inlink->dst; DnnDetectContext *ctx = context->priv; DNNData model_input; - int ret; + int ret, width_idx, height_idx; ret = ff_dnn_get_input(&ctx->dnnctx, &model_input); if (ret != 0) { av_log(ctx, AV_LOG_ERROR, "could not get input from the model\n"); return ret; } - ctx->scale_width = model_input.width == -1 ? inlink->w : model_input.width; - ctx->scale_height = model_input.height == -1 ? inlink->h : model_input.height; + width_idx = dnn_get_width_idx_by_layout(model_input.layout); + height_idx = dnn_get_height_idx_by_layout(model_input.layout); + ctx->scale_width = model_input.dims[width_idx] == -1 ? inlink->w : + model_input.dims[width_idx]; + ctx->scale_height = model_input.dims[height_idx] == -1 ? inlink->h : + model_input.dims[height_idx]; return 0; } diff --git a/libavfilter/vf_dnn_processing.c b/libavfilter/vf_dnn_processing.c index 6829e94585..0b70c8e024 100644 --- a/libavfilter/vf_dnn_processing.c +++ b/libavfilter/vf_dnn_processing.c @@ -77,22 +77,29 @@ static const enum AVPixelFormat pix_fmts[] = { "the frame's format %s does not match " \ "the model input channel %d\n", \ av_get_pix_fmt_name(fmt), \ - model_input->channels); + model_input->dims[dnn_get_channel_idx_by_layout(model_input->layout)]); static int check_modelinput_inlink(const DNNData *model_input, const AVFilterLink *inlink) { AVFilterContext *ctx = inlink->dst; enum AVPixelFormat fmt = inlink->format; + int width_idx, height_idx; + width_idx = dnn_get_width_idx_by_layout(model_input->layout); + height_idx = dnn_get_height_idx_by_layout(model_input->layout); // the design is to add explicit scale filter before this filter - if (model_input->height != -1 && model_input->height != inlink->h) { + if (model_input->dims[height_idx] != -1 && + model_input->dims[height_idx] != inlink->h) { av_log(ctx, AV_LOG_ERROR, "the model requires frame height %d but got %d\n", - model_input->height, inlink->h); + model_input->dims[height_idx], + inlink->h); return AVERROR(EIO); } - if (model_input->width != -1 && model_input->width != inlink->w) { + if (model_input->dims[width_idx] != -1 && + model_input->dims[width_idx] != inlink->w) { av_log(ctx, AV_LOG_ERROR, "the model requires frame width %d but got %d\n", - model_input->width, inlink->w); + model_input->dims[width_idx], + inlink->w); return AVERROR(EIO); } if (model_input->dt != DNN_FLOAT) { @@ -103,7 +110,7 @@ static int check_modelinput_inlink(const DNNData *model_input, const AVFilterLin switch (fmt) { case AV_PIX_FMT_RGB24: case AV_PIX_FMT_BGR24: - if (model_input->channels != 3) { + if (model_input->dims[dnn_get_channel_idx_by_layout(model_input->layout)] != 3) { LOG_FORMAT_CHANNEL_MISMATCH(); return AVERROR(EIO); } @@ -116,7 +123,7 @@ static int check_modelinput_inlink(const DNNData *model_input, const AVFilterLin case AV_PIX_FMT_YUV410P: case AV_PIX_FMT_YUV411P: case AV_PIX_FMT_NV12: - if (model_input->channels != 1) { + if (model_input->dims[dnn_get_channel_idx_by_layout(model_input->layout)] != 1) { LOG_FORMAT_CHANNEL_MISMATCH(); return AVERROR(EIO); }