diff mbox series

[FFmpeg-devel,3/4] lavfi/dnn_backend_tensorflow: support detect model

Message ID 20210430030711.30216-3-ting.fu@intel.com
State Accepted
Headers show
Series [FFmpeg-devel,1/4] dnn: add DCO_RGB color order to enum DNNColorOrder
Related show

Checks

Context Check Description
andriy/x86_make success Make finished
andriy/x86_make_fate success Make fate finished
andriy/PPC64_make success Make finished
andriy/PPC64_make_fate success Make fate finished

Commit Message

Ting Fu April 30, 2021, 3:07 a.m. UTC
Signed-off-by: Ting Fu <ting.fu@intel.com>
---
 libavfilter/dnn/dnn_backend_tf.c | 39 ++++++++++++++++++++++++++------
 libavfilter/vf_dnn_detect.c      | 32 +++++++++++++++++++++++++-
 2 files changed, 63 insertions(+), 8 deletions(-)
diff mbox series

Patch

diff --git a/libavfilter/dnn/dnn_backend_tf.c b/libavfilter/dnn/dnn_backend_tf.c
index 5c85b562c4..8fb2ae8583 100644
--- a/libavfilter/dnn/dnn_backend_tf.c
+++ b/libavfilter/dnn/dnn_backend_tf.c
@@ -793,15 +793,40 @@  static DNNReturnType execute_model_tf(const DNNModel *model, const char *input_n
         outputs[i].data = TF_TensorData(output_tensors[i]);
         outputs[i].dt = TF_TensorType(output_tensors[i]);
     }
-    if (do_ioproc) {
-        if (tf_model->model->frame_post_proc != NULL) {
-            tf_model->model->frame_post_proc(out_frame, outputs, tf_model->model->filter_ctx);
+    switch (model->func_type) {
+    case DFT_PROCESS_FRAME:
+        //it only support 1 output if it's frame in & frame out
+        if (do_ioproc) {
+            if (tf_model->model->frame_post_proc != NULL) {
+                tf_model->model->frame_post_proc(out_frame, outputs, tf_model->model->filter_ctx);
+            } else {
+                ff_proc_from_dnn_to_frame(out_frame, outputs, ctx);
+            }
         } else {
-            ff_proc_from_dnn_to_frame(out_frame, outputs, ctx);
+            out_frame->width = outputs[0].width;
+            out_frame->height = outputs[0].height;
+        }
+        break;
+    case DFT_ANALYTICS_DETECT:
+        if (!model->detect_post_proc) {
+            av_log(ctx, AV_LOG_ERROR, "Detect filter needs provide post proc\n");
+            return DNN_ERROR;
+        }
+        model->detect_post_proc(out_frame, outputs, nb_output, model->filter_ctx);
+        break;
+    default:
+        for (uint32_t i = 0; i < nb_output; ++i) {
+            if (output_tensors[i]) {
+                TF_DeleteTensor(output_tensors[i]);
+            }
         }
-    } else {
-        out_frame->width = outputs[0].width;
-        out_frame->height = outputs[0].height;
+        TF_DeleteTensor(input_tensor);
+        av_freep(&output_tensors);
+        av_freep(&tf_outputs);
+        av_freep(&outputs);
+
+        av_log(ctx, AV_LOG_ERROR, "Tensorflow backend does not support this kind of dnn filter now\n");
+        return DNN_ERROR;
     }
 
     for (uint32_t i = 0; i < nb_output; ++i) {
diff --git a/libavfilter/vf_dnn_detect.c b/libavfilter/vf_dnn_detect.c
index 1dbe4f29a4..7d39acb653 100644
--- a/libavfilter/vf_dnn_detect.c
+++ b/libavfilter/vf_dnn_detect.c
@@ -203,10 +203,40 @@  static int read_detect_label_file(AVFilterContext *context)
     return 0;
 }
 
+static int check_output_nb(DnnDetectContext *ctx, DNNBackendType backend_type, int output_nb)
+{
+    switch(backend_type) {
+    case DNN_TF:
+        if (output_nb != 4) {
+            av_log(ctx, AV_LOG_ERROR, "Only support tensorflow detect model with 4 outputs, \
+                                       but get %d instead\n", output_nb);
+            return AVERROR(EINVAL);
+        }
+        return 0;
+    case DNN_OV:
+        if (output_nb != 1) {
+            av_log(ctx, AV_LOG_ERROR, "Dnn detect filter with openvino backend needs 1 output only, \
+                                       but get %d instead\n", output_nb);
+            return AVERROR(EINVAL);
+        }
+        return 0;
+    default:
+        avpriv_report_missing_feature(ctx, "Dnn detect filter does not support current backend\n");
+        return AVERROR(EINVAL);
+    }
+    return 0;
+}
+
 static av_cold int dnn_detect_init(AVFilterContext *context)
 {
     DnnDetectContext *ctx = context->priv;
-    int ret = ff_dnn_init(&ctx->dnnctx, DFT_ANALYTICS_DETECT, context);
+    DnnContext *dnn_ctx = &ctx->dnnctx;
+    int ret;
+
+    ret = ff_dnn_init(&ctx->dnnctx, DFT_ANALYTICS_DETECT, context);
+    if (ret < 0)
+        return ret;
+    ret = check_output_nb(ctx, dnn_ctx->backend_type, dnn_ctx->nb_outputs);
     if (ret < 0)
         return ret;
     ff_dnn_set_detect_post_proc(&ctx->dnnctx, dnn_detect_post_proc);