diff mbox series

[FFmpeg-devel,WIP,v2,6/9] avfilter/dnn_backend_tf: Simplify memory allocation

Message ID tencent_F495C4C9098EA6A5441308E843F3E9F37005@qq.com
State New
Headers show
Series [FFmpeg-devel,WIP,v2,1/9] avfilter/dnn: Refactor DNN parameter configuration system | expand

Checks

Context Check Description
yinshiyou/make_loongarch64 success Make finished
yinshiyou/make_fate_loongarch64 success Make fate finished

Commit Message

Zhao Zhili April 28, 2024, 6:46 a.m. UTC
From: Zhao Zhili <zhilizhao@tencent.com>

---
 libavfilter/dnn/dnn_backend_tf.c | 33 +++++++++++++-------------------
 1 file changed, 13 insertions(+), 20 deletions(-)
diff mbox series

Patch

diff --git a/libavfilter/dnn/dnn_backend_tf.c b/libavfilter/dnn/dnn_backend_tf.c
index 3b4de6d13f..c7716e696d 100644
--- a/libavfilter/dnn/dnn_backend_tf.c
+++ b/libavfilter/dnn/dnn_backend_tf.c
@@ -37,8 +37,8 @@ 
 #include <tensorflow/c/c_api.h>
 
 typedef struct TFModel {
+    DNNModel model;
     DnnContext *ctx;
-    DNNModel *model;
     TF_Graph *graph;
     TF_Session *session;
     TF_Status *status;
@@ -518,7 +518,7 @@  static void dnn_free_model_tf(DNNModel **model)
         TF_DeleteStatus(tf_model->status);
     }
     av_freep(&tf_model);
-    av_freep(&model);
+    *model = NULL;
 }
 
 static DNNModel *dnn_load_model_tf(DnnContext *ctx, DNNFunctionType func_type, AVFilterContext *filter_ctx)
@@ -526,18 +526,11 @@  static DNNModel *dnn_load_model_tf(DnnContext *ctx, DNNFunctionType func_type, A
     DNNModel *model = NULL;
     TFModel *tf_model = NULL;
 
-    model = av_mallocz(sizeof(DNNModel));
-    if (!model){
-        return NULL;
-    }
-
     tf_model = av_mallocz(sizeof(TFModel));
-    if (!tf_model){
-        av_freep(&model);
+    if (!tf_model)
         return NULL;
-    }
+    model = &tf_model->model;
     model->model = tf_model;
-    tf_model->model = model;
     tf_model->ctx = ctx;
 
     if (load_tf_model(tf_model, ctx->model_filename) != 0){
@@ -650,11 +643,11 @@  static int fill_model_input_tf(TFModel *tf_model, TFRequestItem *request) {
     }
     input.data = (float *)TF_TensorData(infer_request->input_tensor);
 
-    switch (tf_model->model->func_type) {
+    switch (tf_model->model.func_type) {
     case DFT_PROCESS_FRAME:
         if (task->do_ioproc) {
-            if (tf_model->model->frame_pre_proc != NULL) {
-                tf_model->model->frame_pre_proc(task->in_frame, &input, tf_model->model->filter_ctx);
+            if (tf_model->model.frame_pre_proc != NULL) {
+                tf_model->model.frame_pre_proc(task->in_frame, &input, tf_model->model.filter_ctx);
             } else {
                 ff_proc_from_frame_to_dnn(task->in_frame, &input, ctx);
             }
@@ -664,7 +657,7 @@  static int fill_model_input_tf(TFModel *tf_model, TFRequestItem *request) {
         ff_frame_to_dnn_detect(task->in_frame, &input, ctx);
         break;
     default:
-        avpriv_report_missing_feature(ctx, "model function type %d", tf_model->model->func_type);
+        avpriv_report_missing_feature(ctx, "model function type %d", tf_model->model.func_type);
         break;
     }
 
@@ -724,12 +717,12 @@  static void infer_completion_callback(void *args) {
         outputs[i].data = TF_TensorData(infer_request->output_tensors[i]);
         outputs[i].dt = (DNNDataType)TF_TensorType(infer_request->output_tensors[i]);
     }
-    switch (tf_model->model->func_type) {
+    switch (tf_model->model.func_type) {
     case DFT_PROCESS_FRAME:
         //it only support 1 output if it's frame in & frame out
         if (task->do_ioproc) {
-            if (tf_model->model->frame_post_proc != NULL) {
-                tf_model->model->frame_post_proc(task->out_frame, outputs, tf_model->model->filter_ctx);
+            if (tf_model->model.frame_post_proc != NULL) {
+                tf_model->model.frame_post_proc(task->out_frame, outputs, tf_model->model.filter_ctx);
             } else {
                 ff_proc_from_dnn_to_frame(task->out_frame, outputs, ctx);
             }
@@ -741,11 +734,11 @@  static void infer_completion_callback(void *args) {
         }
         break;
     case DFT_ANALYTICS_DETECT:
-        if (!tf_model->model->detect_post_proc) {
+        if (!tf_model->model.detect_post_proc) {
             av_log(ctx, AV_LOG_ERROR, "Detect filter needs provide post proc\n");
             return;
         }
-        tf_model->model->detect_post_proc(task->in_frame, outputs, task->nb_output, tf_model->model->filter_ctx);
+        tf_model->model.detect_post_proc(task->in_frame, outputs, task->nb_output, tf_model->model.filter_ctx);
         break;
     default:
         av_log(ctx, AV_LOG_ERROR, "Tensorflow backend does not support this kind of dnn filter now\n");