From patchwork Mon May 23 09:29:17 2022 Content-Type: text/plain; charset="utf-8" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit X-Patchwork-Submitter: Ting Fu X-Patchwork-Id: 35894 Delivered-To: ffmpegpatchwork2@gmail.com Received: by 2002:a05:6a21:9992:b0:82:461d:f3b with SMTP id ve18csp1821081pzb; Mon, 23 May 2022 02:43:23 -0700 (PDT) X-Google-Smtp-Source: ABdhPJwOCk0d2uzzVqR6wZk9mIzyZRX/t1sgwuKxIaNrLENjZovKtHCrMiGNonDZKPF5aivogYuc X-Received: by 2002:a17:906:4fce:b0:6f4:f41c:3233 with SMTP id i14-20020a1709064fce00b006f4f41c3233mr17668692ejw.117.1653298993218; Mon, 23 May 2022 02:43:13 -0700 (PDT) ARC-Seal: i=1; a=rsa-sha256; t=1653298993; cv=none; d=google.com; s=arc-20160816; b=S+MxwC/CYhl86GJFz3a4ow1dnutlCOVkAhbYoIZgrp3XLk8ZkJCI5IvFZMgdWHqj97 9j+x7nCdDDbbVGzIPb+mxRW+K9lYMO58C2evh6gGst/5MumQTXSnBJC02atz/gN03zbH m16grXaTztIiUavPpa2h+1xX89RQpqnYgMGOdIzHj0KEViJBcrYjZVCUNMmu7JQCx5f5 7m1VwxnsAiAoupP9J5ixrn63ZBO8DNiWNTzTnfU8Q/GUrgrDrA68wMjo0EEO2ugosrr8 wpM8uzAWulwXhYN6dNbxlzoWw3N5ft0iXth2OB4kwu4L4E5gmNGVfM45OU2Y5y6+R7jA kqOg== ARC-Message-Signature: i=1; a=rsa-sha256; c=relaxed/relaxed; d=google.com; s=arc-20160816; h=sender:errors-to:content-transfer-encoding:mime-version:reply-to :list-subscribe:list-help:list-post:list-archive:list-unsubscribe :list-id:precedence:subject:message-id:date:to:from:dkim-signature :delivered-to; bh=6YmLHemWTxIPk5Y2DcmocdS9FrObH2AqCYjiwaYe4iQ=; b=wwCVnDH9giuhPKrTo5S0wHg4frouQWgEPlYI/JMf73JnvqJi1gcTND/LOLIHePKgNo hSN+lwsElMEZ0z7CBlWKb/7Bh2uJ9iqiR3dNdYXvp+AVXr6l7HHcNLNzg1m0h9E/j8sO HLwbycPozFzx1au7v+M8kV5CdtfZwWb2Dlj0QcPAFuDM4cmspmkk1pKcX5E1A9s5c2ZM xC6DfStbgc0G3uCf+KvdcRRNH9nt/vD7fwmHzah6QJ+32/yoVo+PkMeMuXVQDOqaCJbz doqnC8ZRn5sdd/VYrYtzG7VUKadJ6T6kbGuXYptkhTWXpvzhHbEOKKYTaFc16Q59tKtG /ygw== ARC-Authentication-Results: i=1; mx.google.com; dkim=neutral (body hash did not verify) header.i=@intel.com header.s=Intel header.b=Kj15HZR+; spf=pass (google.com: domain of ffmpeg-devel-bounces@ffmpeg.org designates 79.124.17.100 as permitted sender) smtp.mailfrom=ffmpeg-devel-bounces@ffmpeg.org Return-Path: Received: from ffbox0-bg.mplayerhq.hu (ffbox0-bg.ffmpeg.org. [79.124.17.100]) by mx.google.com with ESMTP id y21-20020aa7c255000000b004289e7be14esi13900427edo.105.2022.05.23.02.43.10; Mon, 23 May 2022 02:43:13 -0700 (PDT) Received-SPF: pass (google.com: domain of ffmpeg-devel-bounces@ffmpeg.org designates 79.124.17.100 as permitted sender) client-ip=79.124.17.100; Authentication-Results: mx.google.com; dkim=neutral (body hash did not verify) header.i=@intel.com header.s=Intel header.b=Kj15HZR+; spf=pass (google.com: domain of ffmpeg-devel-bounces@ffmpeg.org designates 79.124.17.100 as permitted sender) smtp.mailfrom=ffmpeg-devel-bounces@ffmpeg.org Received: from [127.0.1.1] (localhost [127.0.0.1]) by ffbox0-bg.mplayerhq.hu (Postfix) with ESMTP id DBC9B68B49C; Mon, 23 May 2022 12:43:06 +0300 (EEST) X-Original-To: ffmpeg-devel@ffmpeg.org Delivered-To: ffmpeg-devel@ffmpeg.org Received: from mga05.intel.com (mga05.intel.com [192.55.52.43]) by ffbox0-bg.mplayerhq.hu (Postfix) with ESMTPS id 8050A68B1B5 for ; Mon, 23 May 2022 12:42:59 +0300 (EEST) DKIM-Signature: v=1; a=rsa-sha256; c=relaxed/simple; d=intel.com; i=@intel.com; q=dns/txt; s=Intel; t=1653298984; x=1684834984; h=from:to:subject:date:message-id; bh=f1rS3FPoX70ONXkCgYt5wYlBlfgRcX5bRoQAd8ZvJRQ=; b=Kj15HZR+G3D+qjclXVCYw4RpYU8MxADMS33zwEr5mToEHzxttE4bz2VL FMnuQynbvN2O35xZGXGjwDcfhj8vPboUDioO72tadIvzYMahxOFpPQe8r h6Ui83CCfPeu1XbwR48vKUPtWiETHlbLiKnXpKDKnO1YLPb57WZe8EpLc 5kC4ANZ1WEdgiq2yErnRofrlyIXHDbNwOTu5GP+TmQ6JHAi3w0PkBegFv fpyZOwGDKgr8ufhnq9GeZShf07pC1lvh5wt1E+FK1PQDQHAp9NOacQxXA XI4M7dmqmxLFUkmbucY0gNIjuevXprl8bj5Ev+2P013bOBm6+cXIHDdyP g==; X-IronPort-AV: E=McAfee;i="6400,9594,10355"; a="359560325" X-IronPort-AV: E=Sophos;i="5.91,246,1647327600"; d="scan'208";a="359560325" Received: from orsmga008.jf.intel.com ([10.7.209.65]) by fmsmga105.fm.intel.com with ESMTP/TLS/ECDHE-RSA-AES256-GCM-SHA384; 23 May 2022 02:42:57 -0700 X-ExtLoop1: 1 X-IronPort-AV: E=Sophos;i="5.91,246,1647327600"; d="scan'208";a="600553065" Received: from semmer-ubuntu.sh.intel.com ([10.239.159.83]) by orsmga008.jf.intel.com with ESMTP; 23 May 2022 02:42:56 -0700 From: Ting Fu To: ffmpeg-devel@ffmpeg.org Date: Mon, 23 May 2022 17:29:17 +0800 Message-Id: <20220523092918.9548-1-ting.fu@intel.com> X-Mailer: git-send-email 2.17.1 Subject: [FFmpeg-devel] [PATCH 1/2] libavfi/dnn: refine enum DNNColorOrder X-BeenThere: ffmpeg-devel@ffmpeg.org X-Mailman-Version: 2.1.29 Precedence: list List-Id: FFmpeg development discussions and patches List-Unsubscribe: , List-Archive: List-Post: List-Help: List-Subscribe: , Reply-To: FFmpeg development discussions and patches MIME-Version: 1.0 Errors-To: ffmpeg-devel-bounces@ffmpeg.org Sender: "ffmpeg-devel" X-TUID: hz8Fs+AvsspN Change the DCO_RGB and DCO_BGR color order in DNNColorOrder to DCO_RGB_PACKED and DCO_GBR_PACKED for following RGB planar support. Signed-off-by: Ting Fu --- libavfilter/dnn/dnn_backend_openvino.c | 2 +- libavfilter/dnn/dnn_backend_tf.c | 2 +- libavfilter/dnn/dnn_io_proc.c | 4 ++-- libavfilter/dnn_interface.h | 4 ++-- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/libavfilter/dnn/dnn_backend_openvino.c b/libavfilter/dnn/dnn_backend_openvino.c index cf012aca4c..92e180a0eb 100644 --- a/libavfilter/dnn/dnn_backend_openvino.c +++ b/libavfilter/dnn/dnn_backend_openvino.c @@ -156,7 +156,7 @@ static int fill_model_input_ov(OVModel *ov_model, OVRequestItem *request) input.dt = precision_to_datatype(precision); // all models in openvino open model zoo use BGR as input, // change to be an option when necessary. - input.order = DCO_BGR; + input.order = DCO_BGR_PACKED; for (int i = 0; i < ctx->options.batch_size; ++i) { lltask = ff_queue_pop_front(ov_model->lltask_queue); diff --git a/libavfilter/dnn/dnn_backend_tf.c b/libavfilter/dnn/dnn_backend_tf.c index 3b5084b67b..e639b3cecd 100644 --- a/libavfilter/dnn/dnn_backend_tf.c +++ b/libavfilter/dnn/dnn_backend_tf.c @@ -294,7 +294,7 @@ static int get_input_tf(void *model, DNNData *input, const char *input_name) tf_output.index = 0; input->dt = TF_OperationOutputType(tf_output); - input->order = DCO_RGB; + input->order = DCO_RGB_PACKED; status = TF_NewStatus(); TF_GraphGetTensorShape(tf_model->graph, tf_output, dims, 4, status); diff --git a/libavfilter/dnn/dnn_io_proc.c b/libavfilter/dnn/dnn_io_proc.c index 7961bf6b95..532b089002 100644 --- a/libavfilter/dnn/dnn_io_proc.c +++ b/libavfilter/dnn/dnn_io_proc.c @@ -176,9 +176,9 @@ static enum AVPixelFormat get_pixel_format(DNNData *data) { if (data->dt == DNN_UINT8) { switch (data->order) { - case DCO_BGR: + case DCO_BGR_PACKED: return AV_PIX_FMT_BGR24; - case DCO_RGB: + case DCO_RGB_PACKED: return AV_PIX_FMT_RGB24; default: av_assert0(!"unsupported data pixel format.\n"); diff --git a/libavfilter/dnn_interface.h b/libavfilter/dnn_interface.h index ef8d7ae66f..d94baa90c4 100644 --- a/libavfilter/dnn_interface.h +++ b/libavfilter/dnn_interface.h @@ -38,8 +38,8 @@ typedef enum {DNN_FLOAT = 1, DNN_UINT8 = 4} DNNDataType; typedef enum { DCO_NONE, - DCO_BGR, - DCO_RGB, + DCO_BGR_PACKED, + DCO_RGB_PACKED, } DNNColorOrder; typedef enum { From patchwork Mon May 23 09:29:18 2022 Content-Type: text/plain; charset="utf-8" MIME-Version: 1.0 Content-Transfer-Encoding: 7bit X-Patchwork-Submitter: Ting Fu X-Patchwork-Id: 35893 Delivered-To: ffmpegpatchwork2@gmail.com Received: by 2002:a05:6a21:9992:b0:82:461d:f3b with SMTP id ve18csp1821075pzb; Mon, 23 May 2022 02:43:22 -0700 (PDT) X-Google-Smtp-Source: ABdhPJy9peTJzVPHbQkyqmELe3wO7/gKacTFDZtky0L0gXVOtl8eXsyohkxAeILQ+clNpKvnzoJo X-Received: by 2002:a17:907:7283:b0:6f4:ff4f:1b6e with SMTP id dt3-20020a170907728300b006f4ff4f1b6emr18504237ejc.344.1653299002430; Mon, 23 May 2022 02:43:22 -0700 (PDT) ARC-Seal: i=1; a=rsa-sha256; t=1653299002; cv=none; d=google.com; s=arc-20160816; b=IpxQzDiJm51/JIgwIKQ2B7bwIqJOMUx2scZj6IKHroK2xruKP/tokKJDNuJKJ0H9CI nPo2/dD6ej+T/3kk3G8D95mYiZKo+JWelqtuHBd/OzWK64C0+Lceu7pGgW1GO1Hm77IA BcY4Woq3cxiWJMY+heICAV6TiTGkf1+RzM5QYO9Bw7zSp3/Y8zdnHDHaxC8IASlr3nM9 zbXau6v7I1Vz7JI13iKp7EmpcTqJuWFOwpIffbct1sKH/KI/1z0soNQMF07nGhhkd3Lb n8aGq9kM+CanVZ3OmXxKTrt1wmECih8vQDgygPpQ2N/NSRZ4CRn9CEbcPPgE0B2PxxXi nWcA== ARC-Message-Signature: i=1; a=rsa-sha256; c=relaxed/relaxed; d=google.com; s=arc-20160816; h=sender:errors-to:content-transfer-encoding:mime-version:reply-to :list-subscribe:list-help:list-post:list-archive:list-unsubscribe :list-id:precedence:subject:references:in-reply-to:message-id:date :to:from:dkim-signature:delivered-to; bh=zj3s8hwI8SueSY2gVZvrB3mj5WCHRuVfsGcMiRAo2Fg=; b=ChY6XW/o9TW/kcEuHfqUYccO3VYuoDEbFdSfN+zyi9SVsoxsz7mX5p34oo4e61K3GX PfTyz+wtN6KfJqDGYkLthP/5NQeHF8YXsP92IauKOeer553oQylSiC7FRG6+N0m1WnvR cfSlV5SeJLVss3sB1Z2GC2L9yMGYaGt44mLqkRa7Lqav+1SauazanuLQkeCLLVbxMLjV H4HH//KuWJQhsWbnJrUJpDl/TbqFSSZ0uoiPq/s1MHhsMt/DtcEwHrIqYJf0EHlTIPWf T41ZUnx5PyuNsrb+GNtc2NfeWLeVZpn7737eiHqLrcUB672QnkEqLohnbgETWMHMN1+S ercg== ARC-Authentication-Results: i=1; mx.google.com; dkim=neutral (body hash did not verify) header.i=@intel.com header.s=Intel header.b="RUdn/vWP"; spf=pass (google.com: domain of ffmpeg-devel-bounces@ffmpeg.org designates 79.124.17.100 as permitted sender) smtp.mailfrom=ffmpeg-devel-bounces@ffmpeg.org Return-Path: Received: from ffbox0-bg.mplayerhq.hu (ffbox0-bg.ffmpeg.org. [79.124.17.100]) by mx.google.com with ESMTP id y25-20020a50eb99000000b0042ab8857df5si5680984edr.407.2022.05.23.02.43.22; Mon, 23 May 2022 02:43:22 -0700 (PDT) Received-SPF: pass (google.com: domain of ffmpeg-devel-bounces@ffmpeg.org designates 79.124.17.100 as permitted sender) client-ip=79.124.17.100; Authentication-Results: mx.google.com; dkim=neutral (body hash did not verify) header.i=@intel.com header.s=Intel header.b="RUdn/vWP"; spf=pass (google.com: domain of ffmpeg-devel-bounces@ffmpeg.org designates 79.124.17.100 as permitted sender) smtp.mailfrom=ffmpeg-devel-bounces@ffmpeg.org Received: from [127.0.1.1] (localhost [127.0.0.1]) by ffbox0-bg.mplayerhq.hu (Postfix) with ESMTP id 10D8768B4D2; Mon, 23 May 2022 12:43:13 +0300 (EEST) X-Original-To: ffmpeg-devel@ffmpeg.org Delivered-To: ffmpeg-devel@ffmpeg.org Received: from mga05.intel.com (mga05.intel.com [192.55.52.43]) by ffbox0-bg.mplayerhq.hu (Postfix) with ESMTPS id 668A268B2A8 for ; Mon, 23 May 2022 12:43:05 +0300 (EEST) DKIM-Signature: v=1; a=rsa-sha256; c=relaxed/simple; d=intel.com; i=@intel.com; q=dns/txt; s=Intel; t=1653298990; x=1684834990; h=from:to:subject:date:message-id:in-reply-to:references; bh=F1CjvkDjxaFJTbirjwwTdu7wuEKhrF9MdIGpO3xwdMI=; b=RUdn/vWPvZ84RLl+N2HDGsiyuxd5zx6YWd2bX/a2wddDrPpWNcHuMFhH rOWDBIg/ZZc9aYtCUFgNL20VQOLMxTben74gW36SfvfQ++91l9loHzJXO HSh8kX5dUFb2RJUVNaazg2J9NLBQbrdVa9J+nJApk018EbrcnltBpHdSf u2kwyA2OX5GFEOFPOI6GpKDC7M0GUvWxt54RczXPW19yH0LiQIrY1Af50 W3HmLBBtpSP925eaCY4ZSG2m+pjLKheW01N24TosKdKXb883Dxlvz7wU+ 7ZHhZfzcWIonpZEZLaOCBaLhMT9jioA6KocH26iyFSijnhysk9CLPGLDE g==; X-IronPort-AV: E=McAfee;i="6400,9594,10355"; a="359560329" X-IronPort-AV: E=Sophos;i="5.91,246,1647327600"; d="scan'208";a="359560329" Received: from orsmga008.jf.intel.com ([10.7.209.65]) by fmsmga105.fm.intel.com with ESMTP/TLS/ECDHE-RSA-AES256-GCM-SHA384; 23 May 2022 02:42:58 -0700 X-ExtLoop1: 1 X-IronPort-AV: E=Sophos;i="5.91,246,1647327600"; d="scan'208";a="600553072" Received: from semmer-ubuntu.sh.intel.com ([10.239.159.83]) by orsmga008.jf.intel.com with ESMTP; 23 May 2022 02:42:57 -0700 From: Ting Fu To: ffmpeg-devel@ffmpeg.org Date: Mon, 23 May 2022 17:29:18 +0800 Message-Id: <20220523092918.9548-2-ting.fu@intel.com> X-Mailer: git-send-email 2.17.1 In-Reply-To: <20220523092918.9548-1-ting.fu@intel.com> References: <20220523092918.9548-1-ting.fu@intel.com> Subject: [FFmpeg-devel] [PATCH 2/2] libavfi/dnn: add LibTorch as one of DNN backend X-BeenThere: ffmpeg-devel@ffmpeg.org X-Mailman-Version: 2.1.29 Precedence: list List-Id: FFmpeg development discussions and patches List-Unsubscribe: , List-Archive: List-Post: List-Help: List-Subscribe: , Reply-To: FFmpeg development discussions and patches MIME-Version: 1.0 Errors-To: ffmpeg-devel-bounces@ffmpeg.org Sender: "ffmpeg-devel" X-TUID: B5zeKAt0M9Ud PyTorch is an open source machine learning framework that accelerates the path from research prototyping to production deployment. Official websit: https://pytorch.org/. We call the C++ library of PyTorch as LibTorch, the same below. To build FFmpeg with LibTorch, please take following steps as reference: 1. download LibTorch C++ library in https://pytorch.org/get-started/locally/, please select C++/Java for language, and other options as your need. 2. unzip the file to your own dir, with command unzip libtorch-shared-with-deps-latest.zip -d your_dir 3. export libtorch_root/libtorch/include and libtorch_root/libtorch/include/torch/csrc/api/include to $PATH export libtorch_root/libtorch/lib/ to $LD_LIBRARY_PATH 4. config FFmpeg with ../configure --enable-libtorch --extra-cflag=-I/libtorch_root/libtorch/include --extra-cflag=-I/libtorch_root/libtorch/include/torch/csrc/api/include --extra-ldflags=-L/libtorch_root/libtorch/lib/ 5. make To run FFmpeg DNN inference with LibTorch backend: ./ffmpeg -i input.jpg -vf dnn_processing=dnn_backend=torch:model=LibTorch_model.pt -y output.jpg The LibTorch_model.pt can be generated by Python with torch.jit.script() api. Please note, torch.jit.trace() is not recommanded, since it does not support ambiguous input size. Signed-off-by: Ting Fu --- configure | 7 +- libavfilter/dnn/Makefile | 1 + libavfilter/dnn/dnn_backend_torch.cpp | 567 ++++++++++++++++++++++++++ libavfilter/dnn/dnn_backend_torch.h | 47 +++ libavfilter/dnn/dnn_interface.c | 12 + libavfilter/dnn/dnn_io_proc.c | 117 +++++- libavfilter/dnn_filter_common.c | 31 +- libavfilter/dnn_interface.h | 3 +- libavfilter/vf_dnn_processing.c | 3 + 9 files changed, 774 insertions(+), 14 deletions(-) create mode 100644 libavfilter/dnn/dnn_backend_torch.cpp create mode 100644 libavfilter/dnn/dnn_backend_torch.h diff --git a/configure b/configure index f115b21064..85ce3e67a3 100755 --- a/configure +++ b/configure @@ -279,6 +279,7 @@ External library support: --enable-libtheora enable Theora encoding via libtheora [no] --enable-libtls enable LibreSSL (via libtls), needed for https support if openssl, gnutls or mbedtls is not used [no] + --enable-libtorch enable Torch as one DNN backend --enable-libtwolame enable MP2 encoding via libtwolame [no] --enable-libuavs3d enable AVS3 decoding via libuavs3d [no] --enable-libv4l2 enable libv4l2/v4l-utils [no] @@ -1850,6 +1851,7 @@ EXTERNAL_LIBRARY_LIST=" libopus libplacebo libpulse + libtorch librabbitmq librav1e librist @@ -2719,7 +2721,7 @@ dct_select="rdft" deflate_wrapper_deps="zlib" dirac_parse_select="golomb" dovi_rpu_select="golomb" -dnn_suggest="libtensorflow libopenvino" +dnn_suggest="libtensorflow libopenvino libtorch" dnn_deps="avformat swscale" error_resilience_select="me_cmp" faandct_deps="faan" @@ -6600,6 +6602,7 @@ enabled libopus && { } enabled libplacebo && require_pkg_config libplacebo "libplacebo >= 4.192.0" libplacebo/vulkan.h pl_vulkan_create enabled libpulse && require_pkg_config libpulse libpulse pulse/pulseaudio.h pa_context_new +enabled libtorch && add_cppflags -D_GLIBCXX_USE_CXX11_ABI=0 && check_cxxflags -std=c++14 && require_cpp libtorch torch/torch.h "torch::Tensor" -ltorch -lc10 -ltorch_cpu -lstdc++ -lpthread enabled librabbitmq && require_pkg_config librabbitmq "librabbitmq >= 0.7.1" amqp.h amqp_new_connection enabled librav1e && require_pkg_config librav1e "rav1e >= 0.4.0" rav1e.h rav1e_context_new enabled librist && require_pkg_config librist "librist >= 0.2" librist/librist.h rist_receiver_create @@ -7025,6 +7028,8 @@ check_disable_warning -Wno-pointer-sign check_disable_warning -Wno-unused-const-variable check_disable_warning -Wno-bool-operation check_disable_warning -Wno-char-subscripts +#this option is for supress redundant-decls warning in compile libtorch +check_disable_warning -Wno-redundant-decls check_disable_warning_headers(){ warning_flag=-W${1#-Wno-} diff --git a/libavfilter/dnn/Makefile b/libavfilter/dnn/Makefile index 4cfbce0efc..d44dcb847e 100644 --- a/libavfilter/dnn/Makefile +++ b/libavfilter/dnn/Makefile @@ -16,5 +16,6 @@ OBJS-$(CONFIG_DNN) += dnn/dnn_backend_native_layer_mat DNN-OBJS-$(CONFIG_LIBTENSORFLOW) += dnn/dnn_backend_tf.o DNN-OBJS-$(CONFIG_LIBOPENVINO) += dnn/dnn_backend_openvino.o +DNN-OBJS-$(CONFIG_LIBTORCH) += dnn/dnn_backend_torch.o OBJS-$(CONFIG_DNN) += $(DNN-OBJS-yes) diff --git a/libavfilter/dnn/dnn_backend_torch.cpp b/libavfilter/dnn/dnn_backend_torch.cpp new file mode 100644 index 0000000000..86cc018fbc --- /dev/null +++ b/libavfilter/dnn/dnn_backend_torch.cpp @@ -0,0 +1,567 @@ +/* + * Copyright (c) 2022 + * + * This file is part of FFmpeg. + * + * FFmpeg is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * FFmpeg is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with FFmpeg; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ + +/** + * @file + * DNN Torch backend implementation. + */ + +#include +#include +#include "dnn_backend_torch.h" + +extern "C" { +#include "dnn_io_proc.h" +#include "../internal.h" +#include "dnn_backend_common.h" +#include "libavutil/opt.h" +#include "queue.h" +#include "safe_queue.h" +} + +typedef struct THOptions{ + char *device_name; + c10::DeviceType device_type; +} THOptions; + +typedef struct THContext { + const AVClass *c_class; + THOptions options; +} THContext; + +typedef struct THModel { + THContext ctx; + DNNModel *model; + torch::jit::Module jit_model; + SafeQueue *request_queue; + Queue *task_queue; + Queue *lltask_queue; +} THModel; + +typedef struct THInferRequest { + torch::Tensor *output; + torch::Tensor *input_tensor; +} THInferRequest; + +typedef struct THRequestItem { + THInferRequest *infer_request; + LastLevelTaskItem *lltask; + DNNAsyncExecModule exec_module; +} THRequestItem; + + +#define OFFSET(x) offsetof(THContext, x) +#define FLAGS AV_OPT_FLAG_FILTERING_PARAM +static const AVOption dnn_th_options[] = { + { "device", "device to run model", OFFSET(options.device_name), AV_OPT_TYPE_STRING, { .str = "cpu" }, 0, 0, FLAGS }, + { NULL } +}; + +AVFILTER_DEFINE_CLASS(dnn_th); + +static int execute_model_th(THRequestItem *request, Queue *lltask_queue); +static int th_start_inference(void *args); +static void infer_completion_callback(void *args); + +static int extract_lltask_from_task(TaskItem *task, Queue *lltask_queue) +{ + THModel *th_model = (THModel *)task->model; + THContext *ctx = &th_model->ctx; + LastLevelTaskItem *lltask = (LastLevelTaskItem *)av_malloc(sizeof(*lltask)); + if (!lltask) { + av_log(ctx, AV_LOG_ERROR, "Failed to allocate memory for LastLevelTaskItem\n"); + return AVERROR(ENOMEM); + } + task->inference_todo = 1; + task->inference_done = 0; + lltask->task = task; + if (ff_queue_push_back(lltask_queue, lltask) < 0) { + av_log(ctx, AV_LOG_ERROR, "Failed to push back lltask_queue.\n"); + av_freep(&lltask); + return AVERROR(ENOMEM); + } + return 0; +} + +static int get_input_th(void *model, DNNData *input, const char *input_name) +{ + input->dt = DNN_FLOAT; + input->order = DCO_RGB_PLANAR; + input->height = -1; + input->width = -1; + input->channels = 3; + return 0; +} + +static int get_output_th(void *model, const char *input_name, int input_width, int input_height, + const char *output_name, int *output_width, int *output_height) +{ + int ret = 0; + THModel *th_model = (THModel*) model; + THContext *ctx = &th_model->ctx; + TaskItem task; + THRequestItem *request; + DNNExecBaseParams exec_params = { + .input_name = input_name, + .output_names = &output_name, + .nb_output = 1, + .in_frame = NULL, + .out_frame = NULL, + }; + ret = ff_dnn_fill_gettingoutput_task(&task, &exec_params, th_model, input_height, input_width, ctx); + if ( ret != 0) { + goto err; + } + + ret = extract_lltask_from_task(&task, th_model->lltask_queue); + if ( ret != 0) { + av_log(ctx, AV_LOG_ERROR, "unable to extract last level task from task.\n"); + goto err; + } + + request = (THRequestItem*) ff_safe_queue_pop_front(th_model->request_queue); + if (!request) { + av_log(ctx, AV_LOG_ERROR, "unable to get infer request.\n"); + ret = AVERROR(EINVAL); + goto err; + } + + ret = execute_model_th(request, th_model->lltask_queue); + *output_width = task.out_frame->width; + *output_height = task.out_frame->height; + +err: + av_frame_free(&task.out_frame); + av_frame_free(&task.in_frame); + return ret; +} + +static void th_free_request(THInferRequest *request) +{ + if (!request) + return; + if (request->output) { + delete(request->output); + request->output = NULL; + } + if (request->input_tensor) { + delete(request->input_tensor); + request->input_tensor = NULL; + } + return; +} + +static inline void destroy_request_item(THRequestItem **arg) +{ + THRequestItem *item; + if (!arg || !*arg) { + return; + } + item = *arg; + th_free_request(item->infer_request); + av_freep(&item->infer_request); + av_freep(&item->lltask); + ff_dnn_async_module_cleanup(&item->exec_module); + av_freep(arg); +} + +static THInferRequest *th_create_inference_request(void) +{ + THInferRequest *request = (THInferRequest *)av_malloc(sizeof(THInferRequest)); + if (!request) { + return NULL; + } + request->input_tensor = NULL; + request->output = NULL; + return request; +} + +DNNModel *ff_dnn_load_model_th(const char *model_filename, DNNFunctionType func_type, const char *options, AVFilterContext *filter_ctx) +{ + DNNModel *model = NULL; + THModel *th_model = NULL; + THRequestItem *item = NULL; + THContext *ctx; + + model = (DNNModel *)av_mallocz(sizeof(DNNModel)); + if (!model) { + return NULL; + } + + th_model = (THModel *)av_mallocz(sizeof(THModel)); + if (!th_model) { + av_freep(&model); + return NULL; + } + + th_model->ctx.c_class = &dnn_th_class; + ctx = &th_model->ctx; + //parse options + av_opt_set_defaults(ctx); + if (av_opt_set_from_string(ctx, options, NULL, "=", "&") < 0) { + av_log(ctx, AV_LOG_ERROR, "Failed to parse options \"%s\"\n", options); + return NULL; + } + + c10::Device device = c10::Device(ctx->options.device_name); + if (device.is_cpu()) { + ctx->options.device_type = torch::kCPU; + } else { + av_log(ctx, AV_LOG_ERROR, "Not supported device:\"%s\"\n", ctx->options.device_name); + goto fail; + } + + try { + th_model->jit_model = torch::jit::load(model_filename, device); + } catch (const c10::Error& e) { + av_log(ctx, AV_LOG_ERROR, "Failed to load torch model\n"); + goto fail; + } + + th_model->request_queue = ff_safe_queue_create(); + if (!th_model->request_queue) { + goto fail; + } + + item = (THRequestItem *)av_mallocz(sizeof(THRequestItem)); + if (!item) { + goto fail; + } + item->lltask = NULL; + item->infer_request = th_create_inference_request(); + if (!item->infer_request) { + av_log(NULL, AV_LOG_ERROR, "Failed to allocate memory for Torch inference request\n"); + goto fail; + } + item->exec_module.start_inference = &th_start_inference; + item->exec_module.callback = &infer_completion_callback; + item->exec_module.args = item; + + if (ff_safe_queue_push_back(th_model->request_queue, item) < 0) { + goto fail; + } + + th_model->task_queue = ff_queue_create(); + if (!th_model->task_queue) { + goto fail; + } + + th_model->lltask_queue = ff_queue_create(); + if (!th_model->lltask_queue) { + goto fail; + } + + th_model->model = model; + model->model = th_model; + model->get_input = &get_input_th; + model->get_output = &get_output_th; + model->options = NULL; + model->filter_ctx = filter_ctx; + model->func_type = func_type; + return model; + +fail: + destroy_request_item(&item); + ff_queue_destroy(th_model->task_queue); + ff_queue_destroy(th_model->lltask_queue); + ff_safe_queue_destroy(th_model->request_queue); + av_freep(&th_model); + av_freep(&model); + av_freep(&item); + return NULL; +} + +static int fill_model_input_th(THModel *th_model, THRequestItem *request) +{ + LastLevelTaskItem *lltask = NULL; + TaskItem *task = NULL; + THInferRequest *infer_request = NULL; + DNNData input; + THContext *ctx = &th_model->ctx; + int ret; + + lltask = (LastLevelTaskItem *)ff_queue_pop_front(th_model->lltask_queue); + if (!lltask) { + ret = AVERROR(EINVAL); + goto err; + } + request->lltask = lltask; + task = lltask->task; + infer_request = request->infer_request; + + ret = get_input_th(th_model, &input, NULL); + if ( ret != 0) { + goto err; + } + + input.height = task->in_frame->height; + input.width = task->in_frame->width; + input.data = malloc(input.height * input.width * 3 * sizeof(float)); + if (!input.data) + return AVERROR(ENOMEM); + infer_request->input_tensor = new torch::Tensor(); + infer_request->output = new torch::Tensor(); + + switch (th_model->model->func_type) { + case DFT_PROCESS_FRAME: + if (task->do_ioproc) { + if (th_model->model->frame_pre_proc != NULL) { + th_model->model->frame_pre_proc(task->in_frame, &input, th_model->model->filter_ctx); + } else { + ff_proc_from_frame_to_dnn(task->in_frame, &input, ctx); + } + } + break; + default: + avpriv_report_missing_feature(NULL, "model function type %d", th_model->model->func_type); + break; + } + *infer_request->input_tensor = torch::from_blob(input.data, {1, 1, 3, input.height, input.width}, + torch::kFloat32); + return 0; + +err: + th_free_request(infer_request); + return ret; +} + +static int th_start_inference(void *args) +{ + THRequestItem *request = (THRequestItem *)args; + THInferRequest *infer_request = NULL; + LastLevelTaskItem *lltask = NULL; + TaskItem *task = NULL; + THModel *th_model = NULL; + THContext *ctx = NULL; + std::vector inputs; + + if (!request) { + av_log(NULL, AV_LOG_ERROR, "THRequestItem is NULL\n"); + return AVERROR(EINVAL); + } + infer_request = request->infer_request; + lltask = request->lltask; + task = lltask->task; + th_model = (THModel *)task->model; + ctx = &th_model->ctx; + + if (!infer_request->input_tensor || !infer_request->output) { + av_log(ctx, AV_LOG_ERROR, "input or output tensor is NULL\n"); + return DNN_GENERIC_ERROR; + } + inputs.push_back(*infer_request->input_tensor); + + auto parameters = th_model->jit_model.parameters(); + auto para = *(parameters.begin()); + + *infer_request->output = th_model->jit_model.forward(inputs).toTensor(); + + return 0; +} + +static void infer_completion_callback(void *args) { + THRequestItem *request = (THRequestItem*)args; + LastLevelTaskItem *lltask = request->lltask; + TaskItem *task = lltask->task; + DNNData outputs; + THInferRequest *infer_request = request->infer_request; + THModel *th_model = (THModel *)task->model; + torch::Tensor *output = infer_request->output; + + c10::IntArrayRef sizes = output->sizes(); + assert(sizes.size == 5); + outputs.order = DCO_RGB_PLANAR; + outputs.height = sizes.at(3); + outputs.width = sizes.at(4); + outputs.dt = DNN_FLOAT; + outputs.channels = 3; + + switch (th_model->model->func_type) { + case DFT_PROCESS_FRAME: + if (task->do_ioproc) { + outputs.data = output->data_ptr(); + if (th_model->model->frame_post_proc != NULL) { + th_model->model->frame_post_proc(task->out_frame, &outputs, th_model->model->filter_ctx); + } else { + ff_proc_from_dnn_to_frame(task->out_frame, &outputs, &th_model->ctx); + } + } else { + task->out_frame->width = outputs.width; + task->out_frame->height = outputs.height; + } + break; + default: + avpriv_report_missing_feature(&th_model->ctx, "model function type %d", th_model->model->func_type); + goto err; + } + task->inference_done++; +err: + th_free_request(infer_request); + + if (ff_safe_queue_push_back(th_model->request_queue, request) < 0) { + destroy_request_item(&request); + av_log(&th_model->ctx, AV_LOG_ERROR, "Unable to push back request_queue when failed to start inference.\n"); + } +} + +static int execute_model_th(THRequestItem *request, Queue *lltask_queue) +{ + THModel *th_model = NULL; + LastLevelTaskItem *lltask; + TaskItem *task = NULL; + int ret = 0; + + if (ff_queue_size(lltask_queue) == 0) { + destroy_request_item(&request); + return 0; + } + + lltask = (LastLevelTaskItem *)ff_queue_peek_front(lltask_queue); + if (lltask == NULL) { + av_log(NULL, AV_LOG_ERROR, "Failed to get LastLevelTaskItem\n"); + ret = AVERROR(EINVAL); + goto err; + } + task = lltask->task; + th_model = (THModel *)task->model; + + ret = fill_model_input_th(th_model, request); + if ( ret != 0) { + goto err; + } + if (task->async) { + avpriv_report_missing_feature(&th_model->ctx, "LibTorch async"); + } else { + ret = th_start_inference((void *)(request)); + if (ret != 0) { + goto err; + } + infer_completion_callback(request); + return (task->inference_done == task->inference_todo) ? 0 : DNN_GENERIC_ERROR; + } + +err: + th_free_request(request->infer_request); + if (ff_safe_queue_push_back(th_model->request_queue, request) < 0) { + destroy_request_item(&request); + } + return ret; +} + +int ff_dnn_execute_model_th(const DNNModel *model, DNNExecBaseParams *exec_params) +{ + THModel *th_model = (THModel *)model->model; + THContext *ctx = &th_model->ctx; + TaskItem *task; + THRequestItem *request; + int ret = 0; + + ret = ff_check_exec_params(ctx, DNN_TH, model->func_type, exec_params); + if (ret != 0) { + return ret; + } + + task = (TaskItem *)av_malloc(sizeof(TaskItem)); + if (!task) { + av_log(ctx, AV_LOG_ERROR, "unable to alloc memory for task item.\n"); + return AVERROR(ENOMEM); + } + + ret = ff_dnn_fill_task(task, exec_params, th_model, 0, 1); + if (ret != 0) { + av_freep(&task); + av_log(ctx, AV_LOG_ERROR, "unable to fill task.\n"); + return ret; + } + + ret = ff_queue_push_back(th_model->task_queue, task); + if (ret < 0) { + av_freep(&task); + av_log(ctx, AV_LOG_ERROR, "unable to push back task_queue.\n"); + return ret; + } + + ret = extract_lltask_from_task(task, th_model->lltask_queue); + if (ret != 0) { + av_log(ctx, AV_LOG_ERROR, "unable to extract last level task from task.\n"); + return ret; + } + + request = (THRequestItem *)ff_safe_queue_pop_front(th_model->request_queue); + if (!request) { + av_log(ctx, AV_LOG_ERROR, "unable to get infer request.\n"); + return AVERROR(EINVAL); + } + + return execute_model_th(request, th_model->lltask_queue); +} + + +int ff_dnn_flush_th(const DNNModel *model) +{ + THModel *th_model = (THModel *)model->model; + THRequestItem *request; + + if (ff_queue_size(th_model->lltask_queue) == 0) { + // no pending task need to flush + return 0; + } + request = (THRequestItem *)ff_safe_queue_pop_front(th_model->request_queue); + if (!request) { + av_log(&th_model->ctx, AV_LOG_ERROR, "unable to get infer request.\n"); + return AVERROR(EINVAL); + } + + return execute_model_th(request, th_model->lltask_queue); +} + +DNNAsyncStatusType ff_dnn_get_result_th(const DNNModel *model, AVFrame **in, AVFrame **out) +{ + THModel *th_model = (THModel *)model->model; + return ff_dnn_get_result_common(th_model->task_queue, in, out); +} + +void ff_dnn_free_model_th(DNNModel **model) +{ + THModel *th_model; + if(*model) { + th_model = (THModel *) (*model)->model; + while (ff_safe_queue_size(th_model->request_queue) != 0) { + THRequestItem *item = (THRequestItem *)ff_safe_queue_pop_front(th_model->request_queue); + destroy_request_item(&item); + } + ff_safe_queue_destroy(th_model->request_queue); + + while (ff_queue_size(th_model->lltask_queue) != 0) { + LastLevelTaskItem *item = (LastLevelTaskItem *)ff_queue_pop_front(th_model->lltask_queue); + av_freep(&item); + } + ff_queue_destroy(th_model->lltask_queue); + + while (ff_queue_size(th_model->task_queue) != 0) { + TaskItem *item = (TaskItem *)ff_queue_pop_front(th_model->task_queue); + av_frame_free(&item->in_frame); + av_frame_free(&item->out_frame); + av_freep(&item); + } + } + av_freep(&th_model); + av_freep(model); +} diff --git a/libavfilter/dnn/dnn_backend_torch.h b/libavfilter/dnn/dnn_backend_torch.h new file mode 100644 index 0000000000..5d6a08f85f --- /dev/null +++ b/libavfilter/dnn/dnn_backend_torch.h @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2022 + * + * This file is part of FFmpeg. + * + * FFmpeg is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * FFmpeg is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with FFmpeg; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ + +/** + * @file + * DNN inference functions interface for Torch backend. + */ + +#ifndef AVFILTER_DNN_DNN_BACKEND_TORCH_H +#define AVFILTER_DNN_DNN_BACKEND_TORCH_H + + +#ifdef __cplusplus +extern "C" { +#endif +#include "../dnn_interface.h" + +DNNModel *ff_dnn_load_model_th(const char *model_filename, DNNFunctionType func_type, const char *options, AVFilterContext *filter_ctx); + +int ff_dnn_execute_model_th(const DNNModel *model, DNNExecBaseParams *exec_params); +DNNAsyncStatusType ff_dnn_get_result_th(const DNNModel *model, AVFrame **in, AVFrame **out); +int ff_dnn_flush_th(const DNNModel *model); + +void ff_dnn_free_model_th(DNNModel **model); + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/libavfilter/dnn/dnn_interface.c b/libavfilter/dnn/dnn_interface.c index 554a36b0dc..6f4e02b481 100644 --- a/libavfilter/dnn/dnn_interface.c +++ b/libavfilter/dnn/dnn_interface.c @@ -27,6 +27,7 @@ #include "dnn_backend_native.h" #include "dnn_backend_tf.h" #include "dnn_backend_openvino.h" +#include "dnn_backend_torch.h" #include "libavutil/mem.h" DNNModule *ff_get_dnn_module(DNNBackendType backend_type) @@ -70,6 +71,17 @@ DNNModule *ff_get_dnn_module(DNNBackendType backend_type) return NULL; #endif break; + case DNN_TH: + #if (CONFIG_LIBTORCH == 1) + dnn_module->load_model = &ff_dnn_load_model_th; + dnn_module->execute_model = &ff_dnn_execute_model_th; + dnn_module->get_result = &ff_dnn_get_result_th; + dnn_module->flush = &ff_dnn_flush_th; + dnn_module->free_model = &ff_dnn_free_model_th; + #else + av_freep(&dnn_module); + #endif + break; default: av_log(NULL, AV_LOG_ERROR, "Module backend_type is not native or tensorflow\n"); av_freep(&dnn_module); diff --git a/libavfilter/dnn/dnn_io_proc.c b/libavfilter/dnn/dnn_io_proc.c index 532b089002..cbaa1e601f 100644 --- a/libavfilter/dnn/dnn_io_proc.c +++ b/libavfilter/dnn/dnn_io_proc.c @@ -24,10 +24,20 @@ #include "libavutil/avassert.h" #include "libavutil/detection_bbox.h" +static enum AVPixelFormat get_pixel_format(DNNData *data); + int ff_proc_from_dnn_to_frame(AVFrame *frame, DNNData *output, void *log_ctx) { struct SwsContext *sws_ctx; + int frame_size = frame->height * frame->width; + int linesize[3]; + void **dst_data, *middle_data; + enum AVPixelFormat fmt; int bytewidth = av_image_get_linesize(frame->format, frame->width, 0); + linesize[0] = frame->linesize[0]; + dst_data = (void **)frame->data; + fmt = get_pixel_format(output); + if (bytewidth < 0) { return AVERROR(EINVAL); } @@ -35,6 +45,18 @@ int ff_proc_from_dnn_to_frame(AVFrame *frame, DNNData *output, void *log_ctx) avpriv_report_missing_feature(log_ctx, "data type rather than DNN_FLOAT"); return AVERROR(ENOSYS); } + if (fmt == AV_PIX_FMT_GBRP) { + middle_data = malloc(frame_size * 3 * sizeof(uint8_t)); + if (!middle_data) { + av_log(log_ctx, AV_LOG_ERROR, "Failed to malloc memory for middle_data for " + "the conversion fmt:%s s:%dx%d -> fmt:%s s:%dx%d\n", + av_get_pix_fmt_name(AV_PIX_FMT_GRAYF32), frame->width, frame->height, + av_get_pix_fmt_name(AV_PIX_FMT_GRAY8),frame->width, frame->height); + return AVERROR(EINVAL); + } + dst_data = &middle_data; + linesize[0] = frame->width * 3; + } switch (frame->format) { case AV_PIX_FMT_RGB24: @@ -51,12 +73,43 @@ int ff_proc_from_dnn_to_frame(AVFrame *frame, DNNData *output, void *log_ctx) "fmt:%s s:%dx%d -> fmt:%s s:%dx%d\n", av_get_pix_fmt_name(AV_PIX_FMT_GRAYF32), frame->width * 3, frame->height, av_get_pix_fmt_name(AV_PIX_FMT_GRAY8), frame->width * 3, frame->height); + av_freep(&middle_data); return AVERROR(EINVAL); } sws_scale(sws_ctx, (const uint8_t *[4]){(const uint8_t *)output->data, 0, 0, 0}, (const int[4]){frame->width * 3 * sizeof(float), 0, 0, 0}, 0, frame->height, - (uint8_t * const*)frame->data, frame->linesize); + (uint8_t * const*)dst_data, linesize); sws_freeContext(sws_ctx); + switch (fmt) { + case AV_PIX_FMT_GBRP: + sws_ctx = sws_getContext(frame->width, + frame->height, + AV_PIX_FMT_GBRP, + frame->width, + frame->height, + frame->format, + 0, NULL, NULL, NULL); + if (!sws_ctx) { + av_log(log_ctx, AV_LOG_ERROR, "Impossible to create scale context for the conversion " + "fmt:%s s:%dx%d -> fmt:%s s:%dx%d\n", + av_get_pix_fmt_name(AV_PIX_FMT_GBRP), frame->width, frame->height, + av_get_pix_fmt_name(frame->format),frame->width, frame->height); + av_freep(&middle_data); + return AVERROR(EINVAL); + } + sws_scale(sws_ctx, (const uint8_t * const[4]){(uint8_t *)dst_data[0] + frame_size * sizeof(uint8_t), + (uint8_t *)dst_data[0] + frame_size * sizeof(uint8_t) * 2, + (uint8_t *)dst_data[0], 0}, + (const int [4]){frame->width * sizeof(uint8_t), + frame->width * sizeof(uint8_t), + frame->width * sizeof(uint8_t), 0} + , 0, frame->height, + (uint8_t * const*)frame->data, frame->linesize); + break; + default: + break; + } + av_freep(&middle_data); return 0; case AV_PIX_FMT_GRAYF32: av_image_copy_plane(frame->data[0], frame->linesize[0], @@ -101,6 +154,14 @@ int ff_proc_from_frame_to_dnn(AVFrame *frame, DNNData *input, void *log_ctx) { struct SwsContext *sws_ctx; int bytewidth = av_image_get_linesize(frame->format, frame->width, 0); + int frame_size = frame->height * frame->width; + int linesize[3]; + void **src_data, *middle_data = NULL; + enum AVPixelFormat fmt; + linesize[0] = frame->linesize[0]; + src_data = (void **)frame->data; + fmt = get_pixel_format(input); + if (bytewidth < 0) { return AVERROR(EINVAL); } @@ -112,6 +173,46 @@ int ff_proc_from_frame_to_dnn(AVFrame *frame, DNNData *input, void *log_ctx) switch (frame->format) { case AV_PIX_FMT_RGB24: case AV_PIX_FMT_BGR24: + switch (fmt) { + case AV_PIX_FMT_GBRP: + middle_data = av_malloc(frame_size * 3 * sizeof(uint8_t)); + if (!middle_data) { + av_log(log_ctx, AV_LOG_ERROR, "Failed to malloc memory for middle_data for " + "the conversion fmt:%s s:%dx%d -> fmt:%s s:%dx%d\n", + av_get_pix_fmt_name(frame->format), frame->width, frame->height, + av_get_pix_fmt_name(AV_PIX_FMT_GBRP),frame->width, frame->height); + return AVERROR(EINVAL); + } + sws_ctx = sws_getContext(frame->width, + frame->height, + frame->format, + frame->width, + frame->height, + AV_PIX_FMT_GBRP, + 0, NULL, NULL, NULL); + if (!sws_ctx) { + av_log(log_ctx, AV_LOG_ERROR, "Impossible to create scale context for the conversion " + "fmt:%s s:%dx%d -> fmt:%s s:%dx%d\n", + av_get_pix_fmt_name(frame->format), frame->width, frame->height, + av_get_pix_fmt_name(AV_PIX_FMT_GBRP),frame->width, frame->height); + av_freep(&middle_data); + return AVERROR(EINVAL); + } + sws_scale(sws_ctx, (const uint8_t **)frame->data, + frame->linesize, 0, frame->height, + (uint8_t * const [4]){(uint8_t *)middle_data + frame_size * sizeof(uint8_t), + (uint8_t *)middle_data + frame_size * sizeof(uint8_t) * 2, + (uint8_t *)middle_data, 0}, + (const int [4]){frame->width * sizeof(uint8_t), + frame->width * sizeof(uint8_t), + frame->width * sizeof(uint8_t), 0}); + sws_freeContext(sws_ctx); + src_data = &middle_data; + linesize[0] = frame->width * 3; + break; + default: + break; + } sws_ctx = sws_getContext(frame->width * 3, frame->height, AV_PIX_FMT_GRAY8, @@ -124,13 +225,15 @@ int ff_proc_from_frame_to_dnn(AVFrame *frame, DNNData *input, void *log_ctx) "fmt:%s s:%dx%d -> fmt:%s s:%dx%d\n", av_get_pix_fmt_name(AV_PIX_FMT_GRAY8), frame->width * 3, frame->height, av_get_pix_fmt_name(AV_PIX_FMT_GRAYF32),frame->width * 3, frame->height); + av_freep(&middle_data); return AVERROR(EINVAL); } - sws_scale(sws_ctx, (const uint8_t **)frame->data, - frame->linesize, 0, frame->height, + sws_scale(sws_ctx, (const uint8_t **)src_data, + linesize, 0, frame->height, (uint8_t * const [4]){input->data, 0, 0, 0}, (const int [4]){frame->width * 3 * sizeof(float), 0, 0, 0}); sws_freeContext(sws_ctx); + av_freep(&middle_data); break; case AV_PIX_FMT_GRAYF32: av_image_copy_plane(input->data, bytewidth, @@ -184,6 +287,14 @@ static enum AVPixelFormat get_pixel_format(DNNData *data) av_assert0(!"unsupported data pixel format.\n"); return AV_PIX_FMT_BGR24; } + } else if (data->dt == DNN_FLOAT) { + switch (data->order) { + case DCO_RGB_PLANAR: + return AV_PIX_FMT_GBRP; + default: + av_assert0(!"unsupported data pixel format.\n"); + return AV_PIX_FMT_GBRP; + } } av_assert0(!"unsupported data type.\n"); diff --git a/libavfilter/dnn_filter_common.c b/libavfilter/dnn_filter_common.c index 5083e3de19..a4e1147fb9 100644 --- a/libavfilter/dnn_filter_common.c +++ b/libavfilter/dnn_filter_common.c @@ -53,19 +53,31 @@ static char **separate_output_names(const char *expr, const char *val_sep, int * int ff_dnn_init(DnnContext *ctx, DNNFunctionType func_type, AVFilterContext *filter_ctx) { + DNNBackendType backend = ctx->backend_type; + if (!ctx->model_filename) { av_log(filter_ctx, AV_LOG_ERROR, "model file for network is not specified\n"); return AVERROR(EINVAL); } - if (!ctx->model_inputname) { - av_log(filter_ctx, AV_LOG_ERROR, "input name of the model network is not specified\n"); - return AVERROR(EINVAL); - } - ctx->model_outputnames = separate_output_names(ctx->model_outputnames_string, "&", &ctx->nb_outputs); - if (!ctx->model_outputnames) { - av_log(filter_ctx, AV_LOG_ERROR, "could not parse model output names\n"); - return AVERROR(EINVAL); + if (backend == DNN_TH) { + if (ctx->model_inputname) + av_log(filter_ctx, AV_LOG_WARNING, "LibTorch backend do not require inputname, "\ + "inputname will be ignored.\n"); + if (ctx->model_outputnames) + av_log(filter_ctx, AV_LOG_WARNING, "LibTorch backend do not require outputname(s), "\ + "all outputname(s) will be ignored.\n"); + ctx->nb_outputs = 1; + } else { + if (!ctx->model_inputname) { + av_log(filter_ctx, AV_LOG_ERROR, "input name of the model network is not specified\n"); + return AVERROR(EINVAL); + } + ctx->model_outputnames = separate_output_names(ctx->model_outputnames_string, "&", &ctx->nb_outputs); + if (!ctx->model_outputnames) { + av_log(filter_ctx, AV_LOG_ERROR, "could not parse model output names\n"); + return AVERROR(EINVAL); + } } ctx->dnn_module = ff_get_dnn_module(ctx->backend_type); @@ -113,8 +125,9 @@ int ff_dnn_get_input(DnnContext *ctx, DNNData *input) int ff_dnn_get_output(DnnContext *ctx, int input_width, int input_height, int *output_width, int *output_height) { + const char *model_outputnames = ctx->backend_type == DNN_TH ? NULL : ctx->model_outputnames[0]; return ctx->model->get_output(ctx->model->model, ctx->model_inputname, input_width, input_height, - (const char *)ctx->model_outputnames[0], output_width, output_height); + model_outputnames, output_width, output_height); } int ff_dnn_execute_model(DnnContext *ctx, AVFrame *in_frame, AVFrame *out_frame) diff --git a/libavfilter/dnn_interface.h b/libavfilter/dnn_interface.h index d94baa90c4..32698f788b 100644 --- a/libavfilter/dnn_interface.h +++ b/libavfilter/dnn_interface.h @@ -32,7 +32,7 @@ #define DNN_GENERIC_ERROR FFERRTAG('D','N','N','!') -typedef enum {DNN_NATIVE, DNN_TF, DNN_OV} DNNBackendType; +typedef enum {DNN_NATIVE, DNN_TF, DNN_OV, DNN_TH} DNNBackendType; typedef enum {DNN_FLOAT = 1, DNN_UINT8 = 4} DNNDataType; @@ -40,6 +40,7 @@ typedef enum { DCO_NONE, DCO_BGR_PACKED, DCO_RGB_PACKED, + DCO_RGB_PLANAR, } DNNColorOrder; typedef enum { diff --git a/libavfilter/vf_dnn_processing.c b/libavfilter/vf_dnn_processing.c index cac096a19f..ac1dc6e1d9 100644 --- a/libavfilter/vf_dnn_processing.c +++ b/libavfilter/vf_dnn_processing.c @@ -52,6 +52,9 @@ static const AVOption dnn_processing_options[] = { #endif #if (CONFIG_LIBOPENVINO == 1) { "openvino", "openvino backend flag", 0, AV_OPT_TYPE_CONST, { .i64 = 2 }, 0, 0, FLAGS, "backend" }, +#endif +#if (CONFIG_LIBTORCH == 1) + { "torch", "torch backend flag", 0, AV_OPT_TYPE_CONST, { .i64 = 3 }, 0, 0, FLAGS, "backend" }, #endif DNN_COMMON_OPTIONS { NULL }