From patchwork Thu Apr 29 13:36:56 2021 Content-Type: text/plain; charset="utf-8" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit X-Patchwork-Submitter: "Guo, Yejun" X-Patchwork-Id: 27476 Delivered-To: ffmpegpatchwork2@gmail.com Received: by 2002:a05:6a11:4023:0:0:0:0 with SMTP id ky35csp1493149pxb; Thu, 29 Apr 2021 06:50:07 -0700 (PDT) X-Google-Smtp-Source: ABdhPJxgPkm9MmBCEGMTdJCpkq1NaMzHx1eRdL1oN+LCurvgduMUnWdoUEX1jHMB1k+qLIzZHcxI X-Received: by 2002:a17:906:5a83:: with SMTP id l3mr34734995ejq.50.1619704206831; Thu, 29 Apr 2021 06:50:06 -0700 (PDT) ARC-Seal: i=1; a=rsa-sha256; t=1619704206; cv=none; d=google.com; s=arc-20160816; b=08TU8XW+7xBudwKpxmMkyQ13Dt5B11PoXOcPQGT5PAvpchAlQ+wsUCOQuddwmy1KWn rKsRhI5axaJ9lK+Hm3F9yLE4+OHgoCIl4bhav5mgzBZ2d/twbxolUMzZBLx2q2UHRhfp 6GCJ0cGF8hiwS5WnH+xQPwsh3rg/jQCzYVPXOzc3HpIreOlbDL1HJNZsIvhz+eeyUP3C jlUtJKgUi13Q0I5Qgqop96pQwYjW/JGw9o/p2sH1QSxWjPZ2IWaMwO72hN3NL4bgn02c ctv3cg9oqdvbYQk1BrX//UjffAaBCZrbr2vBIQycyYAiU0UjWRYWJZraROPoFM8aYmtw oE6A== ARC-Message-Signature: i=1; a=rsa-sha256; c=relaxed/relaxed; d=google.com; s=arc-20160816; h=sender:errors-to:content-transfer-encoding:mime-version:cc:reply-to :list-subscribe:list-help:list-post:list-archive:list-unsubscribe :list-id:precedence:subject:references:in-reply-to:message-id:date :to:from:ironport-sdr:ironport-sdr:delivered-to; bh=j+AZXP17gl3JyByw31Wqk5lgB2rl185EtMdRGvi2hK4=; b=u0HbeGAINGw4fffrmeLWOSnW7y0yrDC43xoolgBbgIWAEvWQb91YOqFZ6e2aAVRQCO YFao4OOUWOOzm7K6Y2GFsqnxlUlQkDbftgIXmwBIMzoJ61hu+UPgR9XZPAjLKG2r8RNm KqpnUCH+5qnnpm8GJusidxC0w/A9aZE7TCQ0fnNDE9eEaSloJa3gJjgJVhtnC3jy9CEQ h5BhPelN9lREbtV0Wg+FYwCcucY9JaCWE0qvqP+ug2PeKj1Yzaae831laXPlyMNoqOUs QvgVSTpD1xOTsdMIv1VGOrHHYfdwicb6gFu1XyHma4sOcQlKlnV3ecKwPmEZtHucmBFe k9jQ== ARC-Authentication-Results: i=1; mx.google.com; spf=pass (google.com: domain of ffmpeg-devel-bounces@ffmpeg.org designates 79.124.17.100 as permitted sender) smtp.mailfrom=ffmpeg-devel-bounces@ffmpeg.org; dmarc=fail (p=NONE sp=NONE dis=NONE) header.from=intel.com Return-Path: Received: from ffbox0-bg.mplayerhq.hu (ffbox0-bg.ffmpeg.org. [79.124.17.100]) by mx.google.com with ESMTP id i25si59771ejc.18.2021.04.29.06.50.05; Thu, 29 Apr 2021 06:50:06 -0700 (PDT) Received-SPF: pass (google.com: domain of ffmpeg-devel-bounces@ffmpeg.org designates 79.124.17.100 as permitted sender) client-ip=79.124.17.100; Authentication-Results: mx.google.com; spf=pass (google.com: domain of ffmpeg-devel-bounces@ffmpeg.org designates 79.124.17.100 as permitted sender) smtp.mailfrom=ffmpeg-devel-bounces@ffmpeg.org; dmarc=fail (p=NONE sp=NONE dis=NONE) header.from=intel.com Received: from [127.0.1.1] (localhost [127.0.0.1]) by ffbox0-bg.mplayerhq.hu (Postfix) with ESMTP id 60CFC68A24C; Thu, 29 Apr 2021 16:49:29 +0300 (EEST) X-Original-To: ffmpeg-devel@ffmpeg.org Delivered-To: ffmpeg-devel@ffmpeg.org Received: from mga07.intel.com (mga07.intel.com [134.134.136.100]) by ffbox0-bg.mplayerhq.hu (Postfix) with ESMTPS id 88FC868A0D3 for ; Thu, 29 Apr 2021 16:49:20 +0300 (EEST) IronPort-SDR: oatIjSwVcV14Qc6MnQyKIVK1XXMASSxMvf/a/4DiieiY7AB2aIXkSIYJPQ1Z5aZwqHi9/Ce+IA D49JLixDqN7A== X-IronPort-AV: E=McAfee;i="6200,9189,9969"; a="260956502" X-IronPort-AV: E=Sophos;i="5.82,259,1613462400"; d="scan'208";a="260956502" Received: from fmsmga008.fm.intel.com ([10.253.24.58]) by orsmga105.jf.intel.com with ESMTP/TLS/ECDHE-RSA-AES256-GCM-SHA384; 29 Apr 2021 06:49:12 -0700 IronPort-SDR: 0NWM39uNxNbg7to2DY2RVDY7tIp5iSO3KfBG3Y0ccgDUVVlmIPk6BoaY+rMcMeiBO+OeB6SLZQ moNIKgBA20bQ== X-ExtLoop1: 1 X-IronPort-AV: E=Sophos;i="5.82,259,1613462400"; d="scan'208";a="424096074" Received: from yguo18-skl-u1604.sh.intel.com ([10.239.159.53]) by fmsmga008.fm.intel.com with ESMTP; 29 Apr 2021 06:49:11 -0700 From: "Guo, Yejun" To: ffmpeg-devel@ffmpeg.org Date: Thu, 29 Apr 2021 21:36:56 +0800 Message-Id: <20210429133657.23076-5-yejun.guo@intel.com> X-Mailer: git-send-email 2.17.1 In-Reply-To: <20210429133657.23076-1-yejun.guo@intel.com> References: <20210429133657.23076-1-yejun.guo@intel.com> Subject: [FFmpeg-devel] [PATCH V2 5/6] lavfi/dnn: add classify support with openvino backend X-BeenThere: ffmpeg-devel@ffmpeg.org X-Mailman-Version: 2.1.29 Precedence: list List-Id: FFmpeg development discussions and patches List-Unsubscribe: , List-Archive: List-Post: List-Help: List-Subscribe: , Reply-To: FFmpeg development discussions and patches Cc: yejun.guo@intel.com MIME-Version: 1.0 Errors-To: ffmpeg-devel-bounces@ffmpeg.org Sender: "ffmpeg-devel" X-TUID: d3ce6Cse9cXz Signed-off-by: Guo, Yejun --- libavfilter/dnn/dnn_backend_openvino.c | 143 +++++++++++++++++++++---- libavfilter/dnn/dnn_io_proc.c | 60 +++++++++++ libavfilter/dnn/dnn_io_proc.h | 1 + libavfilter/dnn_filter_common.c | 21 ++++ libavfilter/dnn_filter_common.h | 2 + libavfilter/dnn_interface.h | 10 +- 6 files changed, 218 insertions(+), 19 deletions(-) diff --git a/libavfilter/dnn/dnn_backend_openvino.c b/libavfilter/dnn/dnn_backend_openvino.c index 4e58ff6d9c..1ff8a720b9 100644 --- a/libavfilter/dnn/dnn_backend_openvino.c +++ b/libavfilter/dnn/dnn_backend_openvino.c @@ -29,6 +29,7 @@ #include "libavutil/avassert.h" #include "libavutil/opt.h" #include "libavutil/avstring.h" +#include "libavutil/detection_bbox.h" #include "../internal.h" #include "queue.h" #include "safe_queue.h" @@ -74,6 +75,7 @@ typedef struct TaskItem { // one task might have multiple inferences typedef struct InferenceItem { TaskItem *task; + uint32_t bbox_index; } InferenceItem; // one request for one call to openvino @@ -182,12 +184,23 @@ static DNNReturnType fill_model_input_ov(OVModel *ov_model, RequestItem *request request->inferences[i] = inference; request->inference_count = i + 1; task = inference->task; - if (task->do_ioproc) { - if (ov_model->model->frame_pre_proc != NULL) { - ov_model->model->frame_pre_proc(task->in_frame, &input, ov_model->model->filter_ctx); - } else { - ff_proc_from_frame_to_dnn(task->in_frame, &input, ov_model->model->func_type, ctx); + switch (task->ov_model->model->func_type) { + case DFT_PROCESS_FRAME: + case DFT_ANALYTICS_DETECT: + if (task->do_ioproc) { + if (ov_model->model->frame_pre_proc != NULL) { + ov_model->model->frame_pre_proc(task->in_frame, &input, ov_model->model->filter_ctx); + } else { + ff_proc_from_frame_to_dnn(task->in_frame, &input, ov_model->model->func_type, ctx); + } } + break; + case DFT_ANALYTICS_CLASSIFY: + ff_frame_to_dnn_classify(task->in_frame, &input, inference->bbox_index, ctx); + break; + default: + av_assert0(!"should not reach here"); + break; } input.data = (uint8_t *)input.data + input.width * input.height * input.channels * get_datatype_size(input.dt); @@ -276,6 +289,13 @@ static void infer_completion_callback(void *args) } task->ov_model->model->detect_post_proc(task->out_frame, &output, 1, task->ov_model->model->filter_ctx); break; + case DFT_ANALYTICS_CLASSIFY: + if (!task->ov_model->model->classify_post_proc) { + av_log(ctx, AV_LOG_ERROR, "classify filter needs to provide post proc\n"); + return; + } + task->ov_model->model->classify_post_proc(task->out_frame, &output, request->inferences[i]->bbox_index, task->ov_model->model->filter_ctx); + break; default: av_assert0(!"should not reach here"); break; @@ -513,7 +533,44 @@ static DNNReturnType get_input_ov(void *model, DNNData *input, const char *input return DNN_ERROR; } -static DNNReturnType extract_inference_from_task(DNNFunctionType func_type, TaskItem *task, Queue *inference_queue) +static int contain_valid_detection_bbox(AVFrame *frame) +{ + AVFrameSideData *sd; + const AVDetectionBBoxHeader *header; + const AVDetectionBBox *bbox; + + sd = av_frame_get_side_data(frame, AV_FRAME_DATA_DETECTION_BBOXES); + if (!sd) { // this frame has nothing detected + return 0; + } + + if (!sd->size) { + return 0; + } + + header = (const AVDetectionBBoxHeader *)sd->data; + if (!header->nb_bboxes) { + return 0; + } + + for (uint32_t i = 0; i < header->nb_bboxes; i++) { + bbox = av_get_detection_bbox(header, i); + if (bbox->x < 0 || bbox->w < 0 || bbox->x + bbox->w >= frame->width) { + return 0; + } + if (bbox->y < 0 || bbox->h < 0 || bbox->y + bbox->h >= frame->width) { + return 0; + } + + if (bbox->classify_count == AV_NUM_DETECTION_BBOX_CLASSIFY) { + return 0; + } + } + + return 1; +} + +static DNNReturnType extract_inference_from_task(DNNFunctionType func_type, TaskItem *task, Queue *inference_queue, DNNExecBaseParams *exec_params) { switch (func_type) { case DFT_PROCESS_FRAME: @@ -532,6 +589,45 @@ static DNNReturnType extract_inference_from_task(DNNFunctionType func_type, Task } return DNN_SUCCESS; } + case DFT_ANALYTICS_CLASSIFY: + { + const AVDetectionBBoxHeader *header; + AVFrame *frame = task->in_frame; + AVFrameSideData *sd; + DNNExecClassificationParams *params = (DNNExecClassificationParams *)exec_params; + + task->inference_todo = 0; + task->inference_done = 0; + + if (!contain_valid_detection_bbox(frame)) { + return DNN_SUCCESS; + } + + sd = av_frame_get_side_data(frame, AV_FRAME_DATA_DETECTION_BBOXES); + header = (const AVDetectionBBoxHeader *)sd->data; + + for (uint32_t i = 0; i < header->nb_bboxes; i++) { + InferenceItem *inference; + const AVDetectionBBox *bbox = av_get_detection_bbox(header, i); + + if (av_strncasecmp(bbox->detect_label, params->target, sizeof(bbox->detect_label)) != 0) { + continue; + } + + inference = av_malloc(sizeof(*inference)); + if (!inference) { + return DNN_ERROR; + } + task->inference_todo++; + inference->task = task; + inference->bbox_index = i; + if (ff_queue_push_back(inference_queue, inference) < 0) { + av_freep(&inference); + return DNN_ERROR; + } + } + return DNN_SUCCESS; + } default: av_assert0(!"should not reach here"); return DNN_ERROR; @@ -598,7 +694,7 @@ static DNNReturnType get_output_ov(void *model, const char *input_name, int inpu task.out_frame = out_frame; task.ov_model = ov_model; - if (extract_inference_from_task(ov_model->model->func_type, &task, ov_model->inference_queue) != DNN_SUCCESS) { + if (extract_inference_from_task(ov_model->model->func_type, &task, ov_model->inference_queue, NULL) != DNN_SUCCESS) { av_frame_free(&out_frame); av_frame_free(&in_frame); av_log(ctx, AV_LOG_ERROR, "unable to extract inference from task.\n"); @@ -690,6 +786,14 @@ DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, DNNExecBaseParams * return DNN_ERROR; } + if (model->func_type == DFT_ANALYTICS_CLASSIFY) { + // Once we add async support for tensorflow backend and native backend, + // we'll combine the two sync/async functions in dnn_interface.h to + // simplify the code in filter, and async will be an option within backends. + // so, do not support now, and classify filter will not call this function. + return DNN_ERROR; + } + if (ctx->options.batch_size > 1) { avpriv_report_missing_feature(ctx, "batch mode for sync execution"); return DNN_ERROR; @@ -710,7 +814,7 @@ DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, DNNExecBaseParams * task.out_frame = exec_params->out_frame ? exec_params->out_frame : exec_params->in_frame; task.ov_model = ov_model; - if (extract_inference_from_task(ov_model->model->func_type, &task, ov_model->inference_queue) != DNN_SUCCESS) { + if (extract_inference_from_task(ov_model->model->func_type, &task, ov_model->inference_queue, exec_params) != DNN_SUCCESS) { av_log(ctx, AV_LOG_ERROR, "unable to extract inference from task.\n"); return DNN_ERROR; } @@ -730,6 +834,7 @@ DNNReturnType ff_dnn_execute_model_async_ov(const DNNModel *model, DNNExecBasePa OVContext *ctx = &ov_model->ctx; RequestItem *request; TaskItem *task; + DNNReturnType ret; if (ff_check_exec_params(ctx, DNN_OV, model->func_type, exec_params) != 0) { return DNN_ERROR; @@ -761,23 +866,25 @@ DNNReturnType ff_dnn_execute_model_async_ov(const DNNModel *model, DNNExecBasePa return DNN_ERROR; } - if (extract_inference_from_task(ov_model->model->func_type, task, ov_model->inference_queue) != DNN_SUCCESS) { + if (extract_inference_from_task(model->func_type, task, ov_model->inference_queue, exec_params) != DNN_SUCCESS) { av_log(ctx, AV_LOG_ERROR, "unable to extract inference from task.\n"); return DNN_ERROR; } - if (ff_queue_size(ov_model->inference_queue) < ctx->options.batch_size) { - // not enough inference items queued for a batch - return DNN_SUCCESS; - } + while (ff_queue_size(ov_model->inference_queue) >= ctx->options.batch_size) { + request = ff_safe_queue_pop_front(ov_model->request_queue); + if (!request) { + av_log(ctx, AV_LOG_ERROR, "unable to get infer request.\n"); + return DNN_ERROR; + } - request = ff_safe_queue_pop_front(ov_model->request_queue); - if (!request) { - av_log(ctx, AV_LOG_ERROR, "unable to get infer request.\n"); - return DNN_ERROR; + ret = execute_model_ov(request, ov_model->inference_queue); + if (ret != DNN_SUCCESS) { + return ret; + } } - return execute_model_ov(request, ov_model->inference_queue); + return DNN_SUCCESS; } DNNAsyncStatusType ff_dnn_get_async_result_ov(const DNNModel *model, AVFrame **in, AVFrame **out) diff --git a/libavfilter/dnn/dnn_io_proc.c b/libavfilter/dnn/dnn_io_proc.c index e104cc5064..5f60d68078 100644 --- a/libavfilter/dnn/dnn_io_proc.c +++ b/libavfilter/dnn/dnn_io_proc.c @@ -22,6 +22,7 @@ #include "libavutil/imgutils.h" #include "libswscale/swscale.h" #include "libavutil/avassert.h" +#include "libavutil/detection_bbox.h" DNNReturnType ff_proc_from_dnn_to_frame(AVFrame *frame, DNNData *output, void *log_ctx) { @@ -175,6 +176,65 @@ static enum AVPixelFormat get_pixel_format(DNNData *data) return AV_PIX_FMT_BGR24; } +DNNReturnType ff_frame_to_dnn_classify(AVFrame *frame, DNNData *input, uint32_t bbox_index, void *log_ctx) +{ + const AVPixFmtDescriptor *desc; + int offsetx[4], offsety[4]; + uint8_t *bbox_data[4]; + struct SwsContext *sws_ctx; + int linesizes[4]; + enum AVPixelFormat fmt; + int left, top, width, height; + const AVDetectionBBoxHeader *header; + const AVDetectionBBox *bbox; + AVFrameSideData *sd = av_frame_get_side_data(frame, AV_FRAME_DATA_DETECTION_BBOXES); + av_assert0(sd); + + header = (const AVDetectionBBoxHeader *)sd->data; + bbox = av_get_detection_bbox(header, bbox_index); + + left = bbox->x; + width = bbox->w; + top = bbox->y; + height = bbox->h; + + fmt = get_pixel_format(input); + sws_ctx = sws_getContext(width, height, frame->format, + input->width, input->height, fmt, + SWS_FAST_BILINEAR, NULL, NULL, NULL); + if (!sws_ctx) { + av_log(log_ctx, AV_LOG_ERROR, "Failed to create scale context for the conversion " + "fmt:%s s:%dx%d -> fmt:%s s:%dx%d\n", + av_get_pix_fmt_name(frame->format), width, height, + av_get_pix_fmt_name(fmt), input->width, input->height); + return DNN_ERROR; + } + + if (av_image_fill_linesizes(linesizes, fmt, input->width) < 0) { + av_log(log_ctx, AV_LOG_ERROR, "unable to get linesizes with av_image_fill_linesizes"); + sws_freeContext(sws_ctx); + return DNN_ERROR; + } + + desc = av_pix_fmt_desc_get(frame->format); + offsetx[1] = offsetx[2] = AV_CEIL_RSHIFT(left, desc->log2_chroma_w); + offsetx[0] = offsetx[3] = left; + + offsety[1] = offsety[2] = AV_CEIL_RSHIFT(top, desc->log2_chroma_h); + offsety[0] = offsety[3] = top; + + for (int k = 0; frame->data[k]; k++) + bbox_data[k] = frame->data[k] + offsety[k] * frame->linesize[k] + offsetx[k]; + + sws_scale(sws_ctx, (const uint8_t *const *)&bbox_data, frame->linesize, + 0, height, + (uint8_t *const *)(&input->data), linesizes); + + sws_freeContext(sws_ctx); + + return DNN_SUCCESS; +} + static DNNReturnType proc_from_frame_to_dnn_analytics(AVFrame *frame, DNNData *input, void *log_ctx) { struct SwsContext *sws_ctx; diff --git a/libavfilter/dnn/dnn_io_proc.h b/libavfilter/dnn/dnn_io_proc.h index 91ad3cb261..16dcdd6d1a 100644 --- a/libavfilter/dnn/dnn_io_proc.h +++ b/libavfilter/dnn/dnn_io_proc.h @@ -32,5 +32,6 @@ DNNReturnType ff_proc_from_frame_to_dnn(AVFrame *frame, DNNData *input, DNNFunctionType func_type, void *log_ctx); DNNReturnType ff_proc_from_dnn_to_frame(AVFrame *frame, DNNData *output, void *log_ctx); +DNNReturnType ff_frame_to_dnn_classify(AVFrame *frame, DNNData *input, uint32_t bbox_index, void *log_ctx); #endif diff --git a/libavfilter/dnn_filter_common.c b/libavfilter/dnn_filter_common.c index c085884eb4..52c7a5392a 100644 --- a/libavfilter/dnn_filter_common.c +++ b/libavfilter/dnn_filter_common.c @@ -77,6 +77,12 @@ int ff_dnn_set_detect_post_proc(DnnContext *ctx, DetectPostProc post_proc) return 0; } +int ff_dnn_set_classify_post_proc(DnnContext *ctx, ClassifyPostProc post_proc) +{ + ctx->model->classify_post_proc = post_proc; + return 0; +} + DNNReturnType ff_dnn_get_input(DnnContext *ctx, DNNData *input) { return ctx->model->get_input(ctx->model->model, input, ctx->model_inputname); @@ -112,6 +118,21 @@ DNNReturnType ff_dnn_execute_model_async(DnnContext *ctx, AVFrame *in_frame, AVF return (ctx->dnn_module->execute_model_async)(ctx->model, &exec_params); } +DNNReturnType ff_dnn_execute_model_classification(DnnContext *ctx, AVFrame *in_frame, AVFrame *out_frame, char *target) +{ + DNNExecClassificationParams class_params = { + { + .input_name = ctx->model_inputname, + .output_names = (const char **)&ctx->model_outputname, + .nb_output = 1, + .in_frame = in_frame, + .out_frame = out_frame, + }, + .target = target, + }; + return (ctx->dnn_module->execute_model_async)(ctx->model, &class_params.base); +} + DNNAsyncStatusType ff_dnn_get_async_result(DnnContext *ctx, AVFrame **in_frame, AVFrame **out_frame) { return (ctx->dnn_module->get_async_result)(ctx->model, in_frame, out_frame); diff --git a/libavfilter/dnn_filter_common.h b/libavfilter/dnn_filter_common.h index 8deb18b39a..e7736d2bac 100644 --- a/libavfilter/dnn_filter_common.h +++ b/libavfilter/dnn_filter_common.h @@ -50,10 +50,12 @@ typedef struct DnnContext { int ff_dnn_init(DnnContext *ctx, DNNFunctionType func_type, AVFilterContext *filter_ctx); int ff_dnn_set_frame_proc(DnnContext *ctx, FramePrePostProc pre_proc, FramePrePostProc post_proc); int ff_dnn_set_detect_post_proc(DnnContext *ctx, DetectPostProc post_proc); +int ff_dnn_set_classify_post_proc(DnnContext *ctx, ClassifyPostProc post_proc); DNNReturnType ff_dnn_get_input(DnnContext *ctx, DNNData *input); DNNReturnType ff_dnn_get_output(DnnContext *ctx, int input_width, int input_height, int *output_width, int *output_height); DNNReturnType ff_dnn_execute_model(DnnContext *ctx, AVFrame *in_frame, AVFrame *out_frame); DNNReturnType ff_dnn_execute_model_async(DnnContext *ctx, AVFrame *in_frame, AVFrame *out_frame); +DNNReturnType ff_dnn_execute_model_classification(DnnContext *ctx, AVFrame *in_frame, AVFrame *out_frame, char *target); DNNAsyncStatusType ff_dnn_get_async_result(DnnContext *ctx, AVFrame **in_frame, AVFrame **out_frame); DNNReturnType ff_dnn_flush(DnnContext *ctx); void ff_dnn_uninit(DnnContext *ctx); diff --git a/libavfilter/dnn_interface.h b/libavfilter/dnn_interface.h index 941670675d..799244ee14 100644 --- a/libavfilter/dnn_interface.h +++ b/libavfilter/dnn_interface.h @@ -52,7 +52,7 @@ typedef enum { DFT_NONE, DFT_PROCESS_FRAME, // process the whole frame DFT_ANALYTICS_DETECT, // detect from the whole frame - // we can add more such as detect_from_crop, classify_from_bbox, etc. + DFT_ANALYTICS_CLASSIFY, // classify for each bounding box }DNNFunctionType; typedef struct DNNData{ @@ -71,8 +71,14 @@ typedef struct DNNExecBaseParams { AVFrame *out_frame; } DNNExecBaseParams; +typedef struct DNNExecClassificationParams { + DNNExecBaseParams base; + const char *target; +} DNNExecClassificationParams; + typedef int (*FramePrePostProc)(AVFrame *frame, DNNData *model, AVFilterContext *filter_ctx); typedef int (*DetectPostProc)(AVFrame *frame, DNNData *output, uint32_t nb, AVFilterContext *filter_ctx); +typedef int (*ClassifyPostProc)(AVFrame *frame, DNNData *output, uint32_t bbox_index, AVFilterContext *filter_ctx); typedef struct DNNModel{ // Stores model that can be different for different backends. @@ -97,6 +103,8 @@ typedef struct DNNModel{ FramePrePostProc frame_post_proc; // set the post process to interpret detect result from DNNData DetectPostProc detect_post_proc; + // set the post process to interpret classify result from DNNData + ClassifyPostProc classify_post_proc; } DNNModel; // Stores pointers to functions for loading, executing, freeing DNN models for one of the backends.