diff mbox series

[FFmpeg-devel,2/4] dnn: change dnn interface to replace DNNData* with AVFrame*

Message ID 20200914062841.22082-1-yejun.guo@intel.com
State New
Headers show
Series [FFmpeg-devel,1/4] dnn: add userdata for load model parameter
Related show

Checks

Context Check Description
andriy/make success Make finished
andriy/make_fate success Make fate finished

Commit Message

Guo, Yejun Sept. 14, 2020, 6:28 a.m. UTC
Currently, every filter needs to provide code to transfer data from
AVFrame* to model input (DNNData*), and also from model output
(DNNData*) to AVFrame*. Actually, such transfer can be implemented
within DNN module, and so filter can focus on its own business logic.

DNN module also exports the function pointer pre_proc and post_proc
in struct DNNModel, just in case that a filter has its special logic
to transfer data between AVFrame* and DNNData*. The default implementation
within DNN module is used if the filter does not set pre/post_proc.

Signed-off-by: Guo, Yejun <yejun.guo@intel.com>
---
 configure                              |   2 +-
 libavfilter/dnn/Makefile               |   1 +
 libavfilter/dnn/dnn_backend_native.c   |  53 ++++--
 libavfilter/dnn/dnn_backend_native.h   |   3 +-
 libavfilter/dnn/dnn_backend_openvino.c |  71 +++++---
 libavfilter/dnn/dnn_backend_openvino.h |   2 +-
 libavfilter/dnn/dnn_backend_tf.c       |  90 ++++++----
 libavfilter/dnn/dnn_backend_tf.h       |   2 +-
 libavfilter/dnn/dnn_io_proc.c          | 135 ++++++++++++++
 libavfilter/dnn/dnn_io_proc.h          |  36 ++++
 libavfilter/dnn_interface.h            |  17 +-
 libavfilter/vf_derain.c                |  59 ++----
 libavfilter/vf_dnn_processing.c        | 240 +++++--------------------
 libavfilter/vf_sr.c                    | 166 +++++++----------
 14 files changed, 451 insertions(+), 426 deletions(-)
 create mode 100644 libavfilter/dnn/dnn_io_proc.c
 create mode 100644 libavfilter/dnn/dnn_io_proc.h
diff mbox series

Patch

diff --git a/configure b/configure
index 5d68695192..39fabb4ad5 100755
--- a/configure
+++ b/configure
@@ -2628,6 +2628,7 @@  cbs_vp9_select="cbs"
 dct_select="rdft"
 dirac_parse_select="golomb"
 dnn_suggest="libtensorflow libopenvino"
+dnn_deps="swscale"
 error_resilience_select="me_cmp"
 faandct_deps="faan"
 faandct_select="fdctdsp"
@@ -3532,7 +3533,6 @@  derain_filter_select="dnn"
 deshake_filter_select="pixelutils"
 deshake_opencl_filter_deps="opencl"
 dilation_opencl_filter_deps="opencl"
-dnn_processing_filter_deps="swscale"
 dnn_processing_filter_select="dnn"
 drawtext_filter_deps="libfreetype"
 drawtext_filter_suggest="libfontconfig libfribidi"
diff --git a/libavfilter/dnn/Makefile b/libavfilter/dnn/Makefile
index e0957073ee..ee08cc5243 100644
--- a/libavfilter/dnn/Makefile
+++ b/libavfilter/dnn/Makefile
@@ -1,4 +1,5 @@ 
 OBJS-$(CONFIG_DNN)                           += dnn/dnn_interface.o
+OBJS-$(CONFIG_DNN)                           += dnn/dnn_io_proc.o
 OBJS-$(CONFIG_DNN)                           += dnn/dnn_backend_native.o
 OBJS-$(CONFIG_DNN)                           += dnn/dnn_backend_native_layers.o
 OBJS-$(CONFIG_DNN)                           += dnn/dnn_backend_native_layer_avgpool.o
diff --git a/libavfilter/dnn/dnn_backend_native.c b/libavfilter/dnn/dnn_backend_native.c
index 830ec19c80..14e878b6b8 100644
--- a/libavfilter/dnn/dnn_backend_native.c
+++ b/libavfilter/dnn/dnn_backend_native.c
@@ -27,6 +27,7 @@ 
 #include "libavutil/avassert.h"
 #include "dnn_backend_native_layer_conv2d.h"
 #include "dnn_backend_native_layers.h"
+#include "dnn_io_proc.h"
 
 #define OFFSET(x) offsetof(NativeContext, x)
 #define FLAGS AV_OPT_FLAG_FILTERING_PARAM
@@ -69,11 +70,12 @@  static DNNReturnType get_input_native(void *model, DNNData *input, const char *i
     return DNN_ERROR;
 }
 
-static DNNReturnType set_input_native(void *model, DNNData *input, const char *input_name)
+static DNNReturnType set_input_native(void *model, AVFrame *frame, const char *input_name)
 {
     NativeModel *native_model = (NativeModel *)model;
     NativeContext *ctx = &native_model->ctx;
     DnnOperand *oprd = NULL;
+    DNNData input;
 
     if (native_model->layers_num <= 0 || native_model->operands_num <= 0) {
         av_log(ctx, AV_LOG_ERROR, "No operands or layers in model\n");
@@ -97,10 +99,8 @@  static DNNReturnType set_input_native(void *model, DNNData *input, const char *i
         return DNN_ERROR;
     }
 
-    oprd->dims[0] = 1;
-    oprd->dims[1] = input->height;
-    oprd->dims[2] = input->width;
-    oprd->dims[3] = input->channels;
+    oprd->dims[1] = frame->height;
+    oprd->dims[2] = frame->width;
 
     av_freep(&oprd->data);
     oprd->length = calculate_operand_data_length(oprd);
@@ -114,7 +114,16 @@  static DNNReturnType set_input_native(void *model, DNNData *input, const char *i
         return DNN_ERROR;
     }
 
-    input->data = oprd->data;
+    input.height = oprd->dims[1];
+    input.width = oprd->dims[2];
+    input.channels = oprd->dims[3];
+    input.data = oprd->data;
+    input.dt = oprd->data_type;
+    if (native_model->model->pre_proc != NULL) {
+        native_model->model->pre_proc(frame, &input, native_model->model->userdata);
+    } else {
+        proc_from_frame_to_dnn(frame, &input, ctx);
+    }
 
     return DNN_SUCCESS;
 }
@@ -185,6 +194,7 @@  DNNModel *ff_dnn_load_model_native(const char *model_filename, const char *optio
     if (av_opt_set_from_string(&native_model->ctx, model->options, NULL, "=", "&") < 0)
         goto fail;
     model->model = (void *)native_model;
+    native_model->model = model;
 
 #if !HAVE_PTHREAD_CANCEL
     if (native_model->ctx.options.conv2d_threads > 1){
@@ -275,11 +285,19 @@  fail:
     return NULL;
 }
 
-DNNReturnType ff_dnn_execute_model_native(const DNNModel *model, DNNData *outputs, const char **output_names, uint32_t nb_output)
+DNNReturnType ff_dnn_execute_model_native(const DNNModel *model, const char **output_names, uint32_t nb_output, AVFrame *out_frame)
 {
     NativeModel *native_model = (NativeModel *)model->model;
     NativeContext *ctx = &native_model->ctx;
     int32_t layer;
+    DNNData output;
+
+    if (nb_output != 1) {
+        // currently, the filter does not need multiple outputs,
+        // so we just pending the support until we really need it.
+        av_log(ctx, AV_LOG_ERROR, "do not support multiple outputs\n");
+        return DNN_ERROR;
+    }
 
     if (native_model->layers_num <= 0 || native_model->operands_num <= 0) {
         av_log(ctx, AV_LOG_ERROR, "No operands or layers in model\n");
@@ -317,11 +335,22 @@  DNNReturnType ff_dnn_execute_model_native(const DNNModel *model, DNNData *output
             return DNN_ERROR;
         }
 
-        outputs[i].data = oprd->data;
-        outputs[i].height = oprd->dims[1];
-        outputs[i].width = oprd->dims[2];
-        outputs[i].channels = oprd->dims[3];
-        outputs[i].dt = oprd->data_type;
+        output.data = oprd->data;
+        output.height = oprd->dims[1];
+        output.width = oprd->dims[2];
+        output.channels = oprd->dims[3];
+        output.dt = oprd->data_type;
+
+        if (out_frame->width != output.width || out_frame->height != output.height) {
+            out_frame->width = output.width;
+            out_frame->height = output.height;
+        } else {
+            if (native_model->model->post_proc != NULL) {
+                native_model->model->post_proc(out_frame, &output, native_model->model->userdata);
+            } else {
+                proc_from_dnn_to_frame(out_frame, &output, ctx);
+            }
+        }
     }
 
     return DNN_SUCCESS;
diff --git a/libavfilter/dnn/dnn_backend_native.h b/libavfilter/dnn/dnn_backend_native.h
index 33634118a8..553438bd22 100644
--- a/libavfilter/dnn/dnn_backend_native.h
+++ b/libavfilter/dnn/dnn_backend_native.h
@@ -119,6 +119,7 @@  typedef struct NativeContext {
 // Represents simple feed-forward convolutional network.
 typedef struct NativeModel{
     NativeContext ctx;
+    DNNModel *model;
     Layer *layers;
     int32_t layers_num;
     DnnOperand *operands;
@@ -127,7 +128,7 @@  typedef struct NativeModel{
 
 DNNModel *ff_dnn_load_model_native(const char *model_filename, const char *options, void *userdata);
 
-DNNReturnType ff_dnn_execute_model_native(const DNNModel *model, DNNData *outputs, const char **output_names, uint32_t nb_output);
+DNNReturnType ff_dnn_execute_model_native(const DNNModel *model, const char **output_names, uint32_t nb_output, AVFrame *out_frame);
 
 void ff_dnn_free_model_native(DNNModel **model);
 
diff --git a/libavfilter/dnn/dnn_backend_openvino.c b/libavfilter/dnn/dnn_backend_openvino.c
index 01e1a1d4c8..b1bad3f659 100644
--- a/libavfilter/dnn/dnn_backend_openvino.c
+++ b/libavfilter/dnn/dnn_backend_openvino.c
@@ -24,6 +24,7 @@ 
  */
 
 #include "dnn_backend_openvino.h"
+#include "dnn_io_proc.h"
 #include "libavformat/avio.h"
 #include "libavutil/avassert.h"
 #include "libavutil/opt.h"
@@ -42,6 +43,7 @@  typedef struct OVContext {
 
 typedef struct OVModel{
     OVContext ctx;
+    DNNModel *model;
     ie_core_t *core;
     ie_network_t *network;
     ie_executable_network_t *exe_network;
@@ -131,7 +133,7 @@  static DNNReturnType get_input_ov(void *model, DNNData *input, const char *input
     return DNN_ERROR;
 }
 
-static DNNReturnType set_input_ov(void *model, DNNData *input, const char *input_name)
+static DNNReturnType set_input_ov(void *model, AVFrame *frame, const char *input_name)
 {
     OVModel *ov_model = (OVModel *)model;
     OVContext *ctx = &ov_model->ctx;
@@ -139,10 +141,7 @@  static DNNReturnType set_input_ov(void *model, DNNData *input, const char *input
     dimensions_t dims;
     precision_e precision;
     ie_blob_buffer_t blob_buffer;
-
-    status = ie_exec_network_create_infer_request(ov_model->exe_network, &ov_model->infer_request);
-    if (status != OK)
-        goto err;
+    DNNData input;
 
     status = ie_infer_request_get_blob(ov_model->infer_request, input_name, &ov_model->input_blob);
     if (status != OK)
@@ -153,23 +152,26 @@  static DNNReturnType set_input_ov(void *model, DNNData *input, const char *input
     if (status != OK)
         goto err;
 
-    av_assert0(input->channels == dims.dims[1]);
-    av_assert0(input->height   == dims.dims[2]);
-    av_assert0(input->width    == dims.dims[3]);
-    av_assert0(input->dt       == precision_to_datatype(precision));
-
     status = ie_blob_get_buffer(ov_model->input_blob, &blob_buffer);
     if (status != OK)
         goto err;
-    input->data = blob_buffer.buffer;
+
+    input.height = dims.dims[2];
+    input.width = dims.dims[3];
+    input.channels = dims.dims[1];
+    input.data = blob_buffer.buffer;
+    input.dt = precision_to_datatype(precision);
+    if (ov_model->model->pre_proc != NULL) {
+        ov_model->model->pre_proc(frame, &input, ov_model->model->userdata);
+    } else {
+        proc_from_frame_to_dnn(frame, &input, ctx);
+    }
 
     return DNN_SUCCESS;
 
 err:
     if (ov_model->input_blob)
         ie_blob_free(&ov_model->input_blob);
-    if (ov_model->infer_request)
-        ie_infer_request_free(&ov_model->infer_request);
     av_log(ctx, AV_LOG_ERROR, "Failed to create inference instance or get input data/dims/precision/memory\n");
     return DNN_ERROR;
 }
@@ -184,7 +186,7 @@  DNNModel *ff_dnn_load_model_ov(const char *model_filename, const char *options,
     ie_config_t config = {NULL, NULL, NULL};
     ie_available_devices_t a_dev;
 
-    model = av_malloc(sizeof(DNNModel));
+    model = av_mallocz(sizeof(DNNModel));
     if (!model){
         return NULL;
     }
@@ -192,6 +194,7 @@  DNNModel *ff_dnn_load_model_ov(const char *model_filename, const char *options,
     ov_model = av_mallocz(sizeof(OVModel));
     if (!ov_model)
         goto err;
+    ov_model->model = model;
     ov_model->ctx.class = &dnn_openvino_class;
     ctx = &ov_model->ctx;
 
@@ -226,6 +229,10 @@  DNNModel *ff_dnn_load_model_ov(const char *model_filename, const char *options,
         goto err;
     }
 
+    status = ie_exec_network_create_infer_request(ov_model->exe_network, &ov_model->infer_request);
+    if (status != OK)
+        goto err;
+
     model->model = (void *)ov_model;
     model->set_input = &set_input_ov;
     model->get_input = &get_input_ov;
@@ -238,6 +245,8 @@  err:
     if (model)
         av_freep(&model);
     if (ov_model) {
+        if (ov_model->infer_request)
+            ie_infer_request_free(&ov_model->infer_request);
         if (ov_model->exe_network)
             ie_exec_network_free(&ov_model->exe_network);
         if (ov_model->network)
@@ -249,7 +258,7 @@  err:
     return NULL;
 }
 
-DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, DNNData *outputs, const char **output_names, uint32_t nb_output)
+DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, const char **output_names, uint32_t nb_output, AVFrame *out_frame)
 {
     char *model_output_name = NULL;
     char *all_output_names = NULL;
@@ -258,8 +267,18 @@  DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, DNNData *outputs, c
     ie_blob_buffer_t blob_buffer;
     OVModel *ov_model = (OVModel *)model->model;
     OVContext *ctx = &ov_model->ctx;
-    IEStatusCode status = ie_infer_request_infer(ov_model->infer_request);
+    IEStatusCode status;
     size_t model_output_count = 0;
+    DNNData output;
+
+    if (nb_output != 1) {
+        // currently, the filter does not need multiple outputs,
+        // so we just pending the support until we really need it.
+        av_log(ctx, AV_LOG_ERROR, "do not support multiple outputs\n");
+        return DNN_ERROR;
+    }
+
+    status = ie_infer_request_infer(ov_model->infer_request);
     if (status != OK) {
         av_log(ctx, AV_LOG_ERROR, "Failed to start synchronous model inference\n");
         return DNN_ERROR;
@@ -296,11 +315,21 @@  DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, DNNData *outputs, c
             return DNN_ERROR;
         }
 
-        outputs[i].channels = dims.dims[1];
-        outputs[i].height   = dims.dims[2];
-        outputs[i].width    = dims.dims[3];
-        outputs[i].dt       = precision_to_datatype(precision);
-        outputs[i].data     = blob_buffer.buffer;
+        output.channels = dims.dims[1];
+        output.height   = dims.dims[2];
+        output.width    = dims.dims[3];
+        output.dt       = precision_to_datatype(precision);
+        output.data     = blob_buffer.buffer;
+        if (out_frame->width != output.width || out_frame->height != output.height) {
+            out_frame->width = output.width;
+            out_frame->height = output.height;
+        } else {
+            if (ov_model->model->post_proc != NULL) {
+                ov_model->model->post_proc(out_frame, &output, ov_model->model->userdata);
+            } else {
+                proc_from_dnn_to_frame(out_frame, &output, ctx);
+            }
+        }
     }
 
     return DNN_SUCCESS;
diff --git a/libavfilter/dnn/dnn_backend_openvino.h b/libavfilter/dnn/dnn_backend_openvino.h
index f69bc5ca0c..efb349cb49 100644
--- a/libavfilter/dnn/dnn_backend_openvino.h
+++ b/libavfilter/dnn/dnn_backend_openvino.h
@@ -31,7 +31,7 @@ 
 
 DNNModel *ff_dnn_load_model_ov(const char *model_filename, const char *options, void *userdata);
 
-DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, DNNData *outputs, const char **output_names, uint32_t nb_output);
+DNNReturnType ff_dnn_execute_model_ov(const DNNModel *model, const char **output_names, uint32_t nb_output, AVFrame *out_frame);
 
 void ff_dnn_free_model_ov(DNNModel **model);
 
diff --git a/libavfilter/dnn/dnn_backend_tf.c b/libavfilter/dnn/dnn_backend_tf.c
index bac7d8c420..c2d8c06931 100644
--- a/libavfilter/dnn/dnn_backend_tf.c
+++ b/libavfilter/dnn/dnn_backend_tf.c
@@ -31,6 +31,7 @@ 
 #include "libavutil/avassert.h"
 #include "dnn_backend_native_layer_pad.h"
 #include "dnn_backend_native_layer_maximum.h"
+#include "dnn_io_proc.h"
 
 #include <tensorflow/c/c_api.h>
 
@@ -40,13 +41,12 @@  typedef struct TFContext {
 
 typedef struct TFModel{
     TFContext ctx;
+    DNNModel *model;
     TF_Graph *graph;
     TF_Session *session;
     TF_Status *status;
     TF_Output input;
     TF_Tensor *input_tensor;
-    TF_Tensor **output_tensors;
-    uint32_t nb_output;
 } TFModel;
 
 static const AVClass dnn_tensorflow_class = {
@@ -152,13 +152,19 @@  static DNNReturnType get_input_tf(void *model, DNNData *input, const char *input
     return DNN_SUCCESS;
 }
 
-static DNNReturnType set_input_tf(void *model, DNNData *input, const char *input_name)
+static DNNReturnType set_input_tf(void *model, AVFrame *frame, const char *input_name)
 {
     TFModel *tf_model = (TFModel *)model;
     TFContext *ctx = &tf_model->ctx;
+    DNNData input;
     TF_SessionOptions *sess_opts;
     const TF_Operation *init_op = TF_GraphOperationByName(tf_model->graph, "init");
 
+    if (get_input_tf(model, &input, input_name) != DNN_SUCCESS)
+        return DNN_ERROR;
+    input.height = frame->height;
+    input.width = frame->width;
+
     // Input operation
     tf_model->input.oper = TF_GraphOperationByName(tf_model->graph, input_name);
     if (!tf_model->input.oper){
@@ -169,12 +175,18 @@  static DNNReturnType set_input_tf(void *model, DNNData *input, const char *input
     if (tf_model->input_tensor){
         TF_DeleteTensor(tf_model->input_tensor);
     }
-    tf_model->input_tensor = allocate_input_tensor(input);
+    tf_model->input_tensor = allocate_input_tensor(&input);
     if (!tf_model->input_tensor){
         av_log(ctx, AV_LOG_ERROR, "Failed to allocate memory for input tensor\n");
         return DNN_ERROR;
     }
-    input->data = (float *)TF_TensorData(tf_model->input_tensor);
+    input.data = (float *)TF_TensorData(tf_model->input_tensor);
+
+    if (tf_model->model->pre_proc != NULL) {
+        tf_model->model->pre_proc(frame, &input, tf_model->model->userdata);
+    } else {
+        proc_from_frame_to_dnn(frame, &input, ctx);
+    }
 
     // session
     if (tf_model->session){
@@ -591,7 +603,7 @@  DNNModel *ff_dnn_load_model_tf(const char *model_filename, const char *options,
     DNNModel *model = NULL;
     TFModel *tf_model = NULL;
 
-    model = av_malloc(sizeof(DNNModel));
+    model = av_mallocz(sizeof(DNNModel));
     if (!model){
         return NULL;
     }
@@ -602,6 +614,7 @@  DNNModel *ff_dnn_load_model_tf(const char *model_filename, const char *options,
         return NULL;
     }
     tf_model->ctx.class = &dnn_tensorflow_class;
+    tf_model->model = model;
 
     if (load_tf_model(tf_model, model_filename) != DNN_SUCCESS){
         if (load_native_model(tf_model, model_filename) != DNN_SUCCESS){
@@ -621,11 +634,20 @@  DNNModel *ff_dnn_load_model_tf(const char *model_filename, const char *options,
     return model;
 }
 
-DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, DNNData *outputs, const char **output_names, uint32_t nb_output)
+DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, const char **output_names, uint32_t nb_output, AVFrame *out_frame)
 {
     TF_Output *tf_outputs;
     TFModel *tf_model = (TFModel *)model->model;
     TFContext *ctx = &tf_model->ctx;
+    DNNData output;
+    TF_Tensor **output_tensors;
+
+    if (nb_output != 1) {
+        // currently, the filter does not need multiple outputs,
+        // so we just pending the support until we really need it.
+        av_log(ctx, AV_LOG_ERROR, "do not support multiple outputs\n");
+        return DNN_ERROR;
+    }
 
     tf_outputs = av_malloc_array(nb_output, sizeof(*tf_outputs));
     if (tf_outputs == NULL) {
@@ -633,18 +655,8 @@  DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, DNNData *outputs, c
         return DNN_ERROR;
     }
 
-    if (tf_model->output_tensors) {
-        for (uint32_t i = 0; i < tf_model->nb_output; ++i) {
-            if (tf_model->output_tensors[i]) {
-                TF_DeleteTensor(tf_model->output_tensors[i]);
-                tf_model->output_tensors[i] = NULL;
-            }
-        }
-    }
-    av_freep(&tf_model->output_tensors);
-    tf_model->nb_output = nb_output;
-    tf_model->output_tensors = av_mallocz_array(nb_output, sizeof(*tf_model->output_tensors));
-    if (!tf_model->output_tensors) {
+    output_tensors = av_mallocz_array(nb_output, sizeof(*output_tensors));
+    if (!output_tensors) {
         av_freep(&tf_outputs);
         av_log(ctx, AV_LOG_ERROR, "Failed to allocate memory for output tensor\n"); \
         return DNN_ERROR;
@@ -654,6 +666,7 @@  DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, DNNData *outputs, c
         tf_outputs[i].oper = TF_GraphOperationByName(tf_model->graph, output_names[i]);
         if (!tf_outputs[i].oper) {
             av_freep(&tf_outputs);
+            av_freep(&output_tensors);
             av_log(ctx, AV_LOG_ERROR, "Could not find output \"%s\" in model\n", output_names[i]); \
             return DNN_ERROR;
         }
@@ -662,22 +675,40 @@  DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, DNNData *outputs, c
 
     TF_SessionRun(tf_model->session, NULL,
                   &tf_model->input, &tf_model->input_tensor, 1,
-                  tf_outputs, tf_model->output_tensors, nb_output,
+                  tf_outputs, output_tensors, nb_output,
                   NULL, 0, NULL, tf_model->status);
     if (TF_GetCode(tf_model->status) != TF_OK) {
         av_freep(&tf_outputs);
+        av_freep(&output_tensors);
         av_log(ctx, AV_LOG_ERROR, "Failed to run session when executing model\n");
         return DNN_ERROR;
     }
 
     for (uint32_t i = 0; i < nb_output; ++i) {
-        outputs[i].height = TF_Dim(tf_model->output_tensors[i], 1);
-        outputs[i].width = TF_Dim(tf_model->output_tensors[i], 2);
-        outputs[i].channels = TF_Dim(tf_model->output_tensors[i], 3);
-        outputs[i].data = TF_TensorData(tf_model->output_tensors[i]);
-        outputs[i].dt = TF_TensorType(tf_model->output_tensors[i]);
+        output.height = TF_Dim(output_tensors[i], 1);
+        output.width = TF_Dim(output_tensors[i], 2);
+        output.channels = TF_Dim(output_tensors[i], 3);
+        output.data = TF_TensorData(output_tensors[i]);
+        output.dt = TF_TensorType(output_tensors[i]);
+
+        if (out_frame->width != output.width || out_frame->height != output.height) {
+            out_frame->width = output.width;
+            out_frame->height = output.height;
+        } else {
+            if (tf_model->model->post_proc != NULL) {
+                tf_model->model->post_proc(out_frame, &output, tf_model->model->userdata);
+            } else {
+                proc_from_dnn_to_frame(out_frame, &output, ctx);
+            }
+        }
     }
 
+    for (uint32_t i = 0; i < nb_output; ++i) {
+        if (output_tensors[i]) {
+            TF_DeleteTensor(output_tensors[i]);
+        }
+    }
+    av_freep(&output_tensors);
     av_freep(&tf_outputs);
     return DNN_SUCCESS;
 }
@@ -701,15 +732,6 @@  void ff_dnn_free_model_tf(DNNModel **model)
         if (tf_model->input_tensor){
             TF_DeleteTensor(tf_model->input_tensor);
         }
-        if (tf_model->output_tensors) {
-            for (uint32_t i = 0; i < tf_model->nb_output; ++i) {
-                if (tf_model->output_tensors[i]) {
-                    TF_DeleteTensor(tf_model->output_tensors[i]);
-                    tf_model->output_tensors[i] = NULL;
-                }
-            }
-        }
-        av_freep(&tf_model->output_tensors);
         av_freep(&tf_model);
         av_freep(model);
     }
diff --git a/libavfilter/dnn/dnn_backend_tf.h b/libavfilter/dnn/dnn_backend_tf.h
index 1cf5cc9e76..f379e83d8d 100644
--- a/libavfilter/dnn/dnn_backend_tf.h
+++ b/libavfilter/dnn/dnn_backend_tf.h
@@ -31,7 +31,7 @@ 
 
 DNNModel *ff_dnn_load_model_tf(const char *model_filename, const char *options, void *userdata);
 
-DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, DNNData *outputs, const char **output_names, uint32_t nb_output);
+DNNReturnType ff_dnn_execute_model_tf(const DNNModel *model, const char **output_names, uint32_t nb_output, AVFrame *out_frame);
 
 void ff_dnn_free_model_tf(DNNModel **model);
 
diff --git a/libavfilter/dnn/dnn_io_proc.c b/libavfilter/dnn/dnn_io_proc.c
new file mode 100644
index 0000000000..8ce1959b42
--- /dev/null
+++ b/libavfilter/dnn/dnn_io_proc.c
@@ -0,0 +1,135 @@ 
+/*
+ * Copyright (c) 2020
+ *
+ * This file is part of FFmpeg.
+ *
+ * FFmpeg is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * FFmpeg is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with FFmpeg; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+
+#include "dnn_io_proc.h"
+#include "libavutil/imgutils.h"
+#include "libswscale/swscale.h"
+
+DNNReturnType proc_from_dnn_to_frame(AVFrame *frame, DNNData *output, void *log_ctx)
+{
+    struct SwsContext *sws_ctx;
+    int bytewidth = av_image_get_linesize(frame->format, frame->width, 0);
+    if (output->dt != DNN_FLOAT) {
+        av_log(log_ctx, AV_LOG_ERROR, "do not support data type rather than DNN_FLOAT\n");
+        return DNN_ERROR;
+    }
+
+    switch (frame->format) {
+    case AV_PIX_FMT_RGB24:
+    case AV_PIX_FMT_BGR24:
+        sws_ctx = sws_getContext(frame->width * 3,
+                                 frame->height,
+                                 AV_PIX_FMT_GRAYF32,
+                                 frame->width * 3,
+                                 frame->height,
+                                 AV_PIX_FMT_GRAY8,
+                                 0, NULL, NULL, NULL);
+        sws_scale(sws_ctx, (const uint8_t *[4]){(const uint8_t *)output->data, 0, 0, 0},
+                           (const int[4]){frame->width * 3 * sizeof(float), 0, 0, 0}, 0, frame->height,
+                           (uint8_t * const*)frame->data, frame->linesize);
+        sws_freeContext(sws_ctx);
+        return DNN_SUCCESS;
+    case AV_PIX_FMT_GRAYF32:
+        av_image_copy_plane(frame->data[0], frame->linesize[0],
+                            output->data, bytewidth,
+                            bytewidth, frame->height);
+        return DNN_SUCCESS;
+    case AV_PIX_FMT_YUV420P:
+    case AV_PIX_FMT_YUV422P:
+    case AV_PIX_FMT_YUV444P:
+    case AV_PIX_FMT_YUV410P:
+    case AV_PIX_FMT_YUV411P:
+    case AV_PIX_FMT_GRAY8:
+        sws_ctx = sws_getContext(frame->width,
+                                 frame->height,
+                                 AV_PIX_FMT_GRAYF32,
+                                 frame->width,
+                                 frame->height,
+                                 AV_PIX_FMT_GRAY8,
+                                 0, NULL, NULL, NULL);
+        sws_scale(sws_ctx, (const uint8_t *[4]){(const uint8_t *)output->data, 0, 0, 0},
+                           (const int[4]){frame->width * sizeof(float), 0, 0, 0}, 0, frame->height,
+                           (uint8_t * const*)frame->data, frame->linesize);
+        sws_freeContext(sws_ctx);
+        return DNN_SUCCESS;
+    default:
+        av_log(log_ctx, AV_LOG_ERROR, "do not support frame format %d\n", frame->format);
+        return DNN_ERROR;
+    }
+
+    return DNN_SUCCESS;
+}
+
+DNNReturnType proc_from_frame_to_dnn(AVFrame *frame, DNNData *input, void *log_ctx)
+{
+    struct SwsContext *sws_ctx;
+    int bytewidth = av_image_get_linesize(frame->format, frame->width, 0);
+    if (input->dt != DNN_FLOAT) {
+        av_log(log_ctx, AV_LOG_ERROR, "do not support data type rather than DNN_FLOAT\n");
+        return DNN_ERROR;
+    }
+
+    switch (frame->format) {
+    case AV_PIX_FMT_RGB24:
+    case AV_PIX_FMT_BGR24:
+        sws_ctx = sws_getContext(frame->width * 3,
+                                 frame->height,
+                                 AV_PIX_FMT_GRAY8,
+                                 frame->width * 3,
+                                 frame->height,
+                                 AV_PIX_FMT_GRAYF32,
+                                 0, NULL, NULL, NULL);
+        sws_scale(sws_ctx, (const uint8_t **)frame->data,
+                           frame->linesize, 0, frame->height,
+                           (uint8_t * const*)(&input->data),
+                           (const int [4]){frame->width * 3 * sizeof(float), 0, 0, 0});
+        sws_freeContext(sws_ctx);
+        break;
+    case AV_PIX_FMT_GRAYF32:
+        av_image_copy_plane(input->data, bytewidth,
+                            frame->data[0], frame->linesize[0],
+                            bytewidth, frame->height);
+        break;
+    case AV_PIX_FMT_YUV420P:
+    case AV_PIX_FMT_YUV422P:
+    case AV_PIX_FMT_YUV444P:
+    case AV_PIX_FMT_YUV410P:
+    case AV_PIX_FMT_YUV411P:
+    case AV_PIX_FMT_GRAY8:
+        sws_ctx = sws_getContext(frame->width,
+                                 frame->height,
+                                 AV_PIX_FMT_GRAY8,
+                                 frame->width,
+                                 frame->height,
+                                 AV_PIX_FMT_GRAYF32,
+                                 0, NULL, NULL, NULL);
+        sws_scale(sws_ctx, (const uint8_t **)frame->data,
+                           frame->linesize, 0, frame->height,
+                           (uint8_t * const*)(&input->data),
+                           (const int [4]){frame->width * sizeof(float), 0, 0, 0});
+        sws_freeContext(sws_ctx);
+        break;
+    default:
+        av_log(log_ctx, AV_LOG_ERROR, "do not support frame format %d\n", frame->format);
+        return DNN_ERROR;
+    }
+
+    return DNN_SUCCESS;
+}
diff --git a/libavfilter/dnn/dnn_io_proc.h b/libavfilter/dnn/dnn_io_proc.h
new file mode 100644
index 0000000000..4c7dc7c1a2
--- /dev/null
+++ b/libavfilter/dnn/dnn_io_proc.h
@@ -0,0 +1,36 @@ 
+/*
+ * Copyright (c) 2020
+ *
+ * This file is part of FFmpeg.
+ *
+ * FFmpeg is free software; you can redistribute it and/or
+ * modify it under the terms of the GNU Lesser General Public
+ * License as published by the Free Software Foundation; either
+ * version 2.1 of the License, or (at your option) any later version.
+ *
+ * FFmpeg is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+ * Lesser General Public License for more details.
+ *
+ * You should have received a copy of the GNU Lesser General Public
+ * License along with FFmpeg; if not, write to the Free Software
+ * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+ */
+
+/**
+ * @file
+ * DNN input&output process between AVFrame and DNNData.
+ */
+
+
+#ifndef AVFILTER_DNN_DNN_IO_PROC_H
+#define AVFILTER_DNN_DNN_IO_PROC_H
+
+#include "../dnn_interface.h"
+#include "libavutil/frame.h"
+
+DNNReturnType proc_from_frame_to_dnn(AVFrame *frame, DNNData *input, void *log_ctx);
+DNNReturnType proc_from_dnn_to_frame(AVFrame *frame, DNNData *output, void *log_ctx);
+
+#endif
diff --git a/libavfilter/dnn_interface.h b/libavfilter/dnn_interface.h
index 702c8306e0..6debc50607 100644
--- a/libavfilter/dnn_interface.h
+++ b/libavfilter/dnn_interface.h
@@ -27,6 +27,7 @@ 
 #define AVFILTER_DNN_INTERFACE_H
 
 #include <stdint.h>
+#include "libavutil/frame.h"
 
 typedef enum {DNN_SUCCESS, DNN_ERROR} DNNReturnType;
 
@@ -50,17 +51,23 @@  typedef struct DNNModel{
     // Gets model input information
     // Just reuse struct DNNData here, actually the DNNData.data field is not needed.
     DNNReturnType (*get_input)(void *model, DNNData *input, const char *input_name);
-    // Sets model input and output.
-    // Should be called at least once before model execution.
-    DNNReturnType (*set_input)(void *model, DNNData *input, const char *input_name);
+    // Sets model input.
+    // Should be called every time before model execution.
+    DNNReturnType (*set_input)(void *model, AVFrame *frame, const char *input_name);
+    // set the pre process to transfer data from AVFrame to DNNData
+    // the default implementation within DNN is used if it is not provided by the filter
+    int (*pre_proc)(AVFrame *frame_in, DNNData *model_input, void *user_data);
+    // set the post process to transfer data from DNNData to AVFrame
+    // the default implementation within DNN is used if it is not provided by the filter
+    int (*post_proc)(AVFrame *frame_out, DNNData *model_output, void *user_data);
 } DNNModel;
 
 // Stores pointers to functions for loading, executing, freeing DNN models for one of the backends.
 typedef struct DNNModule{
     // Loads model and parameters from given file. Returns NULL if it is not possible.
     DNNModel *(*load_model)(const char *model_filename, const char *options, void *userdata);
-    // Executes model with specified input and output. Returns DNN_ERROR otherwise.
-    DNNReturnType (*execute_model)(const DNNModel *model, DNNData *outputs, const char **output_names, uint32_t nb_output);
+    // Executes model with specified output. Returns DNN_ERROR otherwise.
+    DNNReturnType (*execute_model)(const DNNModel *model, const char **output_names, uint32_t nb_output, AVFrame *out_frame);
     // Frees memory allocated for model.
     void (*free_model)(DNNModel **model);
 } DNNModule;
diff --git a/libavfilter/vf_derain.c b/libavfilter/vf_derain.c
index c251d55ee7..a59cd6e941 100644
--- a/libavfilter/vf_derain.c
+++ b/libavfilter/vf_derain.c
@@ -39,11 +39,8 @@  typedef struct DRContext {
     DNNBackendType     backend_type;
     DNNModule         *dnn_module;
     DNNModel          *model;
-    DNNData            input;
-    DNNData            output;
 } DRContext;
 
-#define CLIP(x, min, max) (x < min ? min : (x > max ? max : x))
 #define OFFSET(x) offsetof(DRContext, x)
 #define FLAGS AV_OPT_FLAG_FILTERING_PARAM | AV_OPT_FLAG_VIDEO_PARAM
 static const AVOption derain_options[] = {
@@ -74,25 +71,6 @@  static int query_formats(AVFilterContext *ctx)
     return ff_set_common_formats(ctx, formats);
 }
 
-static int config_inputs(AVFilterLink *inlink)
-{
-    AVFilterContext *ctx          = inlink->dst;
-    DRContext *dr_context         = ctx->priv;
-    DNNReturnType result;
-
-    dr_context->input.width    = inlink->w;
-    dr_context->input.height   = inlink->h;
-    dr_context->input.channels = 3;
-
-    result = (dr_context->model->set_input)(dr_context->model->model, &dr_context->input, "x");
-    if (result != DNN_SUCCESS) {
-        av_log(ctx, AV_LOG_ERROR, "could not set input and output for the model\n");
-        return AVERROR(EIO);
-    }
-
-    return 0;
-}
-
 static int filter_frame(AVFilterLink *inlink, AVFrame *in)
 {
     AVFilterContext *ctx  = inlink->dst;
@@ -100,43 +78,30 @@  static int filter_frame(AVFilterLink *inlink, AVFrame *in)
     DRContext *dr_context = ctx->priv;
     DNNReturnType dnn_result;
     const char *model_output_name = "y";
+    AVFrame *out;
 
-    AVFrame *out = ff_get_video_buffer(outlink, outlink->w, outlink->h);
+    dnn_result = (dr_context->model->set_input)(dr_context->model->model, in, "x");
+    if (dnn_result != DNN_SUCCESS) {
+        av_log(ctx, AV_LOG_ERROR, "could not set input for the model\n");
+        av_frame_free(&in);
+        return AVERROR(EIO);
+    }
+
+    out = ff_get_video_buffer(outlink, outlink->w, outlink->h);
     if (!out) {
         av_log(ctx, AV_LOG_ERROR, "could not allocate memory for output frame\n");
         av_frame_free(&in);
         return AVERROR(ENOMEM);
     }
-
     av_frame_copy_props(out, in);
 
-    for (int i = 0; i < in->height; i++){
-        for(int j = 0; j < in->width * 3; j++){
-            int k = i * in->linesize[0] + j;
-            int t = i * in->width * 3 + j;
-            ((float *)dr_context->input.data)[t] = in->data[0][k] / 255.0;
-        }
-    }
-
-    dnn_result = (dr_context->dnn_module->execute_model)(dr_context->model, &dr_context->output, &model_output_name, 1);
+    dnn_result = (dr_context->dnn_module->execute_model)(dr_context->model, &model_output_name, 1, out);
     if (dnn_result != DNN_SUCCESS){
         av_log(ctx, AV_LOG_ERROR, "failed to execute model\n");
+        av_frame_free(&in);
         return AVERROR(EIO);
     }
 
-    out->height = dr_context->output.height;
-    out->width  = dr_context->output.width;
-    outlink->h  = dr_context->output.height;
-    outlink->w  = dr_context->output.width;
-
-    for (int i = 0; i < out->height; i++){
-        for(int j = 0; j < out->width * 3; j++){
-            int k = i * out->linesize[0] + j;
-            int t = i * out->width * 3 + j;
-            out->data[0][k] = CLIP((int)((((float *)dr_context->output.data)[t]) * 255), 0, 255);
-        }
-    }
-
     av_frame_free(&in);
 
     return ff_filter_frame(outlink, out);
@@ -146,7 +111,6 @@  static av_cold int init(AVFilterContext *ctx)
 {
     DRContext *dr_context = ctx->priv;
 
-    dr_context->input.dt = DNN_FLOAT;
     dr_context->dnn_module = ff_get_dnn_module(dr_context->backend_type);
     if (!dr_context->dnn_module) {
         av_log(ctx, AV_LOG_ERROR, "could not create DNN module for requested backend\n");
@@ -184,7 +148,6 @@  static const AVFilterPad derain_inputs[] = {
     {
         .name         = "default",
         .type         = AVMEDIA_TYPE_VIDEO,
-        .config_props = config_inputs,
         .filter_frame = filter_frame,
     },
     { NULL }
diff --git a/libavfilter/vf_dnn_processing.c b/libavfilter/vf_dnn_processing.c
index f120bf9df4..d7462bc828 100644
--- a/libavfilter/vf_dnn_processing.c
+++ b/libavfilter/vf_dnn_processing.c
@@ -46,12 +46,6 @@  typedef struct DnnProcessingContext {
     DNNModule *dnn_module;
     DNNModel *model;
 
-    // input & output of the model at execution time
-    DNNData input;
-    DNNData output;
-
-    struct SwsContext *sws_gray8_to_grayf32;
-    struct SwsContext *sws_grayf32_to_gray8;
     struct SwsContext *sws_uv_scale;
     int sws_uv_height;
 } DnnProcessingContext;
@@ -103,7 +97,7 @@  static av_cold int init(AVFilterContext *context)
         return AVERROR(EINVAL);
     }
 
-    ctx->model = (ctx->dnn_module->load_model)(ctx->model_filename, ctx->backend_options, NULL);
+    ctx->model = (ctx->dnn_module->load_model)(ctx->model_filename, ctx->backend_options, ctx);
     if (!ctx->model) {
         av_log(ctx, AV_LOG_ERROR, "could not load DNN model\n");
         return AVERROR(EINVAL);
@@ -148,6 +142,10 @@  static int check_modelinput_inlink(const DNNData *model_input, const AVFilterLin
                                    model_input->width, inlink->w);
         return AVERROR(EIO);
     }
+    if (model_input->dt != DNN_FLOAT) {
+        av_log(ctx, AV_LOG_ERROR, "only support dnn models with input data type as float32.\n");
+        return AVERROR(EIO);
+    }
 
     switch (fmt) {
     case AV_PIX_FMT_RGB24:
@@ -156,20 +154,6 @@  static int check_modelinput_inlink(const DNNData *model_input, const AVFilterLin
             LOG_FORMAT_CHANNEL_MISMATCH();
             return AVERROR(EIO);
         }
-        if (model_input->dt != DNN_FLOAT && model_input->dt != DNN_UINT8) {
-            av_log(ctx, AV_LOG_ERROR, "only support dnn models with input data type as float32 and uint8.\n");
-            return AVERROR(EIO);
-        }
-        return 0;
-    case AV_PIX_FMT_GRAY8:
-        if (model_input->channels != 1) {
-            LOG_FORMAT_CHANNEL_MISMATCH();
-            return AVERROR(EIO);
-        }
-        if (model_input->dt != DNN_UINT8) {
-            av_log(ctx, AV_LOG_ERROR, "only support dnn models with input data type uint8.\n");
-            return AVERROR(EIO);
-        }
         return 0;
     case AV_PIX_FMT_GRAYF32:
     case AV_PIX_FMT_YUV420P:
@@ -181,10 +165,6 @@  static int check_modelinput_inlink(const DNNData *model_input, const AVFilterLin
             LOG_FORMAT_CHANNEL_MISMATCH();
             return AVERROR(EIO);
         }
-        if (model_input->dt != DNN_FLOAT) {
-            av_log(ctx, AV_LOG_ERROR, "only support dnn models with input data type float32.\n");
-            return AVERROR(EIO);
-        }
         return 0;
     default:
         av_log(ctx, AV_LOG_ERROR, "%s not supported.\n", av_get_pix_fmt_name(fmt));
@@ -213,74 +193,24 @@  static int config_input(AVFilterLink *inlink)
         return check;
     }
 
-    ctx->input.width    = inlink->w;
-    ctx->input.height   = inlink->h;
-    ctx->input.channels = model_input.channels;
-    ctx->input.dt = model_input.dt;
-
-    result = (ctx->model->set_input)(ctx->model->model,
-                                     &ctx->input, ctx->model_inputname);
-    if (result != DNN_SUCCESS) {
-        av_log(ctx, AV_LOG_ERROR, "could not set input and output for the model\n");
-        return AVERROR(EIO);
-    }
-
     return 0;
 }
 
-static int prepare_sws_context(AVFilterLink *outlink)
+static av_always_inline int isPlanarYUV(enum AVPixelFormat pix_fmt)
+{
+    const AVPixFmtDescriptor *desc = av_pix_fmt_desc_get(pix_fmt);
+    av_assert0(desc);
+    return !(desc->flags & AV_PIX_FMT_FLAG_RGB) && desc->nb_components == 3;
+}
+
+static int prepare_uv_scale(AVFilterLink *outlink)
 {
     AVFilterContext *context = outlink->src;
     DnnProcessingContext *ctx = context->priv;
     AVFilterLink *inlink = context->inputs[0];
     enum AVPixelFormat fmt = inlink->format;
-    DNNDataType input_dt  = ctx->input.dt;
-    DNNDataType output_dt = ctx->output.dt;
-
-    switch (fmt) {
-    case AV_PIX_FMT_RGB24:
-    case AV_PIX_FMT_BGR24:
-        if (input_dt == DNN_FLOAT) {
-            ctx->sws_gray8_to_grayf32 = sws_getContext(inlink->w * 3,
-                                                       inlink->h,
-                                                       AV_PIX_FMT_GRAY8,
-                                                       inlink->w * 3,
-                                                       inlink->h,
-                                                       AV_PIX_FMT_GRAYF32,
-                                                       0, NULL, NULL, NULL);
-        }
-        if (output_dt == DNN_FLOAT) {
-            ctx->sws_grayf32_to_gray8 = sws_getContext(outlink->w * 3,
-                                                       outlink->h,
-                                                       AV_PIX_FMT_GRAYF32,
-                                                       outlink->w * 3,
-                                                       outlink->h,
-                                                       AV_PIX_FMT_GRAY8,
-                                                       0, NULL, NULL, NULL);
-        }
-        return 0;
-    case AV_PIX_FMT_YUV420P:
-    case AV_PIX_FMT_YUV422P:
-    case AV_PIX_FMT_YUV444P:
-    case AV_PIX_FMT_YUV410P:
-    case AV_PIX_FMT_YUV411P:
-        av_assert0(input_dt == DNN_FLOAT);
-        av_assert0(output_dt == DNN_FLOAT);
-        ctx->sws_gray8_to_grayf32 = sws_getContext(inlink->w,
-                                                   inlink->h,
-                                                   AV_PIX_FMT_GRAY8,
-                                                   inlink->w,
-                                                   inlink->h,
-                                                   AV_PIX_FMT_GRAYF32,
-                                                   0, NULL, NULL, NULL);
-        ctx->sws_grayf32_to_gray8 = sws_getContext(outlink->w,
-                                                   outlink->h,
-                                                   AV_PIX_FMT_GRAYF32,
-                                                   outlink->w,
-                                                   outlink->h,
-                                                   AV_PIX_FMT_GRAY8,
-                                                   0, NULL, NULL, NULL);
 
+    if (isPlanarYUV(fmt)) {
         if (inlink->w != outlink->w || inlink->h != outlink->h) {
             const AVPixFmtDescriptor *desc = av_pix_fmt_desc_get(fmt);
             int sws_src_h = AV_CEIL_RSHIFT(inlink->h, desc->log2_chroma_h);
@@ -292,10 +222,6 @@  static int prepare_sws_context(AVFilterLink *outlink)
                                                SWS_BICUBIC, NULL, NULL, NULL);
             ctx->sws_uv_height = sws_src_h;
         }
-        return 0;
-    default:
-        //do nothing
-        break;
     }
 
     return 0;
@@ -306,120 +232,34 @@  static int config_output(AVFilterLink *outlink)
     AVFilterContext *context = outlink->src;
     DnnProcessingContext *ctx = context->priv;
     DNNReturnType result;
+    AVFilterLink *inlink = context->inputs[0];
+    AVFrame *out = NULL;
 
-    // have a try run in case that the dnn model resize the frame
-    result = (ctx->dnn_module->execute_model)(ctx->model, &ctx->output, (const char **)&ctx->model_outputname, 1);
-    if (result != DNN_SUCCESS){
-        av_log(ctx, AV_LOG_ERROR, "failed to execute model\n");
+    AVFrame *fake_in = ff_get_video_buffer(inlink, inlink->w, inlink->h);
+    result = (ctx->model->set_input)(ctx->model->model, fake_in, ctx->model_inputname);
+    if (result != DNN_SUCCESS) {
+        av_log(ctx, AV_LOG_ERROR, "could not set input for the model\n");
         return AVERROR(EIO);
     }
 
-    outlink->w = ctx->output.width;
-    outlink->h = ctx->output.height;
-
-    prepare_sws_context(outlink);
-
-    return 0;
-}
-
-static int copy_from_frame_to_dnn(DnnProcessingContext *ctx, const AVFrame *frame)
-{
-    int bytewidth = av_image_get_linesize(frame->format, frame->width, 0);
-    DNNData *dnn_input = &ctx->input;
-
-    switch (frame->format) {
-    case AV_PIX_FMT_RGB24:
-    case AV_PIX_FMT_BGR24:
-        if (dnn_input->dt == DNN_FLOAT) {
-            sws_scale(ctx->sws_gray8_to_grayf32, (const uint8_t **)frame->data, frame->linesize,
-                      0, frame->height, (uint8_t * const*)(&dnn_input->data),
-                      (const int [4]){frame->width * 3 * sizeof(float), 0, 0, 0});
-        } else {
-            av_assert0(dnn_input->dt == DNN_UINT8);
-            av_image_copy_plane(dnn_input->data, bytewidth,
-                                frame->data[0], frame->linesize[0],
-                                bytewidth, frame->height);
-        }
-        return 0;
-    case AV_PIX_FMT_GRAY8:
-    case AV_PIX_FMT_GRAYF32:
-        av_image_copy_plane(dnn_input->data, bytewidth,
-                            frame->data[0], frame->linesize[0],
-                            bytewidth, frame->height);
-        return 0;
-    case AV_PIX_FMT_YUV420P:
-    case AV_PIX_FMT_YUV422P:
-    case AV_PIX_FMT_YUV444P:
-    case AV_PIX_FMT_YUV410P:
-    case AV_PIX_FMT_YUV411P:
-        sws_scale(ctx->sws_gray8_to_grayf32, (const uint8_t **)frame->data, frame->linesize,
-                  0, frame->height, (uint8_t * const*)(&dnn_input->data),
-                  (const int [4]){frame->width * sizeof(float), 0, 0, 0});
-        return 0;
-    default:
+    // have a try run in case that the dnn model resize the frame
+    out = ff_get_video_buffer(inlink, inlink->w, inlink->h);
+    result = (ctx->dnn_module->execute_model)(ctx->model, (const char **)&ctx->model_outputname, 1, out);
+    if (result != DNN_SUCCESS){
+        av_log(ctx, AV_LOG_ERROR, "failed to execute model\n");
         return AVERROR(EIO);
     }
 
-    return 0;
-}
+    outlink->w = out->width;
+    outlink->h = out->height;
 
-static int copy_from_dnn_to_frame(DnnProcessingContext *ctx, AVFrame *frame)
-{
-    int bytewidth = av_image_get_linesize(frame->format, frame->width, 0);
-    DNNData *dnn_output = &ctx->output;
-
-    switch (frame->format) {
-    case AV_PIX_FMT_RGB24:
-    case AV_PIX_FMT_BGR24:
-        if (dnn_output->dt == DNN_FLOAT) {
-            sws_scale(ctx->sws_grayf32_to_gray8, (const uint8_t *[4]){(const uint8_t *)dnn_output->data, 0, 0, 0},
-                      (const int[4]){frame->width * 3 * sizeof(float), 0, 0, 0},
-                      0, frame->height, (uint8_t * const*)frame->data, frame->linesize);
-
-        } else {
-            av_assert0(dnn_output->dt == DNN_UINT8);
-            av_image_copy_plane(frame->data[0], frame->linesize[0],
-                                dnn_output->data, bytewidth,
-                                bytewidth, frame->height);
-        }
-        return 0;
-    case AV_PIX_FMT_GRAY8:
-        // it is possible that data type of dnn output is float32,
-        // need to add support for such case when needed.
-        av_assert0(dnn_output->dt == DNN_UINT8);
-        av_image_copy_plane(frame->data[0], frame->linesize[0],
-                            dnn_output->data, bytewidth,
-                            bytewidth, frame->height);
-        return 0;
-    case AV_PIX_FMT_GRAYF32:
-        av_assert0(dnn_output->dt == DNN_FLOAT);
-        av_image_copy_plane(frame->data[0], frame->linesize[0],
-                            dnn_output->data, bytewidth,
-                            bytewidth, frame->height);
-        return 0;
-    case AV_PIX_FMT_YUV420P:
-    case AV_PIX_FMT_YUV422P:
-    case AV_PIX_FMT_YUV444P:
-    case AV_PIX_FMT_YUV410P:
-    case AV_PIX_FMT_YUV411P:
-        sws_scale(ctx->sws_grayf32_to_gray8, (const uint8_t *[4]){(const uint8_t *)dnn_output->data, 0, 0, 0},
-                  (const int[4]){frame->width * sizeof(float), 0, 0, 0},
-                  0, frame->height, (uint8_t * const*)frame->data, frame->linesize);
-        return 0;
-    default:
-        return AVERROR(EIO);
-    }
+    av_frame_free(&fake_in);
+    av_frame_free(&out);
+    prepare_uv_scale(outlink);
 
     return 0;
 }
 
-static av_always_inline int isPlanarYUV(enum AVPixelFormat pix_fmt)
-{
-    const AVPixFmtDescriptor *desc = av_pix_fmt_desc_get(pix_fmt);
-    av_assert0(desc);
-    return !(desc->flags & AV_PIX_FMT_FLAG_RGB) && desc->nb_components == 3;
-}
-
 static int copy_uv_planes(DnnProcessingContext *ctx, AVFrame *out, const AVFrame *in)
 {
     const AVPixFmtDescriptor *desc;
@@ -453,11 +293,9 @@  static int filter_frame(AVFilterLink *inlink, AVFrame *in)
     DNNReturnType dnn_result;
     AVFrame *out;
 
-    copy_from_frame_to_dnn(ctx, in);
-
-    dnn_result = (ctx->dnn_module->execute_model)(ctx->model, &ctx->output, (const char **)&ctx->model_outputname, 1);
-    if (dnn_result != DNN_SUCCESS){
-        av_log(ctx, AV_LOG_ERROR, "failed to execute model\n");
+    dnn_result = (ctx->model->set_input)(ctx->model->model, in, ctx->model_inputname);
+    if (dnn_result != DNN_SUCCESS) {
+        av_log(ctx, AV_LOG_ERROR, "could not set input for the model\n");
         av_frame_free(&in);
         return AVERROR(EIO);
     }
@@ -467,9 +305,15 @@  static int filter_frame(AVFilterLink *inlink, AVFrame *in)
         av_frame_free(&in);
         return AVERROR(ENOMEM);
     }
-
     av_frame_copy_props(out, in);
-    copy_from_dnn_to_frame(ctx, out);
+
+    dnn_result = (ctx->dnn_module->execute_model)(ctx->model, (const char **)&ctx->model_outputname, 1, out);
+    if (dnn_result != DNN_SUCCESS){
+        av_log(ctx, AV_LOG_ERROR, "failed to execute model\n");
+        av_frame_free(&in);
+        av_frame_free(&out);
+        return AVERROR(EIO);
+    }
 
     if (isPlanarYUV(in->format))
         copy_uv_planes(ctx, out, in);
@@ -482,8 +326,6 @@  static av_cold void uninit(AVFilterContext *ctx)
 {
     DnnProcessingContext *context = ctx->priv;
 
-    sws_freeContext(context->sws_gray8_to_grayf32);
-    sws_freeContext(context->sws_grayf32_to_gray8);
     sws_freeContext(context->sws_uv_scale);
 
     if (context->dnn_module)
diff --git a/libavfilter/vf_sr.c b/libavfilter/vf_sr.c
index 445777f0c6..2eda8c3219 100644
--- a/libavfilter/vf_sr.c
+++ b/libavfilter/vf_sr.c
@@ -41,11 +41,10 @@  typedef struct SRContext {
     DNNBackendType backend_type;
     DNNModule *dnn_module;
     DNNModel *model;
-    DNNData input;
-    DNNData output;
     int scale_factor;
-    struct SwsContext *sws_contexts[3];
-    int sws_slice_h, sws_input_linesize, sws_output_linesize;
+    struct SwsContext *sws_uv_scale;
+    int sws_uv_height;
+    struct SwsContext *sws_pre_scale;
 } SRContext;
 
 #define OFFSET(x) offsetof(SRContext, x)
@@ -87,11 +86,6 @@  static av_cold int init(AVFilterContext *context)
         return AVERROR(EIO);
     }
 
-    sr_context->input.dt = DNN_FLOAT;
-    sr_context->sws_contexts[0] = NULL;
-    sr_context->sws_contexts[1] = NULL;
-    sr_context->sws_contexts[2] = NULL;
-
     return 0;
 }
 
@@ -111,95 +105,63 @@  static int query_formats(AVFilterContext *context)
     return ff_set_common_formats(context, formats_list);
 }
 
-static int config_props(AVFilterLink *inlink)
+static int config_output(AVFilterLink *outlink)
 {
-    AVFilterContext *context = inlink->dst;
-    SRContext *sr_context = context->priv;
-    AVFilterLink *outlink = context->outputs[0];
+    AVFilterContext *context = outlink->src;
+    SRContext *ctx = context->priv;
     DNNReturnType result;
-    int sws_src_h, sws_src_w, sws_dst_h, sws_dst_w;
+    AVFilterLink *inlink = context->inputs[0];
+    AVFrame *out = NULL;
     const char *model_output_name = "y";
 
-    sr_context->input.width = inlink->w * sr_context->scale_factor;
-    sr_context->input.height = inlink->h * sr_context->scale_factor;
-    sr_context->input.channels = 1;
-
-    result = (sr_context->model->set_input)(sr_context->model->model, &sr_context->input, "x");
-    if (result != DNN_SUCCESS){
-        av_log(context, AV_LOG_ERROR, "could not set input and output for the model\n");
+    AVFrame *fake_in = ff_get_video_buffer(inlink, inlink->w, inlink->h);
+    result = (ctx->model->set_input)(ctx->model->model, fake_in, "x");
+    if (result != DNN_SUCCESS) {
+        av_log(context, AV_LOG_ERROR, "could not set input for the model\n");
         return AVERROR(EIO);
     }
 
-    result = (sr_context->dnn_module->execute_model)(sr_context->model, &sr_context->output, &model_output_name, 1);
+    // have a try run in case that the dnn model resize the frame
+    out = ff_get_video_buffer(inlink, inlink->w, inlink->h);
+    result = (ctx->dnn_module->execute_model)(ctx->model, (const char **)&model_output_name, 1, out);
     if (result != DNN_SUCCESS){
         av_log(context, AV_LOG_ERROR, "failed to execute loaded model\n");
         return AVERROR(EIO);
     }
 
-    if (sr_context->input.height != sr_context->output.height || sr_context->input.width != sr_context->output.width){
-        sr_context->input.width = inlink->w;
-        sr_context->input.height = inlink->h;
-        result = (sr_context->model->set_input)(sr_context->model->model, &sr_context->input, "x");
-        if (result != DNN_SUCCESS){
-            av_log(context, AV_LOG_ERROR, "could not set input and output for the model\n");
-            return AVERROR(EIO);
-        }
-        result = (sr_context->dnn_module->execute_model)(sr_context->model, &sr_context->output, &model_output_name, 1);
-        if (result != DNN_SUCCESS){
-            av_log(context, AV_LOG_ERROR, "failed to execute loaded model\n");
-            return AVERROR(EIO);
-        }
-        sr_context->scale_factor = 0;
-    }
-    outlink->h = sr_context->output.height;
-    outlink->w = sr_context->output.width;
-    sr_context->sws_contexts[1] = sws_getContext(sr_context->input.width, sr_context->input.height, AV_PIX_FMT_GRAY8,
-                                                 sr_context->input.width, sr_context->input.height, AV_PIX_FMT_GRAYF32,
-                                                 0, NULL, NULL, NULL);
-    sr_context->sws_input_linesize = sr_context->input.width << 2;
-    sr_context->sws_contexts[2] = sws_getContext(sr_context->output.width, sr_context->output.height, AV_PIX_FMT_GRAYF32,
-                                                 sr_context->output.width, sr_context->output.height, AV_PIX_FMT_GRAY8,
-                                                 0, NULL, NULL, NULL);
-    sr_context->sws_output_linesize = sr_context->output.width << 2;
-    if (!sr_context->sws_contexts[1] || !sr_context->sws_contexts[2]){
-        av_log(context, AV_LOG_ERROR, "could not create SwsContext for conversions\n");
-        return AVERROR(ENOMEM);
-    }
-    if (sr_context->scale_factor){
-        sr_context->sws_contexts[0] = sws_getContext(inlink->w, inlink->h, inlink->format,
-                                                     outlink->w, outlink->h, outlink->format,
-                                                     SWS_BICUBIC, NULL, NULL, NULL);
-        if (!sr_context->sws_contexts[0]){
-            av_log(context, AV_LOG_ERROR, "could not create SwsContext for scaling\n");
-            return AVERROR(ENOMEM);
-        }
-        sr_context->sws_slice_h = inlink->h;
-    } else {
+    if (fake_in->width != out->width || fake_in->height != out->height) {
+        //espcn
+        outlink->w = out->width;
+        outlink->h = out->height;
         if (inlink->format != AV_PIX_FMT_GRAY8){
             const AVPixFmtDescriptor *desc = av_pix_fmt_desc_get(inlink->format);
-            sws_src_h = AV_CEIL_RSHIFT(sr_context->input.height, desc->log2_chroma_h);
-            sws_src_w = AV_CEIL_RSHIFT(sr_context->input.width, desc->log2_chroma_w);
-            sws_dst_h = AV_CEIL_RSHIFT(sr_context->output.height, desc->log2_chroma_h);
-            sws_dst_w = AV_CEIL_RSHIFT(sr_context->output.width, desc->log2_chroma_w);
-
-            sr_context->sws_contexts[0] = sws_getContext(sws_src_w, sws_src_h, AV_PIX_FMT_GRAY8,
-                                                         sws_dst_w, sws_dst_h, AV_PIX_FMT_GRAY8,
-                                                         SWS_BICUBIC, NULL, NULL, NULL);
-            if (!sr_context->sws_contexts[0]){
-                av_log(context, AV_LOG_ERROR, "could not create SwsContext for scaling\n");
-                return AVERROR(ENOMEM);
-            }
-            sr_context->sws_slice_h = sws_src_h;
+            int sws_src_h = AV_CEIL_RSHIFT(inlink->h, desc->log2_chroma_h);
+            int sws_src_w = AV_CEIL_RSHIFT(inlink->w, desc->log2_chroma_w);
+            int sws_dst_h = AV_CEIL_RSHIFT(outlink->h, desc->log2_chroma_h);
+            int sws_dst_w = AV_CEIL_RSHIFT(outlink->w, desc->log2_chroma_w);
+            ctx->sws_uv_scale = sws_getContext(sws_src_w, sws_src_h, AV_PIX_FMT_GRAY8,
+                                               sws_dst_w, sws_dst_h, AV_PIX_FMT_GRAY8,
+                                               SWS_BICUBIC, NULL, NULL, NULL);
+            ctx->sws_uv_height = sws_src_h;
         }
+    } else {
+        //srcnn
+        outlink->w = out->width * ctx->scale_factor;
+        outlink->h = out->height * ctx->scale_factor;
+        ctx->sws_pre_scale = sws_getContext(inlink->w, inlink->h, inlink->format,
+                                        outlink->w, outlink->h, outlink->format,
+                                        SWS_BICUBIC, NULL, NULL, NULL);
     }
 
+    av_frame_free(&fake_in);
+    av_frame_free(&out);
     return 0;
 }
 
 static int filter_frame(AVFilterLink *inlink, AVFrame *in)
 {
     AVFilterContext *context = inlink->dst;
-    SRContext *sr_context = context->priv;
+    SRContext *ctx = context->priv;
     AVFilterLink *outlink = context->outputs[0];
     AVFrame *out = ff_get_video_buffer(outlink, outlink->w, outlink->h);
     DNNReturnType dnn_result;
@@ -211,45 +173,44 @@  static int filter_frame(AVFilterLink *inlink, AVFrame *in)
         return AVERROR(ENOMEM);
     }
     av_frame_copy_props(out, in);
-    out->height = sr_context->output.height;
-    out->width = sr_context->output.width;
-    if (sr_context->scale_factor){
-        sws_scale(sr_context->sws_contexts[0], (const uint8_t **)in->data, in->linesize,
-                  0, sr_context->sws_slice_h, out->data, out->linesize);
 
-        sws_scale(sr_context->sws_contexts[1], (const uint8_t **)out->data, out->linesize,
-                  0, out->height, (uint8_t * const*)(&sr_context->input.data),
-                  (const int [4]){sr_context->sws_input_linesize, 0, 0, 0});
+    if (ctx->sws_pre_scale) {
+        sws_scale(ctx->sws_pre_scale,
+                    (const uint8_t **)in->data, in->linesize, 0, in->height,
+                    out->data, out->linesize);
+        dnn_result = (ctx->model->set_input)(ctx->model->model, out, "x");
     } else {
-        if (sr_context->sws_contexts[0]){
-            sws_scale(sr_context->sws_contexts[0], (const uint8_t **)(in->data + 1), in->linesize + 1,
-                      0, sr_context->sws_slice_h, out->data + 1, out->linesize + 1);
-            sws_scale(sr_context->sws_contexts[0], (const uint8_t **)(in->data + 2), in->linesize + 2,
-                      0, sr_context->sws_slice_h, out->data + 2, out->linesize + 2);
-        }
+        dnn_result = (ctx->model->set_input)(ctx->model->model, in, "x");
+    }
 
-        sws_scale(sr_context->sws_contexts[1], (const uint8_t **)in->data, in->linesize,
-                  0, in->height, (uint8_t * const*)(&sr_context->input.data),
-                  (const int [4]){sr_context->sws_input_linesize, 0, 0, 0});
+    if (dnn_result != DNN_SUCCESS) {
+        av_frame_free(&in);
+        av_frame_free(&out);
+        av_log(context, AV_LOG_ERROR, "could not set input for the model\n");
+        return AVERROR(EIO);
     }
-    av_frame_free(&in);
 
-    dnn_result = (sr_context->dnn_module->execute_model)(sr_context->model, &sr_context->output, &model_output_name, 1);
+    dnn_result = (ctx->dnn_module->execute_model)(ctx->model, (const char **)&model_output_name, 1, out);
     if (dnn_result != DNN_SUCCESS){
-        av_log(context, AV_LOG_ERROR, "failed to execute loaded model\n");
+        av_log(ctx, AV_LOG_ERROR, "failed to execute loaded model\n");
+        av_frame_free(&in);
+        av_frame_free(&out);
         return AVERROR(EIO);
     }
 
-    sws_scale(sr_context->sws_contexts[2], (const uint8_t *[4]){(const uint8_t *)sr_context->output.data, 0, 0, 0},
-              (const int[4]){sr_context->sws_output_linesize, 0, 0, 0},
-              0, out->height, (uint8_t * const*)out->data, out->linesize);
+    if (ctx->sws_uv_scale) {
+        sws_scale(ctx->sws_uv_scale, (const uint8_t **)(in->data + 1), in->linesize + 1,
+                  0, ctx->sws_uv_height, out->data + 1, out->linesize + 1);
+        sws_scale(ctx->sws_uv_scale, (const uint8_t **)(in->data + 2), in->linesize + 2,
+                  0, ctx->sws_uv_height, out->data + 2, out->linesize + 2);
+    }
 
+    av_frame_free(&in);
     return ff_filter_frame(outlink, out);
 }
 
 static av_cold void uninit(AVFilterContext *context)
 {
-    int i;
     SRContext *sr_context = context->priv;
 
     if (sr_context->dnn_module){
@@ -257,16 +218,14 @@  static av_cold void uninit(AVFilterContext *context)
         av_freep(&sr_context->dnn_module);
     }
 
-    for (i = 0; i < 3; ++i){
-        sws_freeContext(sr_context->sws_contexts[i]);
-    }
+    sws_freeContext(sr_context->sws_uv_scale);
+    sws_freeContext(sr_context->sws_pre_scale);
 }
 
 static const AVFilterPad sr_inputs[] = {
     {
         .name         = "default",
         .type         = AVMEDIA_TYPE_VIDEO,
-        .config_props = config_props,
         .filter_frame = filter_frame,
     },
     { NULL }
@@ -275,6 +234,7 @@  static const AVFilterPad sr_inputs[] = {
 static const AVFilterPad sr_outputs[] = {
     {
         .name = "default",
+        .config_props = config_output,
         .type = AVMEDIA_TYPE_VIDEO,
     },
     { NULL }