diff mbox series

[FFmpeg-devel,1/2] avfilter/dnn/dnn_backend_trt: Update with master and sign-off

Message ID 20210725115843.8235-1-xiaoweiw@nvidia.com
State New
Headers show
Series [FFmpeg-devel,1/2] avfilter/dnn/dnn_backend_trt: Update with master and sign-off
Related show

Checks

Context Check Description
andriy/x86_make fail Make failed
andriy/PPC64_make warning Make failed

Commit Message

Xiaowei Wang July 25, 2021, 11:58 a.m. UTC
Signed-off-by: Xiaowei Wang <xiaoweiw@nvidia.com>
---
 configure                              |   6 +-
 libavfilter/dnn/Makefile               |   1 +
 libavfilter/dnn/dnn_backend_tensorrt.c |  77 +++
 libavfilter/dnn/dnn_backend_tensorrt.h |  72 +++
 libavfilter/dnn/dnn_interface.c        |  10 +
 libavfilter/dnn/dnn_io_proc_trt.cu     |  55 ++
 libavfilter/dnn/trt_class_wrapper.cpp  | 731 +++++++++++++++++++++++++
 libavfilter/dnn/trt_class_wrapper.h    |  49 ++
 libavfilter/dnn_interface.h            |   2 +-
 libavfilter/vf_dnn_processing.c        |   3 +
 10 files changed, 1004 insertions(+), 2 deletions(-)
 create mode 100644 libavfilter/dnn/dnn_backend_tensorrt.c
 create mode 100644 libavfilter/dnn/dnn_backend_tensorrt.h
 create mode 100644 libavfilter/dnn/dnn_io_proc_trt.cu
 create mode 100644 libavfilter/dnn/trt_class_wrapper.cpp
 create mode 100644 libavfilter/dnn/trt_class_wrapper.h
diff mbox series

Patch

diff --git a/configure b/configure
index b124411609..e496a66621 100755
--- a/configure
+++ b/configure
@@ -272,6 +272,8 @@  External library support:
   --enable-libsvtav1       enable AV1 encoding via SVT [no]
   --enable-libtensorflow   enable TensorFlow as a DNN module backend
                            for DNN based filters like sr [no]
+  --enable-libtensorrt     enable TensorRT as a DNN module backend
+                           for DNN based filters like sr [no]
   --enable-libtesseract    enable Tesseract, needed for ocr filter [no]
   --enable-libtheora       enable Theora encoding via libtheora [no]
   --enable-libtls          enable LibreSSL (via libtls), needed for https support
@@ -1839,6 +1841,7 @@  EXTERNAL_LIBRARY_LIST="
     libssh
     libsvtav1
     libtensorflow
+    libtensorrt
     libtesseract
     libtheora
     libtwolame
@@ -2660,7 +2663,7 @@  cbs_mpeg2_select="cbs"
 cbs_vp9_select="cbs"
 dct_select="rdft"
 dirac_parse_select="golomb"
-dnn_suggest="libtensorflow libopenvino"
+dnn_suggest="libtensorflow libopenvino libtensorrt"
 dnn_deps="avformat swscale"
 error_resilience_select="me_cmp"
 faandct_deps="faan"
@@ -6487,6 +6490,7 @@  enabled libspeex          && require_pkg_config libspeex speex speex/speex.h spe
 enabled libsrt            && require_pkg_config libsrt "srt >= 1.3.0" srt/srt.h srt_socket
 enabled libsvtav1         && require_pkg_config libsvtav1 "SvtAv1Enc >= 0.8.4" EbSvtAv1Enc.h svt_av1_enc_init_handle
 enabled libtensorflow     && require libtensorflow tensorflow/c/c_api.h TF_Version -ltensorflow
+enabled libtensorrt       && require_cpp libtensorrt NvInfer.h nvinfer1::Dims2 -lnvinfer -lcudart
 enabled libtesseract      && require_pkg_config libtesseract tesseract tesseract/capi.h TessBaseAPICreate
 enabled libtheora         && require libtheora theora/theoraenc.h th_info_init -ltheoraenc -ltheoradec -logg
 enabled libtls            && require_pkg_config libtls libtls tls.h tls_configure
diff --git a/libavfilter/dnn/Makefile b/libavfilter/dnn/Makefile
index 4cfbce0efc..f9ea7ca386 100644
--- a/libavfilter/dnn/Makefile
+++ b/libavfilter/dnn/Makefile
@@ -16,5 +16,6 @@  OBJS-$(CONFIG_DNN)                           += dnn/dnn_backend_native_layer_mat
 
 DNN-OBJS-$(CONFIG_LIBTENSORFLOW)             += dnn/dnn_backend_tf.o
 DNN-OBJS-$(CONFIG_LIBOPENVINO)               += dnn/dnn_backend_openvino.o
+DNN-OBJS-$(CONFIG_LIBTENSORRT)               += dnn/dnn_backend_tensorrt.o dnn/trt_class_wrapper.o dnn/dnn_io_proc_trt.ptx.o
 
 OBJS-$(CONFIG_DNN)                           += $(DNN-OBJS-yes)
diff --git a/libavfilter/dnn/dnn_backend_tensorrt.c b/libavfilter/dnn/dnn_backend_tensorrt.c
new file mode 100644
index 0000000000..b45b770a77
--- /dev/null
+++ b/libavfilter/dnn/dnn_backend_tensorrt.c
@@ -0,0 +1,77 @@ 
+/*
+* Copyright (c) 2021 NVIDIA CORPORATION. All rights reserved.
+*
+* Permission is hereby granted, free of charge, to any person obtaining a
+* copy of this software and associated documentation files (the "Software"),
+* to deal in the Software without restriction, including without limitation
+* the rights to use, copy, modify, merge, publish, distribute, sublicense,
+* and/or sell copies of the Software, and to permit persons to whom the
+* Software is furnished to do so, subject to the following conditions:
+*
+* The above copyright notice and this permission notice shall be included in
+* all copies or substantial portions of the Software.
+*
+* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
+* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+* DEALINGS IN THE SOFTWARE.
+ */
+
+/**
+ * @file
+ * DNN TensorRT backend implementation.
+ */
+
+#include "trt_class_wrapper.h"
+#include "dnn_backend_tensorrt.h"
+
+#include "libavutil/mem.h"
+#include "libavformat/avio.h"
+#include "libavutil/avassert.h"
+#include "libavutil/opt.h"
+#include "libavutil/avstring.h"
+#include "dnn_io_proc.h"
+#include "../internal.h"
+#include "libavutil/buffer.h"
+#include <stdint.h>
+
+#define OFFSET(x) offsetof(TRTContext, x)
+#define FLAGS AV_OPT_FLAG_FILTERING_PARAM
+static const AVOption dnn_tensorrt_options[] = {
+    { "device", "index of the GPU to run model", OFFSET(options.device), AV_OPT_TYPE_INT, { .i64 = 0 }, 0, INT_MAX, FLAGS },
+    { NULL }
+};
+AVFILTER_DEFINE_CLASS(dnn_tensorrt);
+
+DNNModel *ff_dnn_load_model_trt(const char *model_filename,DNNFunctionType func_type, 
+                                const char *options, AVFilterContext *filter_ctx)
+{
+    DNNModel *model = NULL;
+    model = (DNNModel*)av_mallocz(sizeof(DNNModel));
+    if (!model){
+        return NULL;
+    }
+
+    trt_load_model(model, model_filename, &dnn_tensorrt_class, options);
+
+    return model;
+}
+
+DNNReturnType ff_dnn_execute_model_trt(const DNNModel *model, DNNExecBaseParams *exec_params)
+{
+    execute_model_trt(model, exec_params->input_name, exec_params->in_frame, 
+                    exec_params->output_names, exec_params->nb_output, exec_params->out_frame);
+    return DNN_SUCCESS;
+}
+
+void ff_dnn_free_model_trt(DNNModel **model)
+{
+    if (*model)
+    {
+        free_model_trt(*model);
+        av_freep(model);
+    }
+}
\ No newline at end of file
diff --git a/libavfilter/dnn/dnn_backend_tensorrt.h b/libavfilter/dnn/dnn_backend_tensorrt.h
new file mode 100644
index 0000000000..d700cb247f
--- /dev/null
+++ b/libavfilter/dnn/dnn_backend_tensorrt.h
@@ -0,0 +1,72 @@ 
+/*
+* Copyright (c) 2021 NVIDIA CORPORATION. All rights reserved.
+*
+* Permission is hereby granted, free of charge, to any person obtaining a
+* copy of this software and associated documentation files (the "Software"),
+* to deal in the Software without restriction, including without limitation
+* the rights to use, copy, modify, merge, publish, distribute, sublicense,
+* and/or sell copies of the Software, and to permit persons to whom the
+* Software is furnished to do so, subject to the following conditions:
+*
+* The above copyright notice and this permission notice shall be included in
+* all copies or substantial portions of the Software.
+*
+* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
+* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+* DEALINGS IN THE SOFTWARE.
+ */
+/**
+ * @file
+ * DNN inference functions interface for TensorRT backend.
+ */
+
+
+#ifndef AVFILTER_DNN_DNN_BACKEND_TENSORRT_H
+#define AVFILTER_DNN_DNN_BACKEND_TENSORRT_H
+
+#ifdef __cplusplus
+extern "C"
+{
+#endif
+
+    #include "../dnn_interface.h"
+    #include "libavutil/hwcontext.h"
+    #include "libavutil/hwcontext_cuda_internal.h"
+
+    typedef struct TRTOptions{
+        int device;
+    } TRTOptions;
+
+    typedef struct TRTContext{
+        const AVClass *av_class;
+        TRTOptions options;
+        AVBufferRef *hwdevice;
+        // Host memory pointer to input/output image data
+        void *host_in, *host_out;
+        // Device memory pointer to the fp32 CHW input/output of the model
+        // The device memory is only allocated once and reused during inference
+        // Multiple input/output is not supported
+        CUdeviceptr trt_in, trt_out;
+        // Device memory pointer to 8-bit image data
+        CUdeviceptr frame_in, frame_out;
+
+        CUmodule cu_module;
+        CUfunction cu_func_frame_to_dnn, cu_func_dnn_to_frame;
+
+        int channels;
+    } TRTContext;
+    
+    DNNModel *ff_dnn_load_model_trt(const char *model_filename,DNNFunctionType func_type, 
+                                    const char *options, AVFilterContext *filter_ctx);
+
+    DNNReturnType ff_dnn_execute_model_trt(const DNNModel *model, DNNExecBaseParams *exec_params);
+
+    void ff_dnn_free_model_trt(DNNModel **model);
+#ifdef __cplusplus
+}
+#endif
+#endif
diff --git a/libavfilter/dnn/dnn_interface.c b/libavfilter/dnn/dnn_interface.c
index 02e532fc1b..c4fdfb7e7b 100644
--- a/libavfilter/dnn/dnn_interface.c
+++ b/libavfilter/dnn/dnn_interface.c
@@ -27,6 +27,7 @@ 
 #include "dnn_backend_native.h"
 #include "dnn_backend_tf.h"
 #include "dnn_backend_openvino.h"
+#include "dnn_backend_tensorrt.h"
 #include "libavutil/mem.h"
 
 DNNModule *ff_get_dnn_module(DNNBackendType backend_type)
@@ -65,6 +66,15 @@  DNNModule *ff_get_dnn_module(DNNBackendType backend_type)
     #else
         av_freep(&dnn_module);
         return NULL;
+    #endif
+    case DNN_TRT:
+    #if (CONFIG_LIBTENSORRT == 1)
+        dnn_module->load_model = &ff_dnn_load_model_trt;
+        dnn_module->execute_model = &ff_dnn_execute_model_trt;
+        dnn_module->free_model = &ff_dnn_free_model_trt;
+    #else
+        av_freep(&dnn_module);
+        return NULL;
     #endif
         break;
     default:
diff --git a/libavfilter/dnn/dnn_io_proc_trt.cu b/libavfilter/dnn/dnn_io_proc_trt.cu
new file mode 100644
index 0000000000..030cfd2f60
--- /dev/null
+++ b/libavfilter/dnn/dnn_io_proc_trt.cu
@@ -0,0 +1,55 @@ 
+#include <bits/stdint-uintn.h>
+extern "C" {
+
+__global__ void frame_to_dnn(uint8_t *src, int src_linesize, float *dst, int dst_linesize, 
+                             int width, int height, int unpack_rgb)
+{
+    int x = blockIdx.x * blockDim.x + threadIdx.x;
+    int y = blockIdx.y * blockDim.y + threadIdx.y;
+
+    if (x >= width || y >= height)
+        return;
+    
+    if (unpack_rgb)
+    {
+        uchar3 rgb = *((uchar3 *)(src + y * src_linesize) + x);
+        dst[y * dst_linesize + x] = (float)rgb.x;
+        dst[y * dst_linesize + x + dst_linesize * height] = (float)rgb.y;
+        dst[y * dst_linesize + x + 2 * dst_linesize * height] = (float)rgb.z;
+    }
+    else
+    {
+        dst[y * dst_linesize + x] = (float)src[y * src_linesize + x];
+    }
+}
+
+__device__ static float clamp(float x, float lower, float upper) {
+    return x < lower ? lower : (x > upper ? upper : x);
+}
+
+__global__ void dnn_to_frame(float *src, int src_linesize, uint8_t *dst, int dst_linesize, 
+                            int width, int height, int pack_rgb)
+{
+    int x = blockIdx.x * blockDim.x + threadIdx.x;
+    int y = blockIdx.y * blockDim.y + threadIdx.y;
+
+    if (x >= width || y >= height)
+        return;
+
+    if (pack_rgb)
+    {
+        uint8_t r = (uint8_t)clamp(src[y * src_linesize + x], .0f, 255.0f);
+        uint8_t g = (uint8_t)clamp(src[y * src_linesize + x + src_linesize * height], .0f, 255.0f);
+        uint8_t b = (uint8_t)clamp(src[y * src_linesize + x + 2 * src_linesize * height], .0f, 255.0f);
+
+        uchar3 rgb = make_uchar3(r, g, b);
+
+        *((uchar3*)(dst + y * dst_linesize) + x) = rgb;
+    }
+    else
+    {
+        dst[y * dst_linesize + x] = (uint8_t)clamp(src[y * src_linesize + x], .0f, 255.0f);
+    }
+}
+
+}
\ No newline at end of file
diff --git a/libavfilter/dnn/trt_class_wrapper.cpp b/libavfilter/dnn/trt_class_wrapper.cpp
new file mode 100644
index 0000000000..dac433b690
--- /dev/null
+++ b/libavfilter/dnn/trt_class_wrapper.cpp
@@ -0,0 +1,731 @@ 
+/*
+* Copyright (c) 2021 NVIDIA CORPORATION. All rights reserved.
+*
+* Permission is hereby granted, free of charge, to any person obtaining a
+* copy of this software and associated documentation files (the "Software"),
+* to deal in the Software without restriction, including without limitation
+* the rights to use, copy, modify, merge, publish, distribute, sublicense,
+* and/or sell copies of the Software, and to permit persons to whom the
+* Software is furnished to do so, subject to the following conditions:
+*
+* The above copyright notice and this permission notice shall be included in
+* all copies or substantial portions of the Software.
+*
+* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
+* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+* DEALINGS IN THE SOFTWARE.
+ */
+ 
+/**
+ * @file
+ * DNN TensorRT backend C++ wrapper.
+ */
+
+#include "trt_class_wrapper.h"
+#include "dnn_backend_tensorrt.h"
+
+#include <vector>
+#include <map>
+#include <iostream>
+#include <fstream>
+#include <iomanip>
+#include <string>
+#include <chrono>
+#include <sstream>
+#include <mutex>
+
+#ifdef __cplusplus
+extern "C"
+{
+#endif
+    
+    #include "libavutil/buffer.h"
+    #include "libavutil/hwcontext.h"
+    #include "libavutil/cuda_check.h"
+    #include "libavutil/log.h"
+    #include "libavutil/opt.h"
+    #include "libavformat/avio.h"
+    #include "dnn_io_proc.h"
+    #include "libavutil/frame.h"
+    #include "libavutil/pixdesc.h"
+    #include "libavutil/pixfmt.h"
+    #include "libavutil/mem.h"
+
+#ifdef __cplusplus
+}
+#endif
+
+#include <sys/stat.h>
+#include <time.h>
+#include <unistd.h>
+#include <sys/socket.h>
+#include <netinet/in.h>
+#include <arpa/inet.h>
+#define SOCKET int
+#define INVALID_SOCKET -1
+
+#include <cuda_runtime.h>
+#include <NvInfer.h>
+
+using namespace nvinfer1;
+using namespace std;
+
+#define DIV_UP(a, b) ( ((a) + (b) - 1) / (b) )
+#define BLOCKX 32
+#define BLOCKY 16
+
+// Self-defined CUDA check functions as cuda_check.h is not available for cpp due to void* function pointers
+inline bool check(CUresult e, TRTContext *ctx, CudaFunctions* cu, int iLine, const char *szFile) {
+    if (e != CUDA_SUCCESS) {
+        const char* pStr;
+        cu->cuGetErrorName(e, &pStr);
+        av_log(ctx, AV_LOG_ERROR, "CUDA driver API error: %s, at line %d in file %s\n",
+        pStr, iLine, szFile);
+        return false;
+    }
+    return true;
+}
+
+inline bool check(cudaError_t e, TRTContext *ctx, int iLine, const char *szFile) {
+    if (e != cudaSuccess) {
+        av_log(ctx, AV_LOG_ERROR, "CUDA runtime API error: %s, at line %d in file %s\n",
+            cudaGetErrorName(e), iLine, szFile);
+        return false;
+    }
+    return true;
+}
+
+inline bool check(bool bSuccess, TRTContext *ctx, int iLine, const char *szFile) {
+    if (!bSuccess) {
+        av_log(ctx, AV_LOG_ERROR, "Error at line %d in file %s\n", iLine, szFile);
+        return false;
+    }
+    return true;
+}
+
+#define ck(call, ctx) check(call, ctx, __LINE__, __FILE__)
+#define ck_cu(call) check(call, ctx, cu, __LINE__, __FILE__)
+
+inline std::string to_string(nvinfer1::Dims const &dim) {
+    std::ostringstream oss;
+    oss << "(";
+    for (int i = 0; i < dim.nbDims; i++) {
+        oss << dim.d[i] << ", ";
+    }
+    oss << ")";
+    return oss.str();
+}
+
+typedef ICudaEngine *(*BuildEngineProcType)(IBuilder *builder, void *pData);
+
+struct IOInfo {
+    string name;
+    bool bInput;
+    nvinfer1::Dims dim;
+    nvinfer1::DataType dataType;
+
+    string GetDimString() {
+        return ::to_string(dim);
+    }
+    string GetDataTypeString() {
+        static string aTypeName[] = {"float", "half", "int8", "int32", "bool"};
+        return aTypeName[(int)dataType];
+    }
+    size_t GetNumBytes() {
+        static int aSize[] = {4, 2, 1, 4, 1};
+        size_t nSize = aSize[(int)dataType];
+        for (int i = 0; i < dim.nbDims; i++) {
+            nSize *= dim.d[i];
+        }
+        return nSize;
+    }
+    string to_string() {
+        ostringstream oss;
+        oss << setw(6) << (bInput ? "input" : "output") 
+            << " | " << setw(5) << GetDataTypeString() 
+            << " | " << GetDimString() 
+            << " | " << "size=" << GetNumBytes()
+            << " | " << name;
+        return oss.str();
+    }
+};
+
+class TrtLogger : public nvinfer1::ILogger {
+public:
+    TrtLogger(TRTContext *ctx) : ctx(ctx) {}
+    void log(Severity severity, const char* msg) override {
+        int log_level = AV_LOG_INFO;
+        switch (severity){
+            case nvinfer1::ILogger::Severity::kERROR:
+            log_level = AV_LOG_ERROR;
+            break;
+            case nvinfer1::ILogger::Severity::kWARNING:
+            log_level = AV_LOG_WARNING;
+            break;
+            case nvinfer1::ILogger::Severity::kINFO:
+            log_level = AV_LOG_INFO;
+            break;
+            case nvinfer1::ILogger::Severity::kVERBOSE:
+            log_level = AV_LOG_DEBUG;
+            break;
+            case nvinfer1::ILogger::Severity::kINTERNAL_ERROR:
+            log_level = AV_LOG_FATAL;
+            break;
+        }
+        av_log(ctx, log_level, "%s\n", msg);
+    }
+private:
+    TRTContext *ctx = nullptr;
+};
+    
+class TrtLite {
+public:
+    TrtLite(const char *szEnginePath, TRTContext *trt_ctx) : ctx(trt_ctx) {
+        uint8_t *pBuf = nullptr;
+        uint32_t nSize = 0;
+
+        trt_logger = new TrtLogger(trt_ctx);
+        
+        read_engine(&pBuf, &nSize, szEnginePath);
+        IRuntime *runtime = createInferRuntime(*trt_logger);
+        engine = runtime->deserializeCudaEngine(pBuf, nSize);
+        runtime->destroy();
+        if (!engine) {
+            av_log(ctx, AV_LOG_ERROR, "No engine generated\n");
+            return;
+        }
+        av_freep(&pBuf);
+    }
+    virtual ~TrtLite() {
+        if (context) {
+            context->destroy();
+        }
+        if (engine) {
+            engine->destroy();
+        }
+    }
+    ICudaEngine *GetEngine() {
+        return engine;
+    }
+    void Execute(int nBatch, vector<void *> &vdpBuf, cudaStream_t stm = 0, cudaEvent_t* evtInputConsumed = nullptr) {
+        if (!engine) {
+            av_log(ctx, AV_LOG_ERROR, "No engine\n");
+            return;
+        }
+        if (!engine->hasImplicitBatchDimension() && nBatch > 1) {
+            av_log(ctx, AV_LOG_ERROR, 
+                "Engine was built with explicit batch but is executed with batch size != 1. Results may be incorrect.\n");
+            return;
+        }
+        if (engine->getNbBindings() != vdpBuf.size()) {
+            av_log(ctx, AV_LOG_ERROR, "Number of bindings conflicts with input and output\n");
+            return;
+        }
+        if (!context) {
+            context = engine->createExecutionContext();
+            if (!context) {
+                av_log(ctx, AV_LOG_ERROR, "createExecutionContext() failed\n");
+                return;
+            }
+        }
+        ck(context->enqueue(nBatch, vdpBuf.data(), stm, evtInputConsumed), ctx);
+    }
+    void Execute(map<int, Dims> i2shape, vector<void *> &vdpBuf, cudaStream_t stm = 0, cudaEvent_t* evtInputConsumed = nullptr) {
+        if (!engine) {
+            av_log(ctx, AV_LOG_ERROR, "No engine\n");
+            return;
+        }
+        if (engine->hasImplicitBatchDimension()) {
+            av_log(ctx, AV_LOG_ERROR, "Engine was built with static-shaped input\n");
+            return;
+        }
+        if (engine->getNbBindings() != vdpBuf.size()) {
+            av_log(ctx, AV_LOG_ERROR, "Number of bindings conflicts with input and output\n");
+            return;
+        }
+        if (!context) {
+            context = engine->createExecutionContext();
+            if (!context) {
+                av_log(ctx, AV_LOG_ERROR, "createExecutionContext() failed\n");
+                return;
+            }
+        }
+        for (auto &it : i2shape) {
+            context->setBindingDimensions(it.first, it.second);
+        }
+        ck(context->enqueueV2(vdpBuf.data(), stm, evtInputConsumed), ctx);
+    }
+
+    vector<IOInfo> ConfigIO(int nBatchSize) {
+        vector<IOInfo> vInfo;
+        if (!engine) {
+            av_log(ctx, AV_LOG_ERROR, "No engine\n");
+            return vInfo;
+        }
+        if (!engine->hasImplicitBatchDimension()) {
+            av_log(ctx, AV_LOG_ERROR, "Engine must be built with implicit batch size (and static shape)\n");
+            return vInfo;
+        }
+        for (int i = 0; i < engine->getNbBindings(); i++) {
+            vInfo.push_back({string(engine->getBindingName(i)), engine->bindingIsInput(i), 
+                MakeDim(nBatchSize, engine->getBindingDimensions(i)), engine->getBindingDataType(i)});
+        }
+        return vInfo;
+    }
+    vector<IOInfo> ConfigIO(map<int, Dims> i2shape) {
+        vector<IOInfo> vInfo;
+        if (!engine) {
+            av_log(ctx, AV_LOG_ERROR, "No engine\n");
+            return vInfo;
+        }
+        if (engine->hasImplicitBatchDimension()) {
+            av_log(ctx, AV_LOG_ERROR, "Engine must be built with explicit batch size (to enable dynamic shape)\n");
+            return vInfo;
+        }
+        if (!context) {
+            context = engine->createExecutionContext();
+            if (!context) {
+                av_log(ctx, AV_LOG_ERROR, "createExecutionContext() failed\n");
+                return vInfo;
+            }
+        }
+        for (auto &it : i2shape) {
+            context->setBindingDimensions(it.first, it.second);
+        }
+        if (!context->allInputDimensionsSpecified()) {
+            av_log(ctx, AV_LOG_ERROR, "Not all binding shape are specified\n");
+            return vInfo;
+        }
+        for (int i = 0; i < engine->getNbBindings(); i++) {
+            vInfo.push_back({string(engine->getBindingName(i)), engine->bindingIsInput(i), 
+                context->getBindingDimensions(i), engine->getBindingDataType(i)});
+        }
+        return vInfo;
+    }
+
+    void PrintInfo() {
+        if (!engine) {
+            av_log(ctx, AV_LOG_ERROR, "No engine\n");
+            return;
+        }
+        av_log(ctx, AV_LOG_INFO, "nbBindings: %d\n", engine->getNbBindings());
+        // Only contains engine-level IO information: if dynamic shape is used,
+        // dimension -1 will be printed
+        for (int i = 0; i < engine->getNbBindings(); i++) {
+            av_log(ctx, AV_LOG_INFO, "#%d: %s\n", i, IOInfo{string(engine->getBindingName(i)), engine->bindingIsInput(i),
+                engine->getBindingDimensions(i), engine->getBindingDataType(i)}.to_string().c_str());
+        }
+    }
+
+    TRTContext *ctx = nullptr;
+    
+private:
+    void read_engine(uint8_t **engine_buf, uint32_t *engine_size, const char *engine_filename) {
+        AVIOContext *engine_file_ctx;
+        *engine_buf = nullptr;
+
+        if (avio_open(&engine_file_ctx, engine_filename, AVIO_FLAG_READ) < 0){
+            av_log(ctx, AV_LOG_ERROR, "Error reading engine file from disk!\n");
+            return;
+        }
+
+        uint32_t size = avio_size(engine_file_ctx);
+        uint8_t *buffer = (uint8_t*)av_malloc(size);
+        if (!buffer){
+            avio_closep(&engine_file_ctx);
+            av_log(ctx, AV_LOG_ERROR, "Error allocating memory for TRT engine.\n");
+            return;
+        }
+        uint32_t bytes_read = avio_read(engine_file_ctx, buffer, size);
+        avio_closep(&engine_file_ctx);
+        if (bytes_read != size){
+            av_freep(&buffer);
+            av_log(ctx, AV_LOG_ERROR, "Engine file size (%d) does not equal to read size (%d)\n", size, bytes_read);
+            return;
+        }
+
+        *engine_buf = buffer;
+        *engine_size = size;
+
+        return;
+    }
+    static size_t GetBytesOfBinding(int iBinding, ICudaEngine *engine, IExecutionContext *context = nullptr) {
+        size_t aValueSize[] = {4, 2, 1, 4, 1};
+        size_t nSize = aValueSize[(int)engine->getBindingDataType(iBinding)];
+        const Dims &dims = context ? context->getBindingDimensions(iBinding) : engine->getBindingDimensions(iBinding);
+        for (int i = 0; i < dims.nbDims; i++) {
+            nSize *= dims.d[i];
+        }
+        return nSize;
+    }
+    static nvinfer1::Dims MakeDim(int nBatchSize, nvinfer1::Dims dim) {
+        nvinfer1::Dims ret(dim);
+        for (int i = ret.nbDims; i > 0; i--) {
+            ret.d[i] = ret.d[i - 1];
+        }
+        ret.d[0] = nBatchSize;
+        ret.nbDims++;
+        return ret;
+    }
+
+    ICudaEngine *engine = nullptr;
+    IExecutionContext *context = nullptr;
+    TrtLogger *trt_logger = nullptr;
+};
+
+#define BATCH 1
+
+#ifdef __cplusplus
+extern "C"
+{
+#endif
+
+static DNNReturnType frame_to_dnn(AVFrame *inframe, TRTContext *ctx, int num_bytes)
+{
+    AVHWDeviceContext *hw_device = (AVHWDeviceContext*)ctx->hwdevice->data;
+    AVCUDADeviceContext *hw_ctx = (AVCUDADeviceContext*)hw_device->hwctx;
+    CudaFunctions *cu = hw_ctx->internal->cuda_dl;
+
+    const AVPixFmtDescriptor *desc = av_pix_fmt_desc_get((enum AVPixelFormat)inframe->format);
+    int unpack = (desc->flags & AV_PIX_FMT_FLAG_PLANAR) ? 0 : 1;
+    void *frame_to_dnn_args[] = {&ctx->frame_in, inframe->linesize, &ctx->trt_in, &inframe->width,
+                                &inframe->width, &inframe->height, &unpack};
+
+    CUDA_MEMCPY2D copy_param;
+    memset(&copy_param, 0, sizeof(copy_param));
+    copy_param.dstMemoryType = CU_MEMORYTYPE_DEVICE;
+    copy_param.dstDevice = ctx->frame_in;
+    copy_param.dstPitch = inframe->linesize[0];
+    copy_param.srcMemoryType = CU_MEMORYTYPE_HOST;
+    copy_param.srcHost = inframe->data[0];
+    copy_param.srcPitch = inframe->linesize[0];
+    copy_param.WidthInBytes = inframe->linesize[0];
+    copy_param.Height = inframe->height;
+
+    ck_cu(cu->cuMemcpy2DAsync(&copy_param, hw_ctx->stream));
+    ck_cu(cu->cuLaunchKernel(ctx->cu_func_frame_to_dnn, 
+                        DIV_UP(inframe->width, BLOCKX), DIV_UP(inframe->height, BLOCKY), 
+                        1, BLOCKX, BLOCKY, 1, 0, hw_ctx->stream, frame_to_dnn_args, NULL));
+
+    return DNN_SUCCESS;
+}
+
+static DNNReturnType dnn_to_frame(AVFrame *outframe, TRTContext *ctx, int num_bytes)
+{
+    AVHWDeviceContext *hw_device = (AVHWDeviceContext*)ctx->hwdevice->data;
+    AVCUDADeviceContext *hw_ctx = (AVCUDADeviceContext*)hw_device->hwctx;
+    CudaFunctions *cu = hw_ctx->internal->cuda_dl;
+
+    const AVPixFmtDescriptor *desc = av_pix_fmt_desc_get((enum AVPixelFormat)outframe->format);
+    int pack = (desc->flags & AV_PIX_FMT_FLAG_PLANAR) ? 0 : 1;
+    void *dnn_to_frame_args[] = {&ctx->trt_out, &outframe->width, &ctx->frame_out, &outframe->linesize[0],
+                                &outframe->width, &outframe->height, &pack};
+
+    CUDA_MEMCPY2D copy_param;
+    memset(&copy_param, 0, sizeof(copy_param));
+    copy_param.dstMemoryType = CU_MEMORYTYPE_HOST;
+    copy_param.dstHost = outframe->data[0];
+    copy_param.dstPitch = outframe->linesize[0];
+    copy_param.srcMemoryType = CU_MEMORYTYPE_DEVICE;
+    copy_param.srcDevice = ctx->frame_out;
+    copy_param.srcPitch = outframe->linesize[0];
+    copy_param.WidthInBytes = outframe->linesize[0];
+    copy_param.Height = outframe->height;
+
+    ck_cu(cu->cuLaunchKernel(ctx->cu_func_dnn_to_frame, 
+                        DIV_UP(outframe->width, BLOCKX), DIV_UP(outframe->height, BLOCKY), 
+                        1, BLOCKX, BLOCKY, 1, 0, hw_ctx->stream, dnn_to_frame_args, NULL));
+    ck_cu(cu->cuMemcpy2DAsync(&copy_param, hw_ctx->stream));
+
+    ck_cu(cu->cuStreamSynchronize(hw_ctx->stream));
+
+    return DNN_SUCCESS;
+}
+
+DNNReturnType trt_load_model(DNNModel *model, const char *model_filename, const AVClass *av_class, const char *options)
+{    
+    int ret = 0;
+    char id_buf[64] = { 0 };
+    AVBufferRef *device_ref = NULL;
+    TRTContext *ctx = (TRTContext*)av_mallocz(sizeof(TRTContext));
+    AVHWDeviceContext *hw_device;
+    AVCUDADeviceContext *hw_ctx;
+    CudaFunctions *cu;
+    CUcontext dummy, cuda_ctx;
+
+    ctx->av_class = av_class;
+    av_opt_set_defaults(ctx);
+    if (av_opt_set_from_string(ctx, options, NULL, "=", "&") < 0)
+    {
+        av_log(ctx, AV_LOG_ERROR, "Failed to parse options \"%s\"\n", options);
+        return DNN_ERROR;
+    }
+    snprintf(id_buf, sizeof(id_buf), "%d", ctx->options.device);
+    
+    // TODO: Add device index option
+    ret = av_hwdevice_ctx_create(&device_ref, AV_HWDEVICE_TYPE_CUDA, id_buf, NULL, 1);
+    if (ret < 0)
+    {
+        av_log(ctx, AV_LOG_ERROR, "Error creating device context\n");
+        return DNN_ERROR;
+    }
+
+    hw_device = (AVHWDeviceContext*)device_ref->data;
+    hw_ctx = (AVCUDADeviceContext*)hw_device->hwctx;
+    cu = hw_ctx->internal->cuda_dl;
+    cuda_ctx = hw_ctx->cuda_ctx;
+
+    ck_cu(cu->cuCtxPushCurrent(cuda_ctx));
+
+    TrtLite *trt_model= new TrtLite{model_filename, ctx};
+    if (trt_model == nullptr)
+    {
+        return DNN_ERROR;
+    }
+
+    ctx->hwdevice = device_ref;
+
+    ck_cu(cu->cuCtxPopCurrent(&dummy));
+
+    trt_model->PrintInfo();
+
+
+    model->model = trt_model;
+    model->get_input = &get_input_trt;
+    model->get_output = &get_output_trt;
+    model->options = options;
+    av_log(ctx, AV_LOG_INFO, "Load trt engine\n");
+
+    return DNN_SUCCESS;
+}
+
+DNNReturnType get_input_trt(void *model, DNNData *input, const char *input_name)
+{
+    TrtLite* trt_model = (TrtLite*)model;
+    TRTContext *ctx = trt_model->ctx;
+    AVHWDeviceContext *hw_device = (AVHWDeviceContext*)ctx->hwdevice->data;
+    AVCUDADeviceContext *hw_ctx = (AVCUDADeviceContext*)hw_device->hwctx;
+    CudaFunctions *cu = hw_ctx->internal->cuda_dl;
+
+    CUcontext dummy, cuda_ctx = hw_ctx->cuda_ctx;
+
+    av_log(ctx, AV_LOG_INFO, "Get TRT input\n");
+
+    // For dynamic shape, input dimensions are set to -1,
+    // trt input is initialized in get_output_trt() along with trt output
+    if (!trt_model->GetEngine()->hasImplicitBatchDimension())
+    {
+        av_log(ctx, AV_LOG_INFO, "Model supports dynamic shape\n");
+        for (int i = 0; i < trt_model->GetEngine()->getNbBindings(); i++) {
+            if (trt_model->GetEngine()->bindingIsInput(i))
+            {
+                ctx->channels = trt_model->GetEngine()->getBindingDimensions(i).d[1];
+                if (ctx->channels == -1)
+                {
+                    av_log(ctx, AV_LOG_ERROR, "Do not support dynamic channel size\n");
+                    return DNN_ERROR;
+                }
+                input->channels = ctx->channels;
+            }
+        }
+        input->height = -1;
+        input->width = -1;
+        input->dt = DNN_FLOAT;
+
+        return DNN_SUCCESS;
+    }
+
+    vector<IOInfo> v_info = trt_model->ConfigIO(BATCH);
+    for (auto info: v_info)
+    {
+        if (info.bInput)
+        {
+            input->channels = info.dim.d[1];
+            input->height = info.dim.d[2];
+            input->width = info.dim.d[3];
+            input->dt = DNN_FLOAT;
+
+            ctx->host_in = new uint8_t[info.GetNumBytes()];
+
+            ck_cu(cu->cuCtxPushCurrent(cuda_ctx));
+
+            ck_cu(cu->cuMemAlloc(&ctx->trt_in, info.GetNumBytes()));
+            ck_cu(cu->cuMemAlloc(&ctx->frame_in, info.GetNumBytes() / sizeof(float)));
+
+            ck_cu(cu->cuCtxPopCurrent(&dummy));
+            
+            return DNN_SUCCESS;
+        }
+    }
+    av_log(ctx, AV_LOG_ERROR, "No input found in the model\n");
+    return DNN_ERROR;
+}
+
+DNNReturnType get_output_trt(void *model, const char *input_name, int input_width, int input_height,
+                                const char *output_name, int *output_width, int *output_height)
+{
+    TrtLite* trt_model = (TrtLite*)model;
+    TRTContext *ctx = trt_model->ctx;
+    AVHWDeviceContext *hw_device = (AVHWDeviceContext*)ctx->hwdevice->data;
+    AVCUDADeviceContext *hw_ctx = (AVCUDADeviceContext*)hw_device->hwctx;
+    CudaFunctions *cu = hw_ctx->internal->cuda_dl;
+
+    CUcontext dummy, cuda_ctx = hw_ctx->cuda_ctx;
+    extern char dnn_io_proc_trt_ptx[];
+
+    av_log(ctx, AV_LOG_INFO, "Get TRT output\n");
+
+    vector<IOInfo> v_info;
+    if (!trt_model->GetEngine()->hasImplicitBatchDimension())
+    {
+        map<int, Dims> i2shape;
+        i2shape.insert(make_pair(0, Dims{4, {BATCH, ctx->channels, input_height, input_width}}));
+        v_info = trt_model->ConfigIO(i2shape);
+    }
+    else
+    {
+        v_info = trt_model->ConfigIO(BATCH);
+    }
+    
+    ck_cu(cu->cuCtxPushCurrent(cuda_ctx));
+
+    for (auto info: v_info)
+    {
+        // For dynamic shape, inputs are initialized here
+        if (info.bInput && (!trt_model->GetEngine()->hasImplicitBatchDimension()))
+        {
+            ctx->host_in = new uint8_t[info.GetNumBytes()];
+            ck_cu(cu->cuMemAlloc(&ctx->trt_in, info.GetNumBytes()));
+            ck_cu(cu->cuMemAlloc(&ctx->frame_in, info.GetNumBytes() / sizeof(float)));
+        }
+        if (!info.bInput)
+        {
+            *output_height = info.dim.d[2];
+            *output_width = info.dim.d[3];
+
+            ctx->host_out = new uint8_t[info.GetNumBytes()];
+            ck_cu(cu->cuMemAlloc(&ctx->trt_out, info.GetNumBytes()));
+            ck_cu(cu->cuMemAlloc(&ctx->frame_out, info.GetNumBytes() / sizeof(float)));
+        }
+    }
+
+    ck_cu(cu->cuModuleLoadData(&ctx->cu_module, dnn_io_proc_trt_ptx));
+    ck_cu(cu->cuModuleGetFunction(&ctx->cu_func_frame_to_dnn, ctx->cu_module, "frame_to_dnn"));
+    ck_cu(cu->cuModuleGetFunction(&ctx->cu_func_dnn_to_frame, ctx->cu_module, "dnn_to_frame"));
+
+    ck_cu(cu->cuCtxPopCurrent(&dummy));
+
+    return DNN_SUCCESS;
+}
+
+DNNReturnType execute_model_trt(const DNNModel *model, const char *input_name, AVFrame *in_frame,
+                                      const char **output_names, uint32_t nb_output, AVFrame *out_frame)
+{
+    TrtLite* trt_model = reinterpret_cast<TrtLite*>(model->model);
+    TRTContext *ctx = trt_model->ctx;
+    AVHWDeviceContext *hw_device = (AVHWDeviceContext*)ctx->hwdevice->data;
+    AVCUDADeviceContext *hw_ctx = (AVCUDADeviceContext*)hw_device->hwctx;
+    CudaFunctions *cu = hw_ctx->internal->cuda_dl;
+
+    CUcontext dummy, cuda_ctx = hw_ctx->cuda_ctx;
+
+    DNNData input, output;
+    vector<void*> buf_vec, device_buf_vec;
+    int ret = 0;
+
+    int input_height = in_frame->height;
+    int input_width = in_frame->width;
+    int input_channels = ctx->channels;
+    vector<IOInfo> IO_info_vec;
+    map<int, Dims> i2shape;
+    if (!trt_model->GetEngine()->hasImplicitBatchDimension())
+    {
+        i2shape.insert(make_pair(0, Dims{4, {BATCH, input_channels, input_height, input_width}}));
+        IO_info_vec = trt_model->ConfigIO(i2shape);
+    }
+    else
+    {
+        IO_info_vec = trt_model->ConfigIO(BATCH);
+    }
+
+    ck_cu(cu->cuCtxPushCurrent(cuda_ctx));
+
+    for (auto info : IO_info_vec)
+    {
+
+        if (info.bInput)
+        {   
+            input.height = info.dim.d[2];
+            input.width = info.dim.d[3];
+            input.channels = info.dim.d[1];
+            input.data = ctx->host_in;
+            input.dt = DNN_FLOAT;
+            ret = frame_to_dnn(in_frame, ctx, info.GetNumBytes() / sizeof(float));
+            
+            if (ret < 0)
+                return DNN_ERROR;
+    
+            device_buf_vec.push_back((void*)ctx->trt_in);
+            continue;
+        }
+        else
+        {
+            device_buf_vec.push_back((void*)ctx->trt_out);
+        }
+    }
+
+    if (!trt_model->GetEngine()->hasImplicitBatchDimension())
+    {
+        trt_model->Execute(i2shape, device_buf_vec, hw_ctx->stream);
+    }
+    else
+    {
+        trt_model->Execute(BATCH, device_buf_vec, hw_ctx->stream);
+    }
+
+    for (uint32_t i = 0; i < IO_info_vec.size(); i++)
+    {
+        if (!IO_info_vec[i].bInput)
+        {
+            output.height = IO_info_vec[i].dim.d[2];
+            output.width = IO_info_vec[i].dim.d[3];
+            output.channels = IO_info_vec[i].dim.d[1];
+            output.data = ctx->host_out;
+            output.dt = DNN_FLOAT;
+            ret = dnn_to_frame(out_frame, ctx, IO_info_vec[i].GetNumBytes() / sizeof(float));
+        }
+    }
+
+    ck_cu(cu->cuCtxPopCurrent(&dummy));
+
+    return DNN_SUCCESS;
+}
+
+DNNReturnType free_model_trt(DNNModel *model)
+{
+    TrtLite* trt_model = reinterpret_cast<TrtLite*>(model->model);
+    TRTContext *ctx = trt_model->ctx;
+    AVHWDeviceContext *hw_device = (AVHWDeviceContext*)ctx->hwdevice->data;
+    AVCUDADeviceContext *hw_ctx = (AVCUDADeviceContext*)hw_device->hwctx;
+    CudaFunctions *cu = hw_ctx->internal->cuda_dl;
+
+    delete[]((uint8_t*)ctx->host_in);
+    delete[]((uint8_t*)ctx->host_out);
+    ck_cu(cu->cuMemFree(ctx->trt_in));
+    ck_cu(cu->cuMemFree(ctx->trt_out));
+    
+    delete(trt_model);
+    
+    av_buffer_unref(&ctx->hwdevice);
+    av_free(ctx);
+    model->model = NULL;
+    
+    return DNN_SUCCESS;
+}
+#ifdef __cplusplus
+}
+#endif
diff --git a/libavfilter/dnn/trt_class_wrapper.h b/libavfilter/dnn/trt_class_wrapper.h
new file mode 100644
index 0000000000..18815fadae
--- /dev/null
+++ b/libavfilter/dnn/trt_class_wrapper.h
@@ -0,0 +1,49 @@ 
+/*
+* Copyright (c) 2021 NVIDIA CORPORATION. All rights reserved.
+*
+* Permission is hereby granted, free of charge, to any person obtaining a
+* copy of this software and associated documentation files (the "Software"),
+* to deal in the Software without restriction, including without limitation
+* the rights to use, copy, modify, merge, publish, distribute, sublicense,
+* and/or sell copies of the Software, and to permit persons to whom the
+* Software is furnished to do so, subject to the following conditions:
+*
+* The above copyright notice and this permission notice shall be included in
+* all copies or substantial portions of the Software.
+*
+* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
+* THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
+* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
+* DEALINGS IN THE SOFTWARE.
+ */
+ 
+/**
+ * @file
+ * TensorRT wrapper header for dnn_backend in ffmpeg.
+ */
+
+#ifndef TRT_CLASS_WRAPPER_H
+#define TRT_CLASS_WRAPPER_H
+
+#ifdef __cplusplus
+extern "C"
+{
+#endif
+
+    #include "../dnn_interface.h"
+
+    DNNReturnType free_model_trt(DNNModel *model);
+    DNNReturnType execute_model_trt(const DNNModel *model, const char *input_name, AVFrame *in_frame,
+                                           const char **output_names, uint32_t nb_output, AVFrame *out_frame);
+    DNNReturnType get_output_trt(void *model, const char *input_name, int input_width, int input_height,
+                                        const char *output_name, int *output_width, int *output_height);
+    DNNReturnType get_input_trt(void *model, DNNData *input, const char *input_name);
+    DNNReturnType trt_load_model(DNNModel *model, const char *model_filename, const AVClass *av_class, const char *options);
+
+#ifdef __cplusplus
+}
+#endif
+#endif
diff --git a/libavfilter/dnn_interface.h b/libavfilter/dnn_interface.h
index 5e9ffeb077..13a3ea8fd8 100644
--- a/libavfilter/dnn_interface.h
+++ b/libavfilter/dnn_interface.h
@@ -32,7 +32,7 @@ 
 
 typedef enum {DNN_SUCCESS, DNN_ERROR} DNNReturnType;
 
-typedef enum {DNN_NATIVE, DNN_TF, DNN_OV} DNNBackendType;
+typedef enum {DNN_NATIVE, DNN_TF, DNN_OV, DNN_TRT} DNNBackendType;
 
 typedef enum {DNN_FLOAT = 1, DNN_UINT8 = 4} DNNDataType;
 
diff --git a/libavfilter/vf_dnn_processing.c b/libavfilter/vf_dnn_processing.c
index e1d9d24683..3bc86a2534 100644
--- a/libavfilter/vf_dnn_processing.c
+++ b/libavfilter/vf_dnn_processing.c
@@ -52,6 +52,9 @@  static const AVOption dnn_processing_options[] = {
 #endif
 #if (CONFIG_LIBOPENVINO == 1)
     { "openvino",    "openvino backend flag",      0,                        AV_OPT_TYPE_CONST,     { .i64 = 2 },    0, 0, FLAGS, "backend" },
+#endif
+#if (CONFIG_LIBTENSORRT == 1)
+    { "tensorrt",    "tensorrt backend flag",      0,                        AV_OPT_TYPE_CONST,     { .i64 = 3 },    0, 0, FLAGS, "backend" },
 #endif
     DNN_COMMON_OPTIONS
     { NULL }