From patchwork Tue Feb 20 04:48:24 2024 Content-Type: text/plain; charset="utf-8" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit X-Patchwork-Submitter: "Chen, Wenbin" X-Patchwork-Id: 46386 Delivered-To: ffmpegpatchwork2@gmail.com Received: by 2002:a05:6a20:dda5:b0:19e:cdac:8cce with SMTP id kw37csp1647665pzb; Mon, 19 Feb 2024 20:49:06 -0800 (PST) X-Forwarded-Encrypted: i=2; AJvYcCUCKnK3vgwWvFiqTpMKF0ug3jXJviouBJgp4rSVmY/UoFiviBHy6JRqjVw9rljE3eThaj93PbatkAwchugBAa3n1yjFJU6gQnT92g== X-Google-Smtp-Source: AGHT+IEhUVA+aIKFgY+f//2viyhG61HLIxz+GpfpmmOmcIhjEpS2sPDXxiFKus8eAHLD9WqGXnuH X-Received: by 2002:a05:6512:239f:b0:512:b90e:ab3a with SMTP id c31-20020a056512239f00b00512b90eab3amr2652317lfv.23.1708404546559; Mon, 19 Feb 2024 20:49:06 -0800 (PST) ARC-Seal: i=1; a=rsa-sha256; t=1708404546; cv=none; d=google.com; s=arc-20160816; b=djQdFDtLCae5oLU0zMmyqkjRx4PHhAOJFeXLc2ppKEG1wpIg623a6BA8i1cbA+JXZC bexsEfXmTOdJHORb4hqkrner8sWOUFk1r+59ryY/aILi9n2fAfIb/Ft8EefyZlW3ViEl rZsIctkozDev/7OnH/VV54UXSb6Y5T5G3kzGmueICMqbTOj841kP0HWesRzJdKAgVcdV IhDma7mHAv8wS0sl40p1H6FiGsmBkJdWJP1073XqKkijNqNX7U2fyL4iiWJ2Tg/niE9L Qn5jsgz2+EfoTm5Fde+X5NeqHJ9UTP16FcCL+umDScrlMfxZFWgrQbPWAVxNfb2yfTJZ em2A== ARC-Message-Signature: i=1; a=rsa-sha256; c=relaxed/relaxed; d=google.com; s=arc-20160816; h=sender:errors-to:content-transfer-encoding:reply-to:list-subscribe :list-help:list-post:list-archive:list-unsubscribe:list-id :precedence:subject:mime-version:message-id:date:to:from :dkim-signature:delivered-to; bh=C94l/jTXW2GJ2tEa97xL7oubZxyOqmC8qXBSmzcYlyY=; fh=YOA8vD9MJZuwZ71F/05pj6KdCjf6jQRmzLS+CATXUQk=; b=qEarP0XkcrGSF+pt0TJ/4PkOj1HOVjO155aebfEvOcHW/IcWbqNIAKDd65RxsXLdcZ 3avgya0QlkMTfDL8q9Xb5xrwV9lLVDhvzH38/s6bTQoYvqO1A/z4QbEC/Odcjqd0tn9a BQiZuXyic4VH0oMIT6f5qBE6ANfBVhfoEJoAKdqvwwMC6NVo/BwaPaQ8/VhJP7Y5b5J4 ualsrtVgTxdBCEdmNktwHJK5MlvyFewqLa+0Mlc0FTI9sUB6H3o51v8jIslt0YAdqumm 4Uuir3IATXFeUb5ujqRddNT9MDlcS2FkREtyicXg7EKRJ15zvLvD5FCO5lTn4B0Sra7b QQmQ==; dara=google.com ARC-Authentication-Results: i=1; mx.google.com; dkim=neutral (body hash did not verify) header.i=@intel.com header.s=Intel header.b=LIVuPGp3; spf=pass (google.com: domain of ffmpeg-devel-bounces@ffmpeg.org designates 79.124.17.100 as permitted sender) smtp.mailfrom=ffmpeg-devel-bounces@ffmpeg.org Return-Path: Received: from ffbox0-bg.mplayerhq.hu (ffbox0-bg.ffmpeg.org. [79.124.17.100]) by mx.google.com with ESMTP id cw21-20020a170906c79500b00a3f0c32a855si29435ejb.538.2024.02.19.20.49.06; Mon, 19 Feb 2024 20:49:06 -0800 (PST) Received-SPF: pass (google.com: domain of ffmpeg-devel-bounces@ffmpeg.org designates 79.124.17.100 as permitted sender) client-ip=79.124.17.100; Authentication-Results: mx.google.com; dkim=neutral (body hash did not verify) header.i=@intel.com header.s=Intel header.b=LIVuPGp3; spf=pass (google.com: domain of ffmpeg-devel-bounces@ffmpeg.org designates 79.124.17.100 as permitted sender) smtp.mailfrom=ffmpeg-devel-bounces@ffmpeg.org Received: from [127.0.1.1] (localhost [127.0.0.1]) by ffbox0-bg.mplayerhq.hu (Postfix) with ESMTP id 753D368D443; Tue, 20 Feb 2024 06:49:02 +0200 (EET) X-Original-To: ffmpeg-devel@ffmpeg.org Delivered-To: ffmpeg-devel@ffmpeg.org Received: from mgamail.intel.com (mgamail.intel.com [192.198.163.9]) by ffbox0-bg.mplayerhq.hu (Postfix) with ESMTPS id B644568D404 for ; Tue, 20 Feb 2024 06:48:54 +0200 (EET) DKIM-Signature: v=1; a=rsa-sha256; c=relaxed/simple; d=intel.com; i=@intel.com; q=dns/txt; s=Intel; t=1708404540; x=1739940540; h=from:to:subject:date:message-id:mime-version: content-transfer-encoding; bh=kvDg2O8iUt8F5fGFryDoLoQXdC67lj0zkfYtgfImEow=; b=LIVuPGp3j/LH+y5wLziMHlLr/JrBKixZYJ3yGL92mGXI1PoxrUuBklFh krr289jvAmAFK2AwWlnGEWEtYBeG/E1Lkyt8Ty48QdUx83joDqfPl/9FV 1bKsUIJbZJJ3jH9UMsEKwdSELGF7vgtT4dVx/RGQOwRqR3AHe+DQTaiKJ 2fJIq6y+lRw4q6WPzZYDFsKd+WX+iv+JtM7dvPrrh8kFMtBCojhXCBM3A LAwkxCvSQzee256HyLywa2sIHfbdIDTsqSt4BlKFMBEPsC21spqvjdruE r+1nw15bxLssZsBqi6c7OZmh70vOZp8xElwLiCcrObqsgWaQXzvcxHTCS Q==; X-IronPort-AV: E=McAfee;i="6600,9927,10989"; a="13201335" X-IronPort-AV: E=Sophos;i="6.06,171,1705392000"; d="scan'208";a="13201335" Received: from orviesa007.jf.intel.com ([10.64.159.147]) by fmvoesa103.fm.intel.com with ESMTP/TLS/ECDHE-RSA-AES256-GCM-SHA384; 19 Feb 2024 20:48:26 -0800 X-ExtLoop1: 1 X-IronPort-AV: E=Sophos;i="6.06,171,1705392000"; d="scan'208";a="4957527" Received: from wenbin-z390-aorus-ultra.sh.intel.com ([10.239.156.43]) by orviesa007.jf.intel.com with ESMTP; 19 Feb 2024 20:48:25 -0800 From: wenbin.chen-at-intel.com@ffmpeg.org To: ffmpeg-devel@ffmpeg.org Date: Tue, 20 Feb 2024 12:48:24 +0800 Message-Id: <20240220044824.1439205-1-wenbin.chen@intel.com> X-Mailer: git-send-email 2.34.1 MIME-Version: 1.0 Subject: [FFmpeg-devel] [PATCH v3] libavfi/dnn: add LibTorch as one of DNN backend X-BeenThere: ffmpeg-devel@ffmpeg.org X-Mailman-Version: 2.1.29 Precedence: list List-Id: FFmpeg development discussions and patches List-Unsubscribe: , List-Archive: List-Post: List-Help: List-Subscribe: , Reply-To: FFmpeg development discussions and patches Errors-To: ffmpeg-devel-bounces@ffmpeg.org Sender: "ffmpeg-devel" X-TUID: kmXUU6HdZ9pO From: Wenbin Chen PyTorch is an open source machine learning framework that accelerates the path from research prototyping to production deployment. Official websit: https://pytorch.org/. We call the C++ library of PyTorch as LibTorch, the same below. To build FFmpeg with LibTorch, please take following steps as reference: 1. download LibTorch C++ library in https://pytorch.org/get-started/locally/, please select C++/Java for language, and other options as your need. 2. unzip the file to your own dir, with command unzip libtorch-shared-with-deps-latest.zip -d your_dir 3. export libtorch_root/libtorch/include and libtorch_root/libtorch/include/torch/csrc/api/include to $PATH export libtorch_root/libtorch/lib/ to $LD_LIBRARY_PATH 4. config FFmpeg with ../configure --enable-libtorch --extra-cflag=-I/libtorch_root/libtorch/include --extra-cflag=-I/libtorch_root/libtorch/include/torch/csrc/api/include --extra-ldflags=-L/libtorch_root/libtorch/lib/ 5. make To run FFmpeg DNN inference with LibTorch backend: ./ffmpeg -i input.jpg -vf dnn_processing=dnn_backend=torch:model=LibTorch_model.pt -y output.jpg The LibTorch_model.pt can be generated by Python with torch.jit.script() api. Please note, torch.jit.trace() is not recommanded, since it does not support ambiguous input size. Signed-off-by: Ting Fu Signed-off-by: Wenbin Chen --- configure | 5 +- libavfilter/dnn/Makefile | 1 + libavfilter/dnn/dnn_backend_torch.cpp | 597 ++++++++++++++++++++++++++ libavfilter/dnn/dnn_interface.c | 5 + libavfilter/dnn_filter_common.c | 15 +- libavfilter/dnn_interface.h | 2 +- libavfilter/vf_dnn_processing.c | 3 + 7 files changed, 624 insertions(+), 4 deletions(-) create mode 100644 libavfilter/dnn/dnn_backend_torch.cpp diff --git a/configure b/configure index 2c635043dd..450ef54a80 100755 --- a/configure +++ b/configure @@ -279,6 +279,7 @@ External library support: --enable-libtheora enable Theora encoding via libtheora [no] --enable-libtls enable LibreSSL (via libtls), needed for https support if openssl, gnutls or mbedtls is not used [no] + --enable-libtorch enable Torch as one DNN backend [no] --enable-libtwolame enable MP2 encoding via libtwolame [no] --enable-libuavs3d enable AVS3 decoding via libuavs3d [no] --enable-libv4l2 enable libv4l2/v4l-utils [no] @@ -1901,6 +1902,7 @@ EXTERNAL_LIBRARY_LIST=" libtensorflow libtesseract libtheora + libtorch libtwolame libuavs3d libv4l2 @@ -2781,7 +2783,7 @@ cbs_vp9_select="cbs" deflate_wrapper_deps="zlib" dirac_parse_select="golomb" dovi_rpu_select="golomb" -dnn_suggest="libtensorflow libopenvino" +dnn_suggest="libtensorflow libopenvino libtorch" dnn_deps="avformat swscale" error_resilience_select="me_cmp" evcparse_select="golomb" @@ -6886,6 +6888,7 @@ enabled libtensorflow && require libtensorflow tensorflow/c/c_api.h TF_Versi enabled libtesseract && require_pkg_config libtesseract tesseract tesseract/capi.h TessBaseAPICreate enabled libtheora && require libtheora theora/theoraenc.h th_info_init -ltheoraenc -ltheoradec -logg enabled libtls && require_pkg_config libtls libtls tls.h tls_configure +enabled libtorch && check_cxxflags -std=c++14 && require_cpp libtorch torch/torch.h "torch::Tensor" -ltorch -lc10 -ltorch_cpu -lstdc++ -lpthread enabled libtwolame && require libtwolame twolame.h twolame_init -ltwolame && { check_lib libtwolame twolame.h twolame_encode_buffer_float32_interleaved -ltwolame || die "ERROR: libtwolame must be installed and version must be >= 0.3.10"; } diff --git a/libavfilter/dnn/Makefile b/libavfilter/dnn/Makefile index 5d5697ea42..3d09927c98 100644 --- a/libavfilter/dnn/Makefile +++ b/libavfilter/dnn/Makefile @@ -6,5 +6,6 @@ OBJS-$(CONFIG_DNN) += dnn/dnn_backend_common.o DNN-OBJS-$(CONFIG_LIBTENSORFLOW) += dnn/dnn_backend_tf.o DNN-OBJS-$(CONFIG_LIBOPENVINO) += dnn/dnn_backend_openvino.o +DNN-OBJS-$(CONFIG_LIBTORCH) += dnn/dnn_backend_torch.o OBJS-$(CONFIG_DNN) += $(DNN-OBJS-yes) diff --git a/libavfilter/dnn/dnn_backend_torch.cpp b/libavfilter/dnn/dnn_backend_torch.cpp new file mode 100644 index 0000000000..54d3b309a1 --- /dev/null +++ b/libavfilter/dnn/dnn_backend_torch.cpp @@ -0,0 +1,597 @@ +/* + * Copyright (c) 2024 + * + * This file is part of FFmpeg. + * + * FFmpeg is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * FFmpeg is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with FFmpeg; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ + +/** + * @file + * DNN Torch backend implementation. + */ + +#include +#include + +extern "C" { +#include "../internal.h" +#include "dnn_io_proc.h" +#include "dnn_backend_common.h" +#include "libavutil/opt.h" +#include "queue.h" +#include "safe_queue.h" +} + +typedef struct THOptions{ + char *device_name; + int optimize; +} THOptions; + +typedef struct THContext { + const AVClass *c_class; + THOptions options; +} THContext; + +typedef struct THModel { + THContext ctx; + DNNModel *model; + torch::jit::Module *jit_model; + SafeQueue *request_queue; + Queue *task_queue; + Queue *lltask_queue; +} THModel; + +typedef struct THInferRequest { + torch::Tensor *output; + torch::Tensor *input_tensor; +} THInferRequest; + +typedef struct THRequestItem { + THInferRequest *infer_request; + LastLevelTaskItem *lltask; + DNNAsyncExecModule exec_module; +} THRequestItem; + + +#define OFFSET(x) offsetof(THContext, x) +#define FLAGS AV_OPT_FLAG_FILTERING_PARAM +static const AVOption dnn_th_options[] = { + { "device", "device to run model", OFFSET(options.device_name), AV_OPT_TYPE_STRING, { .str = "cpu" }, 0, 0, FLAGS }, + { "optimize", "turn on graph executor optimization", OFFSET(options.optimize), AV_OPT_TYPE_INT, { .i64 = 0 }, 0, 1, FLAGS}, + { NULL } +}; + +AVFILTER_DEFINE_CLASS(dnn_th); + +static int extract_lltask_from_task(TaskItem *task, Queue *lltask_queue) +{ + THModel *th_model = (THModel *)task->model; + THContext *ctx = &th_model->ctx; + LastLevelTaskItem *lltask = (LastLevelTaskItem *)av_malloc(sizeof(*lltask)); + if (!lltask) { + av_log(ctx, AV_LOG_ERROR, "Failed to allocate memory for LastLevelTaskItem\n"); + return AVERROR(ENOMEM); + } + task->inference_todo = 1; + task->inference_done = 0; + lltask->task = task; + if (ff_queue_push_back(lltask_queue, lltask) < 0) { + av_log(ctx, AV_LOG_ERROR, "Failed to push back lltask_queue.\n"); + av_freep(&lltask); + return AVERROR(ENOMEM); + } + return 0; +} + +static void th_free_request(THInferRequest *request) +{ + if (!request) + return; + if (request->output) { + delete(request->output); + request->output = NULL; + } + if (request->input_tensor) { + delete(request->input_tensor); + request->input_tensor = NULL; + } + return; +} + +static inline void destroy_request_item(THRequestItem **arg) +{ + THRequestItem *item; + if (!arg || !*arg) { + return; + } + item = *arg; + th_free_request(item->infer_request); + av_freep(&item->infer_request); + av_freep(&item->lltask); + ff_dnn_async_module_cleanup(&item->exec_module); + av_freep(arg); +} + +static void dnn_free_model_th(DNNModel **model) +{ + THModel *th_model; + if (!model || !*model) + return; + + th_model = (THModel *) (*model)->model; + while (ff_safe_queue_size(th_model->request_queue) != 0) { + THRequestItem *item = (THRequestItem *)ff_safe_queue_pop_front(th_model->request_queue); + destroy_request_item(&item); + } + ff_safe_queue_destroy(th_model->request_queue); + + while (ff_queue_size(th_model->lltask_queue) != 0) { + LastLevelTaskItem *item = (LastLevelTaskItem *)ff_queue_pop_front(th_model->lltask_queue); + av_freep(&item); + } + ff_queue_destroy(th_model->lltask_queue); + + while (ff_queue_size(th_model->task_queue) != 0) { + TaskItem *item = (TaskItem *)ff_queue_pop_front(th_model->task_queue); + av_frame_free(&item->in_frame); + av_frame_free(&item->out_frame); + av_freep(&item); + } + ff_queue_destroy(th_model->task_queue); + delete th_model->jit_model; + av_opt_free(&th_model->ctx); + av_freep(&th_model); + av_freep(model); +} + +static int get_input_th(void *model, DNNData *input, const char *input_name) +{ + input->dt = DNN_FLOAT; + input->order = DCO_RGB; + input->layout = DL_NCHW; + input->dims[0] = 1; + input->dims[1] = 3; + input->dims[2] = -1; + input->dims[3] = -1; + return 0; +} + +static void deleter(void *arg) +{ + av_freep(&arg); +} + +static int fill_model_input_th(THModel *th_model, THRequestItem *request) +{ + LastLevelTaskItem *lltask = NULL; + TaskItem *task = NULL; + THInferRequest *infer_request = NULL; + DNNData input = { 0 }; + THContext *ctx = &th_model->ctx; + int ret, width_idx, height_idx, channel_idx; + + lltask = (LastLevelTaskItem *)ff_queue_pop_front(th_model->lltask_queue); + if (!lltask) { + ret = AVERROR(EINVAL); + goto err; + } + request->lltask = lltask; + task = lltask->task; + infer_request = request->infer_request; + + ret = get_input_th(th_model, &input, NULL); + if ( ret != 0) { + goto err; + } + width_idx = dnn_get_width_idx_by_layout(input.layout); + height_idx = dnn_get_height_idx_by_layout(input.layout); + channel_idx = dnn_get_channel_idx_by_layout(input.layout); + input.dims[height_idx] = task->in_frame->height; + input.dims[width_idx] = task->in_frame->width; + input.data = av_malloc(input.dims[height_idx] * input.dims[width_idx] * + input.dims[channel_idx] * sizeof(float)); + if (!input.data) + return AVERROR(ENOMEM); + infer_request->input_tensor = new torch::Tensor(); + infer_request->output = new torch::Tensor(); + + switch (th_model->model->func_type) { + case DFT_PROCESS_FRAME: + input.scale = 255; + if (task->do_ioproc) { + if (th_model->model->frame_pre_proc != NULL) { + th_model->model->frame_pre_proc(task->in_frame, &input, th_model->model->filter_ctx); + } else { + ff_proc_from_frame_to_dnn(task->in_frame, &input, ctx); + } + } + break; + default: + avpriv_report_missing_feature(NULL, "model function type %d", th_model->model->func_type); + break; + } + *infer_request->input_tensor = torch::from_blob(input.data, + {1, 1, input.dims[channel_idx], input.dims[height_idx], input.dims[width_idx]}, + deleter, torch::kFloat32); + return 0; + +err: + th_free_request(infer_request); + return ret; +} + +static int th_start_inference(void *args) +{ + THRequestItem *request = (THRequestItem *)args; + THInferRequest *infer_request = NULL; + LastLevelTaskItem *lltask = NULL; + TaskItem *task = NULL; + THModel *th_model = NULL; + THContext *ctx = NULL; + std::vector inputs; + torch::NoGradGuard no_grad; + + if (!request) { + av_log(NULL, AV_LOG_ERROR, "THRequestItem is NULL\n"); + return AVERROR(EINVAL); + } + infer_request = request->infer_request; + lltask = request->lltask; + task = lltask->task; + th_model = (THModel *)task->model; + ctx = &th_model->ctx; + + if (ctx->options.optimize) + torch::jit::setGraphExecutorOptimize(true); + else + torch::jit::setGraphExecutorOptimize(false); + + if (!infer_request->input_tensor || !infer_request->output) { + av_log(ctx, AV_LOG_ERROR, "input or output tensor is NULL\n"); + return DNN_GENERIC_ERROR; + } + inputs.push_back(*infer_request->input_tensor); + + *infer_request->output = th_model->jit_model->forward(inputs).toTensor(); + + return 0; +} + +static void infer_completion_callback(void *args) { + THRequestItem *request = (THRequestItem*)args; + LastLevelTaskItem *lltask = request->lltask; + TaskItem *task = lltask->task; + DNNData outputs = { 0 }; + THInferRequest *infer_request = request->infer_request; + THModel *th_model = (THModel *)task->model; + torch::Tensor *output = infer_request->output; + + c10::IntArrayRef sizes = output->sizes(); + outputs.order = DCO_RGB; + outputs.layout = DL_NCHW; + outputs.dt = DNN_FLOAT; + if (sizes.size() == 5) { + // 5 dimensions: [batch_size, frame_nubmer, channel, height, width] + // this format of data is normally used for video frame SR + outputs.dims[0] = sizes.at(0); // N + outputs.dims[1] = sizes.at(2); // C + outputs.dims[2] = sizes.at(3); // H + outputs.dims[3] = sizes.at(4); // W + } else { + avpriv_report_missing_feature(&th_model->ctx, "Support of this kind of model"); + goto err; + } + + switch (th_model->model->func_type) { + case DFT_PROCESS_FRAME: + if (task->do_ioproc) { + outputs.scale = 255; + outputs.data = output->data_ptr(); + if (th_model->model->frame_post_proc != NULL) { + th_model->model->frame_post_proc(task->out_frame, &outputs, th_model->model->filter_ctx); + } else { + ff_proc_from_dnn_to_frame(task->out_frame, &outputs, &th_model->ctx); + } + } else { + task->out_frame->width = outputs.dims[dnn_get_width_idx_by_layout(outputs.layout)]; + task->out_frame->height = outputs.dims[dnn_get_height_idx_by_layout(outputs.layout)]; + } + break; + default: + avpriv_report_missing_feature(&th_model->ctx, "model function type %d", th_model->model->func_type); + goto err; + } + task->inference_done++; + av_freep(&request->lltask); +err: + th_free_request(infer_request); + + if (ff_safe_queue_push_back(th_model->request_queue, request) < 0) { + destroy_request_item(&request); + av_log(&th_model->ctx, AV_LOG_ERROR, "Unable to push back request_queue when failed to start inference.\n"); + } +} + +static int execute_model_th(THRequestItem *request, Queue *lltask_queue) +{ + THModel *th_model = NULL; + LastLevelTaskItem *lltask; + TaskItem *task = NULL; + int ret = 0; + + if (ff_queue_size(lltask_queue) == 0) { + destroy_request_item(&request); + return 0; + } + + lltask = (LastLevelTaskItem *)ff_queue_peek_front(lltask_queue); + if (lltask == NULL) { + av_log(NULL, AV_LOG_ERROR, "Failed to get LastLevelTaskItem\n"); + ret = AVERROR(EINVAL); + goto err; + } + task = lltask->task; + th_model = (THModel *)task->model; + + ret = fill_model_input_th(th_model, request); + if ( ret != 0) { + goto err; + } + if (task->async) { + avpriv_report_missing_feature(&th_model->ctx, "LibTorch async"); + } else { + ret = th_start_inference((void *)(request)); + if (ret != 0) { + goto err; + } + infer_completion_callback(request); + return (task->inference_done == task->inference_todo) ? 0 : DNN_GENERIC_ERROR; + } + +err: + th_free_request(request->infer_request); + if (ff_safe_queue_push_back(th_model->request_queue, request) < 0) { + destroy_request_item(&request); + } + return ret; +} + +static int get_output_th(void *model, const char *input_name, int input_width, int input_height, + const char *output_name, int *output_width, int *output_height) +{ + int ret = 0; + THModel *th_model = (THModel*) model; + THContext *ctx = &th_model->ctx; + TaskItem task = { 0 }; + THRequestItem *request = NULL; + DNNExecBaseParams exec_params = { + .input_name = input_name, + .output_names = &output_name, + .nb_output = 1, + .in_frame = NULL, + .out_frame = NULL, + }; + ret = ff_dnn_fill_gettingoutput_task(&task, &exec_params, th_model, input_height, input_width, ctx); + if ( ret != 0) { + goto err; + } + + ret = extract_lltask_from_task(&task, th_model->lltask_queue); + if ( ret != 0) { + av_log(ctx, AV_LOG_ERROR, "unable to extract last level task from task.\n"); + goto err; + } + + request = (THRequestItem*) ff_safe_queue_pop_front(th_model->request_queue); + if (!request) { + av_log(ctx, AV_LOG_ERROR, "unable to get infer request.\n"); + ret = AVERROR(EINVAL); + goto err; + } + + ret = execute_model_th(request, th_model->lltask_queue); + *output_width = task.out_frame->width; + *output_height = task.out_frame->height; + +err: + av_frame_free(&task.out_frame); + av_frame_free(&task.in_frame); + return ret; +} + +static THInferRequest *th_create_inference_request(void) +{ + THInferRequest *request = (THInferRequest *)av_malloc(sizeof(THInferRequest)); + if (!request) { + return NULL; + } + request->input_tensor = NULL; + request->output = NULL; + return request; +} + +static DNNModel *dnn_load_model_th(const char *model_filename, DNNFunctionType func_type, const char *options, AVFilterContext *filter_ctx) +{ + DNNModel *model = NULL; + THModel *th_model = NULL; + THRequestItem *item = NULL; + THContext *ctx; + + model = (DNNModel *)av_mallocz(sizeof(DNNModel)); + if (!model) { + return NULL; + } + + th_model = (THModel *)av_mallocz(sizeof(THModel)); + if (!th_model) { + av_freep(&model); + return NULL; + } + th_model->model = model; + model->model = th_model; + th_model->ctx.c_class = &dnn_th_class; + ctx = &th_model->ctx; + //parse options + av_opt_set_defaults(ctx); + if (av_opt_set_from_string(ctx, options, NULL, "=", "&") < 0) { + av_log(ctx, AV_LOG_ERROR, "Failed to parse options \"%s\"\n", options); + return NULL; + } + + c10::Device device = c10::Device(ctx->options.device_name); + if (!device.is_cpu()) { + av_log(ctx, AV_LOG_ERROR, "Not supported device:\"%s\"\n", ctx->options.device_name); + goto fail; + } + + try { + th_model->jit_model = new torch::jit::Module; + (*th_model->jit_model) = torch::jit::load(model_filename); + } catch (const c10::Error& e) { + av_log(ctx, AV_LOG_ERROR, "Failed to load torch model\n"); + goto fail; + } + + th_model->request_queue = ff_safe_queue_create(); + if (!th_model->request_queue) { + goto fail; + } + + item = (THRequestItem *)av_mallocz(sizeof(THRequestItem)); + if (!item) { + goto fail; + } + item->lltask = NULL; + item->infer_request = th_create_inference_request(); + if (!item->infer_request) { + av_log(NULL, AV_LOG_ERROR, "Failed to allocate memory for Torch inference request\n"); + goto fail; + } + item->exec_module.start_inference = &th_start_inference; + item->exec_module.callback = &infer_completion_callback; + item->exec_module.args = item; + + if (ff_safe_queue_push_back(th_model->request_queue, item) < 0) { + goto fail; + } + item = NULL; + + th_model->task_queue = ff_queue_create(); + if (!th_model->task_queue) { + goto fail; + } + + th_model->lltask_queue = ff_queue_create(); + if (!th_model->lltask_queue) { + goto fail; + } + + model->get_input = &get_input_th; + model->get_output = &get_output_th; + model->options = NULL; + model->filter_ctx = filter_ctx; + model->func_type = func_type; + return model; + +fail: + if (item) { + destroy_request_item(&item); + av_freep(&item); + } + dnn_free_model_th(&model); + return NULL; +} + +static int dnn_execute_model_th(const DNNModel *model, DNNExecBaseParams *exec_params) +{ + THModel *th_model = (THModel *)model->model; + THContext *ctx = &th_model->ctx; + TaskItem *task; + THRequestItem *request; + int ret = 0; + + ret = ff_check_exec_params(ctx, DNN_TH, model->func_type, exec_params); + if (ret != 0) { + av_log(ctx, AV_LOG_ERROR, "exec parameter checking fail.\n"); + return ret; + } + + task = (TaskItem *)av_malloc(sizeof(TaskItem)); + if (!task) { + av_log(ctx, AV_LOG_ERROR, "unable to alloc memory for task item.\n"); + return AVERROR(ENOMEM); + } + + ret = ff_dnn_fill_task(task, exec_params, th_model, 0, 1); + if (ret != 0) { + av_freep(&task); + av_log(ctx, AV_LOG_ERROR, "unable to fill task.\n"); + return ret; + } + + ret = ff_queue_push_back(th_model->task_queue, task); + if (ret < 0) { + av_freep(&task); + av_log(ctx, AV_LOG_ERROR, "unable to push back task_queue.\n"); + return ret; + } + + ret = extract_lltask_from_task(task, th_model->lltask_queue); + if (ret != 0) { + av_log(ctx, AV_LOG_ERROR, "unable to extract last level task from task.\n"); + return ret; + } + + request = (THRequestItem *)ff_safe_queue_pop_front(th_model->request_queue); + if (!request) { + av_log(ctx, AV_LOG_ERROR, "unable to get infer request.\n"); + return AVERROR(EINVAL); + } + + return execute_model_th(request, th_model->lltask_queue); +} + +static DNNAsyncStatusType dnn_get_result_th(const DNNModel *model, AVFrame **in, AVFrame **out) +{ + THModel *th_model = (THModel *)model->model; + return ff_dnn_get_result_common(th_model->task_queue, in, out); +} + +static int dnn_flush_th(const DNNModel *model) +{ + THModel *th_model = (THModel *)model->model; + THRequestItem *request; + + if (ff_queue_size(th_model->lltask_queue) == 0) + // no pending task need to flush + return 0; + + request = (THRequestItem *)ff_safe_queue_pop_front(th_model->request_queue); + if (!request) { + av_log(&th_model->ctx, AV_LOG_ERROR, "unable to get infer request.\n"); + return AVERROR(EINVAL); + } + + return execute_model_th(request, th_model->lltask_queue); +} + +extern const DNNModule ff_dnn_backend_torch = { + .load_model = dnn_load_model_th, + .execute_model = dnn_execute_model_th, + .get_result = dnn_get_result_th, + .flush = dnn_flush_th, + .free_model = dnn_free_model_th, +}; diff --git a/libavfilter/dnn/dnn_interface.c b/libavfilter/dnn/dnn_interface.c index e843826aa6..b9f71aea53 100644 --- a/libavfilter/dnn/dnn_interface.c +++ b/libavfilter/dnn/dnn_interface.c @@ -28,6 +28,7 @@ extern const DNNModule ff_dnn_backend_openvino; extern const DNNModule ff_dnn_backend_tf; +extern const DNNModule ff_dnn_backend_torch; const DNNModule *ff_get_dnn_module(DNNBackendType backend_type, void *log_ctx) { @@ -40,6 +41,10 @@ const DNNModule *ff_get_dnn_module(DNNBackendType backend_type, void *log_ctx) case DNN_OV: return &ff_dnn_backend_openvino; #endif + #if (CONFIG_LIBTORCH == 1) + case DNN_TH: + return &ff_dnn_backend_torch; + #endif default: av_log(log_ctx, AV_LOG_ERROR, "Module backend_type %d is not supported or enabled.\n", diff --git a/libavfilter/dnn_filter_common.c b/libavfilter/dnn_filter_common.c index f012d450a2..7d194c9ade 100644 --- a/libavfilter/dnn_filter_common.c +++ b/libavfilter/dnn_filter_common.c @@ -53,12 +53,22 @@ static char **separate_output_names(const char *expr, const char *val_sep, int * int ff_dnn_init(DnnContext *ctx, DNNFunctionType func_type, AVFilterContext *filter_ctx) { + DNNBackendType backend = ctx->backend_type; + if (!ctx->model_filename) { av_log(filter_ctx, AV_LOG_ERROR, "model file for network is not specified\n"); return AVERROR(EINVAL); } - if (ctx->backend_type == DNN_TF) { + if (backend == DNN_TH) { + if (ctx->model_inputname) + av_log(filter_ctx, AV_LOG_WARNING, "LibTorch backend do not require inputname, "\ + "inputname will be ignored.\n"); + if (ctx->model_outputnames) + av_log(filter_ctx, AV_LOG_WARNING, "LibTorch backend do not require outputname(s), "\ + "all outputname(s) will be ignored.\n"); + ctx->nb_outputs = 1; + } else if (backend == DNN_TF) { if (!ctx->model_inputname) { av_log(filter_ctx, AV_LOG_ERROR, "input name of the model network is not specified\n"); return AVERROR(EINVAL); @@ -115,7 +125,8 @@ int ff_dnn_get_input(DnnContext *ctx, DNNData *input) int ff_dnn_get_output(DnnContext *ctx, int input_width, int input_height, int *output_width, int *output_height) { - char * output_name = ctx->model_outputnames ? ctx->model_outputnames[0] : NULL; + char * output_name = ctx->model_outputnames && ctx->backend_type != DNN_TH ? + ctx->model_outputnames[0] : NULL; return ctx->model->get_output(ctx->model->model, ctx->model_inputname, input_width, input_height, (const char *)output_name, output_width, output_height); } diff --git a/libavfilter/dnn_interface.h b/libavfilter/dnn_interface.h index 852d88baa8..63f492e690 100644 --- a/libavfilter/dnn_interface.h +++ b/libavfilter/dnn_interface.h @@ -32,7 +32,7 @@ #define DNN_GENERIC_ERROR FFERRTAG('D','N','N','!') -typedef enum {DNN_TF = 1, DNN_OV} DNNBackendType; +typedef enum {DNN_TF = 1, DNN_OV, DNN_TH} DNNBackendType; typedef enum {DNN_FLOAT = 1, DNN_UINT8 = 4} DNNDataType; diff --git a/libavfilter/vf_dnn_processing.c b/libavfilter/vf_dnn_processing.c index e7d21eef32..fdac31665e 100644 --- a/libavfilter/vf_dnn_processing.c +++ b/libavfilter/vf_dnn_processing.c @@ -50,6 +50,9 @@ static const AVOption dnn_processing_options[] = { #endif #if (CONFIG_LIBOPENVINO == 1) { "openvino", "openvino backend flag", 0, AV_OPT_TYPE_CONST, { .i64 = DNN_OV }, 0, 0, FLAGS, .unit = "backend" }, +#endif +#if (CONFIG_LIBTORCH == 1) + { "torch", "torch backend flag", 0, AV_OPT_TYPE_CONST, { .i64 = DNN_TH }, 0, 0, FLAGS, "backend" }, #endif DNN_COMMON_OPTIONS { NULL }