@@ -236,16 +236,32 @@ static void infer_completion_callback(void *args)
av_assert0(request->task_count >= 1);
for (int i = 0; i < request->task_count; ++i) {
task = request->tasks[i];
- if (task->do_ioproc) {
- if (task->ov_model->model->frame_post_proc != NULL) {
- task->ov_model->model->frame_post_proc(task->out_frame, &output, task->ov_model->model->filter_ctx);
+
+ switch (task->ov_model->model->func_type) {
+ case DFT_PROCESS_FRAME:
+ if (task->do_ioproc) {
+ if (task->ov_model->model->frame_post_proc != NULL) {
+ task->ov_model->model->frame_post_proc(task->out_frame, &output, task->ov_model->model->filter_ctx);
+ } else {
+ ff_proc_from_dnn_to_frame(task->out_frame, &output, ctx);
+ }
} else {
- ff_proc_from_dnn_to_frame(task->out_frame, &output, ctx);
+ task->out_frame->width = output.width;
+ task->out_frame->height = output.height;
}
- } else {
- task->out_frame->width = output.width;
- task->out_frame->height = output.height;
+ break;
+ case DFT_ANALYTICS_DETECT:
+ if (!task->ov_model->model->detect_post_proc) {
+ av_log(ctx, AV_LOG_ERROR, "detect filter needs to provide post proc\n");
+ return;
+ }
+ task->ov_model->model->detect_post_proc(task->out_frame, &output, 1, task->ov_model->model->filter_ctx);
+ break;
+ default:
+ av_assert0(!"should not reach here");
+ break;
}
+
task->done = 1;
output.data = (uint8_t *)output.data
+ output.width * output.height * output.channels * get_datatype_size(output.dt);
@@ -71,6 +71,12 @@ int ff_dnn_set_frame_proc(DnnContext *ctx, FramePrePostProc pre_proc, FramePrePo
return 0;
}
+int ff_dnn_set_detect_post_proc(DnnContext *ctx, DetectPostProc post_proc)
+{
+ ctx->model->detect_post_proc = post_proc;
+ return 0;
+}
+
DNNReturnType ff_dnn_get_input(DnnContext *ctx, DNNData *input)
{
return ctx->model->get_input(ctx->model->model, input, ctx->model_inputname);
@@ -49,6 +49,7 @@ typedef struct DnnContext {
int ff_dnn_init(DnnContext *ctx, DNNFunctionType func_type, AVFilterContext *filter_ctx);
int ff_dnn_set_frame_proc(DnnContext *ctx, FramePrePostProc pre_proc, FramePrePostProc post_proc);
+int ff_dnn_set_detect_post_proc(DnnContext *ctx, DetectPostProc post_proc);
DNNReturnType ff_dnn_get_input(DnnContext *ctx, DNNData *input);
DNNReturnType ff_dnn_get_output(DnnContext *ctx, int input_width, int input_height, int *output_width, int *output_height);
DNNReturnType ff_dnn_execute_model(DnnContext *ctx, AVFrame *in_frame, AVFrame *out_frame);
@@ -64,6 +64,7 @@ typedef struct DNNData{
} DNNData;
typedef int (*FramePrePostProc)(AVFrame *frame, DNNData *model, AVFilterContext *filter_ctx);
+typedef int (*DetectPostProc)(AVFrame *frame, DNNData *output, uint32_t nb, AVFilterContext *filter_ctx);
typedef struct DNNModel{
// Stores model that can be different for different backends.
@@ -86,6 +87,8 @@ typedef struct DNNModel{
// set the post process to transfer data from DNNData to AVFrame
// the default implementation within DNN is used if it is not provided by the filter
FramePrePostProc frame_post_proc;
+ // set the post process to interpret detect result from DNNData
+ DetectPostProc detect_post_proc;
} DNNModel;
// Stores pointers to functions for loading, executing, freeing DNN models for one of the backends.