Message ID | 20240912051619.133812-4-mozil.petryk@gmail.com |
---|---|
State | New |
Headers | show |
Series | [FFmpeg-devel,1/2] lavu: Move vulkan_spirv to libavutil | expand |
Context | Check | Description |
---|---|---|
yinshiyou/make_loongarch64 | success | Make finished |
yinshiyou/make_fate_loongarch64 | success | Make fate finished |
andriy/make_x86 | success | Make finished |
andriy/make_fate_x86 | success | Make fate finished |
On 12 Sep 2024, at 7:16, Petro Mozil wrote: > This patch contains the code for the VC2 vulkan hwaccel, > as well as changes to configure and makefiles needed to compile them. > Hi, I have absolutely no idea about Dirac, I was just looking through for some common mistakes, so this is not a review on any of the actual maths/shaders. Remarks follow inline: > Signed-off-by: Petro Mozil <mozil.petryk@gmail.com> > --- > configure | 2 + > libavcodec/Makefile | 1 + > libavcodec/diracdec.c | 336 +--- > libavcodec/diracdec.h | 263 +++ > libavcodec/hwaccels.h | 1 + > libavcodec/vulkan_dirac.c | 3817 +++++++++++++++++++++++++++++++++++++ > 6 files changed, 4172 insertions(+), 248 deletions(-) > create mode 100644 libavcodec/diracdec.h > create mode 100644 libavcodec/vulkan_dirac.c > > diff --git a/configure b/configure > index d3bd46f382..fd7e4ab6d8 100755 > --- a/configure > +++ b/configure > @@ -3172,6 +3172,8 @@ av1_vdpau_hwaccel_deps="vdpau VdpPictureInfoAV1" > av1_vdpau_hwaccel_select="av1_decoder" > av1_vulkan_hwaccel_deps="vulkan" > av1_vulkan_hwaccel_select="av1_decoder" > +dirac_vulkan_hwaccel_deps="vulkan spirv_compiler" > +dirac_vulkan_hwaccel_select="dirac_decoder" > h263_vaapi_hwaccel_deps="vaapi" > h263_vaapi_hwaccel_select="h263_decoder" > h263_videotoolbox_hwaccel_deps="videotoolbox" > diff --git a/libavcodec/Makefile b/libavcodec/Makefile > index b6243bbc82..90548ea2d5 100644 > --- a/libavcodec/Makefile > +++ b/libavcodec/Makefile > @@ -1006,6 +1006,7 @@ OBJS-$(CONFIG_AV1_NVDEC_HWACCEL) += nvdec_av1.o > OBJS-$(CONFIG_AV1_VAAPI_HWACCEL) += vaapi_av1.o > OBJS-$(CONFIG_AV1_VDPAU_HWACCEL) += vdpau_av1.o > OBJS-$(CONFIG_AV1_VULKAN_HWACCEL) += vulkan_decode.o vulkan_av1.o > +OBJS-$(CONFIG_DIRAC_VULKAN_HWACCEL) += vulkan_dirac.o > OBJS-$(CONFIG_H263_VAAPI_HWACCEL) += vaapi_mpeg4.o > OBJS-$(CONFIG_H263_VIDEOTOOLBOX_HWACCEL) += videotoolbox.o > OBJS-$(CONFIG_H264_D3D11VA_HWACCEL) += dxva2_h264.o > diff --git a/libavcodec/diracdec.c b/libavcodec/diracdec.c > index 76209aebba..542824f6e1 100644 > --- a/libavcodec/diracdec.c > +++ b/libavcodec/diracdec.c > @@ -26,228 +26,11 @@ > * @author Marco Gerards <marco@gnu.org>, David Conrad, Jordi Ortiz <nenjordi@gmail.com> > */ > > -#include "libavutil/mem.h" > -#include "libavutil/mem_internal.h" > -#include "libavutil/pixdesc.h" > -#include "libavutil/thread.h" > -#include "avcodec.h" > -#include "get_bits.h" > -#include "codec_internal.h" > -#include "decode.h" > -#include "golomb.h" > -#include "dirac_arith.h" > -#include "dirac_vlc.h" > -#include "mpegvideoencdsp.h" > -#include "dirac_dwt.h" > -#include "dirac.h" > -#include "diractab.h" > -#include "diracdsp.h" > -#include "videodsp.h" > - > -#define EDGE_WIDTH 16 > - > -/** > - * The spec limits this to 3 for frame coding, but in practice can be as high as 6 > - */ > -#define MAX_REFERENCE_FRAMES 8 > -#define MAX_DELAY 5 /* limit for main profile for frame coding (TODO: field coding) */ > -#define MAX_FRAMES (MAX_REFERENCE_FRAMES + MAX_DELAY + 1) > -#define MAX_QUANT 255 /* max quant for VC-2 */ > -#define MAX_BLOCKSIZE 32 /* maximum xblen/yblen we support */ > - > -/** > - * DiracBlock->ref flags, if set then the block does MC from the given ref > - */ > -#define DIRAC_REF_MASK_REF1 1 > -#define DIRAC_REF_MASK_REF2 2 > -#define DIRAC_REF_MASK_GLOBAL 4 > - > -/** > - * Value of Picture.reference when Picture is not a reference picture, but > - * is held for delayed output. > - */ > -#define DELAYED_PIC_REF 4 > - > -#define CALC_PADDING(size, depth) \ > - (((size + (1 << depth) - 1) >> depth) << depth) > - > -#define DIVRNDUP(a, b) (((a) + (b) - 1) / (b)) > - > -typedef struct { > - AVFrame *avframe; > - int interpolated[3]; /* 1 if hpel[] is valid */ > - uint8_t *hpel[3][4]; > - uint8_t *hpel_base[3][4]; > - int reference; > - unsigned picture_number; > -} DiracFrame; > - > -typedef struct { > - union { > - int16_t mv[2][2]; > - int16_t dc[3]; > - } u; /* anonymous unions aren't in C99 :( */ > - uint8_t ref; > -} DiracBlock; > - > -typedef struct SubBand { > - int level; > - int orientation; > - int stride; /* in bytes */ > - int width; > - int height; > - int pshift; > - int quant; > - uint8_t *ibuf; > - struct SubBand *parent; > - > - /* for low delay */ > - unsigned length; > - const uint8_t *coeff_data; > -} SubBand; > - > -typedef struct Plane { > - DWTPlane idwt; > - > - int width; > - int height; > - ptrdiff_t stride; > - > - /* block length */ > - uint8_t xblen; > - uint8_t yblen; > - /* block separation (block n+1 starts after this many pixels in block n) */ > - uint8_t xbsep; > - uint8_t ybsep; > - /* amount of overspill on each edge (half of the overlap between blocks) */ > - uint8_t xoffset; > - uint8_t yoffset; > - > - SubBand band[MAX_DWT_LEVELS][4]; > -} Plane; > - > -/* Used by Low Delay and High Quality profiles */ > -typedef struct DiracSlice { > - GetBitContext gb; > - int slice_x; > - int slice_y; > - int bytes; > -} DiracSlice; > - > -typedef struct DiracContext { > - AVCodecContext *avctx; > - MpegvideoEncDSPContext mpvencdsp; > - VideoDSPContext vdsp; > - DiracDSPContext diracdsp; > - DiracVersionInfo version; > - GetBitContext gb; > - AVDiracSeqHeader seq; > - int seen_sequence_header; > - int64_t frame_number; /* number of the next frame to display */ > - Plane plane[3]; > - int chroma_x_shift; > - int chroma_y_shift; > - > - int bit_depth; /* bit depth */ > - int pshift; /* pixel shift = bit_depth > 8 */ > - > - int zero_res; /* zero residue flag */ > - int is_arith; /* whether coeffs use arith or golomb coding */ > - int core_syntax; /* use core syntax only */ > - int low_delay; /* use the low delay syntax */ > - int hq_picture; /* high quality picture, enables low_delay */ > - int ld_picture; /* use low delay picture, turns on low_delay */ > - int dc_prediction; /* has dc prediction */ > - int globalmc_flag; /* use global motion compensation */ > - int num_refs; /* number of reference pictures */ > - > - /* wavelet decoding */ > - unsigned wavelet_depth; /* depth of the IDWT */ > - unsigned wavelet_idx; > - > - /** > - * schroedinger older than 1.0.8 doesn't store > - * quant delta if only one codebook exists in a band > - */ > - unsigned old_delta_quant; > - unsigned codeblock_mode; > - > - unsigned num_x; /* number of horizontal slices */ > - unsigned num_y; /* number of vertical slices */ > - > - uint8_t *thread_buf; /* Per-thread buffer for coefficient storage */ > - int threads_num_buf; /* Current # of buffers allocated */ > - int thread_buf_size; /* Each thread has a buffer this size */ > - > - DiracSlice *slice_params_buf; > - int slice_params_num_buf; > - > - struct { > - unsigned width; > - unsigned height; > - } codeblock[MAX_DWT_LEVELS+1]; > - > - struct { > - AVRational bytes; /* average bytes per slice */ > - uint8_t quant[MAX_DWT_LEVELS][4]; /* [DIRAC_STD] E.1 */ > - } lowdelay; > - > - struct { > - unsigned prefix_bytes; > - uint64_t size_scaler; > - } highquality; > - > - struct { > - int pan_tilt[2]; /* pan/tilt vector */ > - int zrs[2][2]; /* zoom/rotate/shear matrix */ > - int perspective[2]; /* perspective vector */ > - unsigned zrs_exp; > - unsigned perspective_exp; > - } globalmc[2]; > - > - /* motion compensation */ > - uint8_t mv_precision; /* [DIRAC_STD] REFS_WT_PRECISION */ > - int16_t weight[2]; /* [DIRAC_STD] REF1_WT and REF2_WT */ > - unsigned weight_log2denom; /* [DIRAC_STD] REFS_WT_PRECISION */ > - > - int blwidth; /* number of blocks (horizontally) */ > - int blheight; /* number of blocks (vertically) */ > - int sbwidth; /* number of superblocks (horizontally) */ > - int sbheight; /* number of superblocks (vertically) */ > - > - uint8_t *sbsplit; > - DiracBlock *blmotion; > - > - uint8_t *edge_emu_buffer[4]; > - uint8_t *edge_emu_buffer_base; > - > - uint16_t *mctmp; /* buffer holding the MC data multiplied by OBMC weights */ > - uint8_t *mcscratch; > - int buffer_stride; > - > - DECLARE_ALIGNED(16, uint8_t, obmc_weight)[3][MAX_BLOCKSIZE*MAX_BLOCKSIZE]; > - > - void (*put_pixels_tab[4])(uint8_t *dst, const uint8_t *src[5], int stride, int h); > - void (*avg_pixels_tab[4])(uint8_t *dst, const uint8_t *src[5], int stride, int h); > - void (*add_obmc)(uint16_t *dst, const uint8_t *src, int stride, const uint8_t *obmc_weight, int yblen); > - dirac_weight_func weight_func; > - dirac_biweight_func biweight_func; > - > - DiracFrame *current_picture; > - DiracFrame *ref_pics[2]; > - > - DiracFrame *ref_frames[MAX_REFERENCE_FRAMES+1]; > - DiracFrame *delay_frames[MAX_DELAY+1]; > - DiracFrame all_frames[MAX_FRAMES]; > -} DiracContext; > - > -enum dirac_subband { > - subband_ll = 0, > - subband_hl = 1, > - subband_lh = 2, > - subband_hh = 3, > - subband_nb, > -}; > +#include "diracdec.h" > +#include "hwaccels.h" > +#include "hwconfig.h" > +#include "libavutil/imgutils.c" > +#include "config_components.h" > > /* magic number division by 3 from schroedinger */ > static inline int divide3(int x) > @@ -351,7 +134,7 @@ static int alloc_buffers(DiracContext *s, int stride) > return 0; > } > > -static av_cold void free_sequence_buffers(DiracContext *s) > +static void free_sequence_buffers(DiracContext *s) > { > int i, j, k; > > @@ -403,8 +186,11 @@ static av_cold int dirac_decode_init(AVCodecContext *avctx) > > for (i = 0; i < MAX_FRAMES; i++) { > s->all_frames[i].avframe = av_frame_alloc(); > - if (!s->all_frames[i].avframe) > + if (!s->all_frames[i].avframe) { > + while (i > 0) > + av_frame_free(&s->all_frames[--i].avframe); > return AVERROR(ENOMEM); > + } > } > ret = ff_thread_once(&dirac_arith_init, ff_dirac_init_arith_tables); > if (ret != 0) > @@ -413,7 +199,7 @@ static av_cold int dirac_decode_init(AVCodecContext *avctx) > return 0; > } > > -static av_cold void dirac_decode_flush(AVCodecContext *avctx) > +static void dirac_decode_flush(AVCodecContext *avctx) > { > DiracContext *s = avctx->priv_data; > free_sequence_buffers(s); > @@ -426,9 +212,7 @@ static av_cold int dirac_decode_end(AVCodecContext *avctx) > DiracContext *s = avctx->priv_data; > int i; > > - // Necessary in case dirac_decode_init() failed > - if (s->all_frames[MAX_FRAMES - 1].avframe) > - free_sequence_buffers(s); > + dirac_decode_flush(avctx); > for (i = 0; i < MAX_FRAMES; i++) > av_frame_free(&s->all_frames[i].avframe); > > @@ -812,14 +596,6 @@ static int decode_lowdelay_slice(AVCodecContext *avctx, void *arg) > return 0; > } > > -typedef struct SliceCoeffs { > - int left; > - int top; > - int tot_h; > - int tot_v; > - int tot; > -} SliceCoeffs; > - > static int subband_coeffs(const DiracContext *s, int x, int y, int p, > SliceCoeffs c[MAX_DWT_LEVELS]) > { > @@ -1006,7 +782,10 @@ static int decode_lowdelay(DiracContext *s) > return AVERROR_INVALIDDATA; > } > > - avctx->execute2(avctx, decode_hq_slice_row, slices, NULL, s->num_y); > + if (avctx->hwaccel) > + FF_HW_CALL(avctx, decode_slice, NULL, 0); > + else > + avctx->execute2(avctx, decode_hq_slice_row, slices, NULL, s->num_y); > } else { > for (slice_y = 0; bufsize > 0 && slice_y < s->num_y; slice_y++) { > for (slice_x = 0; bufsize > 0 && slice_x < s->num_x; slice_x++) { > @@ -1873,7 +1652,13 @@ static int dirac_decode_frame_internal(DiracContext *s) > { > DWTContext d; > int y, i, comp, dsty; > - int ret; > + int ret = -1; > + > + if (s->avctx->hwaccel) { > + ret = FF_HW_CALL(s->avctx, start_frame, NULL, 0); > + if (ret < 0) > + return ret; > + } > > if (s->low_delay) { > /* [DIRAC_STD] 13.5.1 low_delay_transform_data() */ > @@ -1889,6 +1674,14 @@ static int dirac_decode_frame_internal(DiracContext *s) > } > } > > + if (s->avctx->hwaccel) { > + ret = ffhwaccel(s->avctx->hwaccel)->end_frame(s->avctx); Can’t you use FF_HW_SIMPLE_CALL here? > + if (ret == 0) { > + /* Hwaccel failed - fall back on software decoder */ > + } > + return ret; This error handling looks not correct? If I understand the code correctly, this returns 0 on success. And also the return is outside the block so you would always return here no matter what anyway? > + } > + > for (comp = 0; comp < 3; comp++) { > Plane *p = &s->plane[comp]; > uint8_t *frame = s->current_picture->avframe->data[comp]; > @@ -1904,6 +1697,7 @@ static int dirac_decode_frame_internal(DiracContext *s) > if (ret < 0) > return ret; > } > + > ret = ff_spatial_idwt_init(&d, &p->idwt, s->wavelet_idx+2, > s->wavelet_depth, s->bit_depth); > if (ret < 0) > @@ -1970,15 +1764,23 @@ static int get_buffer_with_edge(AVCodecContext *avctx, AVFrame *f, int flags) > { > int ret, i; > int chroma_x_shift, chroma_y_shift; > - ret = av_pix_fmt_get_chroma_sub_sample(avctx->pix_fmt, &chroma_x_shift, > + DiracContext *s = avctx->priv_data; > + ret = av_pix_fmt_get_chroma_sub_sample(s->sof_pix_fmt, &chroma_x_shift, > &chroma_y_shift); > if (ret < 0) > return ret; > > + /*if (avctx->hwaccel) {*/ > + /* f->width = s->plane[0].width;*/ > + /* f->height = s->plane[0].height;*/ > + /* ret = ff_get_buffer(avctx, f, flags);*/ > + /* return ret;*/ > + /*}*/ Forgotten to remove? > + > f->width = avctx->width + 2 * EDGE_WIDTH; > f->height = avctx->height + 2 * EDGE_WIDTH + 2; > ret = ff_get_buffer(avctx, f, flags); > - if (ret < 0) > + if (ret < 0 || avctx->hwaccel) > return ret; > > for (i = 0; f->data[i]; i++) { > @@ -2136,6 +1938,7 @@ static int dirac_decode_data_unit(AVCodecContext *avctx, const uint8_t *buf, int > init_get_bits(&s->gb, &buf[13], 8*(size - DATA_UNIT_HEADER_SIZE)); > > if (parse_code == DIRAC_PCODE_SEQ_HEADER) { > + enum AVPixelFormat *pix_fmts; > if (s->seen_sequence_header) > return 0; > > @@ -2156,6 +1959,7 @@ static int dirac_decode_data_unit(AVCodecContext *avctx, const uint8_t *buf, int > } > > ff_set_sar(avctx, dsh->sample_aspect_ratio); > + s->sof_pix_fmt = dsh->pix_fmt; > avctx->pix_fmt = dsh->pix_fmt; > avctx->color_range = dsh->color_range; > avctx->color_trc = dsh->color_trc; > @@ -2172,7 +1976,20 @@ static int dirac_decode_data_unit(AVCodecContext *avctx, const uint8_t *buf, int > > s->pshift = s->bit_depth > 8; > > - ret = av_pix_fmt_get_chroma_sub_sample(avctx->pix_fmt, > + /*if (s->pshift) {*/ > + /* avctx->pix_fmt = s->sof_pix_fmt;*/ > + /*} else {*/ Same remark as above? > + pix_fmts = (enum AVPixelFormat[]){ > +#if CONFIG_DIRAC_VULKAN_HWACCEL > + AV_PIX_FMT_VULKAN, > +#endif > + s->sof_pix_fmt, > + AV_PIX_FMT_NONE, > + }; > + avctx->pix_fmt = ff_get_format(s->avctx, pix_fmts); > + /*}*/ Same here, also a few more instances further down but I think it is a bit pointless to point them out all here. Is there a reason why this was done or was it just forgotten to be cleaned up? > + > + ret = av_pix_fmt_get_chroma_sub_sample(s->sof_pix_fmt, > &s->chroma_x_shift, > &s->chroma_y_shift); > if (ret < 0) > @@ -2202,9 +2019,10 @@ static int dirac_decode_data_unit(AVCodecContext *avctx, const uint8_t *buf, int > } > > /* find an unused frame */ > - for (i = 0; i < MAX_FRAMES; i++) > + for (i = 0; i < MAX_FRAMES; i++) > if (s->all_frames[i].avframe->data[0] == NULL) > pic = &s->all_frames[i]; > + > if (!pic) { > av_log(avctx, AV_LOG_ERROR, "framelist full\n"); > return AVERROR_INVALIDDATA; > @@ -2244,12 +2062,28 @@ static int dirac_decode_data_unit(AVCodecContext *avctx, const uint8_t *buf, int > if ((ret = get_buffer_with_edge(avctx, pic->avframe, (parse_code & 0x0C) == 0x0C ? AV_GET_BUFFER_FLAG_REF : 0)) < 0) > return ret; > s->current_picture = pic; > - s->plane[0].stride = pic->avframe->linesize[0]; > - s->plane[1].stride = pic->avframe->linesize[1]; > - s->plane[2].stride = pic->avframe->linesize[2]; > > - if (alloc_buffers(s, FFMAX3(FFABS(s->plane[0].stride), FFABS(s->plane[1].stride), FFABS(s->plane[2].stride))) < 0) > - return AVERROR(ENOMEM); > + if (s->avctx->hwaccel) { > + if (!(s->low_delay && s->hq_picture)) { > + av_log(avctx, AV_LOG_ERROR, "The HWaccel only supports VC-2\n"); > + return AVERROR_INVALIDDATA; > + } > + > + if (!s->hwaccel_picture_private) { > + const FFHWAccel *hwaccel = ffhwaccel(s->avctx->hwaccel); > + s->hwaccel_picture_private = > + av_mallocz(hwaccel->frame_priv_data_size); > + if (!s->hwaccel_picture_private) > + return AVERROR(ENOMEM); > + } > + } else { > + s->plane[0].stride = pic->avframe->linesize[0]; > + s->plane[1].stride = pic->avframe->linesize[1]; > + s->plane[2].stride = pic->avframe->linesize[2]; > + > + if (alloc_buffers(s, FFMAX3(FFABS(s->plane[0].stride), FFABS(s->plane[1].stride), FFABS(s->plane[2].stride))) < 0) > + return AVERROR(ENOMEM); > + } > > /* [DIRAC_STD] 11.1 Picture parse. picture_parse() */ > ret = dirac_decode_picture_header(s); > @@ -2359,6 +2193,7 @@ static int dirac_decode_frame(AVCodecContext *avctx, AVFrame *picture, > return buf_idx; > } > > + > const FFCodec ff_dirac_decoder = { > .p.name = "dirac", > CODEC_LONG_NAME("BBC Dirac VC-2"), > @@ -2370,5 +2205,10 @@ const FFCodec ff_dirac_decoder = { > FF_CODEC_DECODE_CB(dirac_decode_frame), > .p.capabilities = AV_CODEC_CAP_DELAY | AV_CODEC_CAP_SLICE_THREADS | AV_CODEC_CAP_DR1, > .flush = dirac_decode_flush, > - .caps_internal = FF_CODEC_CAP_INIT_CLEANUP, > + .hw_configs = (const AVCodecHWConfigInternal *const []) { > +#if CONFIG_DIRAC_VULKAN_HWACCEL > + HWACCEL_VULKAN(dirac), > +#endif > + NULL > + }, > }; > diff --git a/libavcodec/diracdec.h b/libavcodec/diracdec.h > new file mode 100644 > index 0000000000..4ca07342ac > --- /dev/null > +++ b/libavcodec/diracdec.h > @@ -0,0 +1,263 @@ > +/* > + * This file is part of FFmpeg. > + * > + * FFmpeg is free software; you can redistribute it and/or > + * modify it under the terms of the GNU Lesser General Public > + * License as published by the Free Software Foundation; either > + * version 2.1 of the License, or (at your option) any later version. > + * > + * FFmpeg is distributed in the hope that it will be useful, > + * but WITHOUT ANY WARRANTY; without even the implied warranty of > + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU > + * Lesser General Public License for more details. > + * > + * You should have received a copy of the GNU Lesser General Public > + * License along with FFmpeg; if not, write to the Free Software > + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA > + */ > + > +/** > + * @file > + * Dirac Decoder Header > + * @author Marco Gerards <marco@gnu.org>, David Conrad, Jordi Ortiz <nenjordi@gmail.com> > + */ > + > + > +#ifndef AVCODEC_DIRACDEC_H > +#define AVCODEC_DIRACDEC_H > + > +#include "libavutil/mem.h" > +#include "libavutil/mem_internal.h" > +#include "libavutil/pixdesc.h" > +#include "libavutil/thread.h" > +#include "avcodec.h" > +#include "get_bits.h" > +#include "codec_internal.h" > +#include "decode.h" > +#include "golomb.h" > +#include "dirac_arith.h" > +#include "dirac_vlc.h" > +#include "mpegvideoencdsp.h" > +#include "dirac_dwt.h" > +#include "dirac.h" > +#include "diractab.h" > +#include "diracdsp.h" > +#include "videodsp.h" > +#include "hwaccel_internal.h" > + > +#define EDGE_WIDTH 16 > + > +/** > + * The spec limits this to 3 for frame coding, but in practice can be as high as 6 > + */ > +#define MAX_REFERENCE_FRAMES 8 > +#define MAX_DELAY 5 /* limit for main profile for frame coding (TODO: field coding) */ > +#define MAX_FRAMES (MAX_REFERENCE_FRAMES + MAX_DELAY + 1) > +#define MAX_QUANT 255 /* max quant for VC-2 */ > +#define MAX_BLOCKSIZE 32 /* maximum xblen/yblen we support */ > + > +/** > + * DiracBlock->ref flags, if set then the block does MC from the given ref > + */ > +#define DIRAC_REF_MASK_REF1 1 > +#define DIRAC_REF_MASK_REF2 2 > +#define DIRAC_REF_MASK_GLOBAL 4 > + > +/** > + * Value of Picture.reference when Picture is not a reference picture, but > + * is held for delayed output. > + */ > +#define DELAYED_PIC_REF 4 > + > +#define CALC_PADDING(size, depth) \ > + (((size + (1 << depth) - 1) >> depth) << depth) > + > +#define DIVRNDUP(a, b) (((a) + (b) - 1) / (b)) > + > +typedef struct { > + AVFrame *avframe; > + int interpolated[3]; /* 1 if hpel[] is valid */ > + uint8_t *hpel[3][4]; > + uint8_t *hpel_base[3][4]; > + int reference; > + unsigned picture_number; > +} DiracFrame; > + > +typedef struct { > + union { > + int16_t mv[2][2]; > + int16_t dc[3]; > + } u; /* anonymous unions aren't in C99 :( */ > + uint8_t ref; > +} DiracBlock; > + > +typedef struct SubBand { > + int level; > + int orientation; > + int stride; /* in bytes */ > + int width; > + int height; > + int pshift; > + int quant; > + uint8_t *ibuf; > + struct SubBand *parent; > + > + /* for low delay */ > + unsigned length; > + const uint8_t *coeff_data; > +} SubBand; > + > +typedef struct Plane { > + DWTPlane idwt; > + > + int width; > + int height; > + ptrdiff_t stride; > + > + /* block length */ > + uint8_t xblen; > + uint8_t yblen; > + /* block separation (block n+1 starts after this many pixels in block n) */ > + uint8_t xbsep; > + uint8_t ybsep; > + /* amount of overspill on each edge (half of the overlap between blocks) */ > + uint8_t xoffset; > + uint8_t yoffset; > + > + SubBand band[MAX_DWT_LEVELS][4]; > +} Plane; > + > +/* Used by Low Delay and High Quality profiles */ > +typedef struct DiracSlice { > + GetBitContext gb; > + int slice_x; > + int slice_y; > + int bytes; > +} DiracSlice; > + > +typedef struct DiracContext { > + AVCodecContext *avctx; > + MpegvideoEncDSPContext mpvencdsp; > + VideoDSPContext vdsp; > + DiracDSPContext diracdsp; > + DiracVersionInfo version; > + GetBitContext gb; > + AVDiracSeqHeader seq; > + enum AVPixelFormat sof_pix_fmt; > + void *hwaccel_picture_private; > + int seen_sequence_header; > + int64_t frame_number; /* number of the next frame to display */ > + Plane plane[3]; > + int chroma_x_shift; > + int chroma_y_shift; > + > + int bit_depth; /* bit depth */ > + int pshift; /* pixel shift = bit_depth > 8 */ > + > + int zero_res; /* zero residue flag */ > + int is_arith; /* whether coeffs use arith or golomb coding */ > + int core_syntax; /* use core syntax only */ > + int low_delay; /* use the low delay syntax */ > + int hq_picture; /* high quality picture, enables low_delay */ > + int ld_picture; /* use low delay picture, turns on low_delay */ > + int dc_prediction; /* has dc prediction */ > + int globalmc_flag; /* use global motion compensation */ > + int num_refs; /* number of reference pictures */ > + > + /* wavelet decoding */ > + unsigned wavelet_depth; /* depth of the IDWT */ > + unsigned wavelet_idx; > + > + /** > + * schroedinger older than 1.0.8 doesn't store > + * quant delta if only one codebook exists in a band > + */ > + unsigned old_delta_quant; > + unsigned codeblock_mode; > + > + unsigned num_x; /* number of horizontal slices */ > + unsigned num_y; /* number of vertical slices */ > + > + uint8_t *thread_buf; /* Per-thread buffer for coefficient storage */ > + int threads_num_buf; /* Current # of buffers allocated */ > + int thread_buf_size; /* Each thread has a buffer this size */ > + > + DiracSlice *slice_params_buf; > + int slice_params_num_buf; > + > + struct { > + unsigned width; > + unsigned height; > + } codeblock[MAX_DWT_LEVELS+1]; > + > + struct { > + AVRational bytes; /* average bytes per slice */ > + uint8_t quant[MAX_DWT_LEVELS][4]; /* [DIRAC_STD] E.1 */ > + } lowdelay; > + > + struct { > + unsigned prefix_bytes; > + uint64_t size_scaler; > + } highquality; > + > + struct { > + int pan_tilt[2]; /* pan/tilt vector */ > + int zrs[2][2]; /* zoom/rotate/shear matrix */ > + int perspective[2]; /* perspective vector */ > + unsigned zrs_exp; > + unsigned perspective_exp; > + } globalmc[2]; > + > + /* motion compensation */ > + uint8_t mv_precision; /* [DIRAC_STD] REFS_WT_PRECISION */ > + int16_t weight[2]; /* [DIRAC_STD] REF1_WT and REF2_WT */ > + unsigned weight_log2denom; /* [DIRAC_STD] REFS_WT_PRECISION */ > + > + int blwidth; /* number of blocks (horizontally) */ > + int blheight; /* number of blocks (vertically) */ > + int sbwidth; /* number of superblocks (horizontally) */ > + int sbheight; /* number of superblocks (vertically) */ > + > + uint8_t *sbsplit; > + DiracBlock *blmotion; > + > + uint8_t *edge_emu_buffer[4]; > + uint8_t *edge_emu_buffer_base; > + > + uint16_t *mctmp; /* buffer holding the MC data multiplied by OBMC weights */ > + uint8_t *mcscratch; > + int buffer_stride; > + > + DECLARE_ALIGNED(16, uint8_t, obmc_weight)[3][MAX_BLOCKSIZE*MAX_BLOCKSIZE]; > + > + void (*put_pixels_tab[4])(uint8_t *dst, const uint8_t *src[5], int stride, int h); > + void (*avg_pixels_tab[4])(uint8_t *dst, const uint8_t *src[5], int stride, int h); > + void (*add_obmc)(uint16_t *dst, const uint8_t *src, int stride, const uint8_t *obmc_weight, int yblen); > + dirac_weight_func weight_func; > + dirac_biweight_func biweight_func; > + > + DiracFrame *current_picture; > + DiracFrame *ref_pics[2]; > + > + DiracFrame *ref_frames[MAX_REFERENCE_FRAMES+1]; > + DiracFrame *delay_frames[MAX_DELAY+1]; > + DiracFrame all_frames[MAX_FRAMES]; > +} DiracContext; > + > +enum dirac_subband { > + subband_ll = 0, > + subband_hl = 1, > + subband_lh = 2, > + subband_hh = 3, > + subband_nb, > +}; > + > +typedef struct SliceCoeffs { > + int left; > + int top; > + int tot_h; > + int tot_v; > + int tot; > +} SliceCoeffs; > + > +#endif > diff --git a/libavcodec/hwaccels.h b/libavcodec/hwaccels.h > index 5171e4c7d7..f6d148b169 100644 > --- a/libavcodec/hwaccels.h > +++ b/libavcodec/hwaccels.h > @@ -27,6 +27,7 @@ extern const struct FFHWAccel ff_av1_nvdec_hwaccel; > extern const struct FFHWAccel ff_av1_vaapi_hwaccel; > extern const struct FFHWAccel ff_av1_vdpau_hwaccel; > extern const struct FFHWAccel ff_av1_vulkan_hwaccel; > +extern const struct FFHWAccel ff_dirac_vulkan_hwaccel; > extern const struct FFHWAccel ff_h263_vaapi_hwaccel; > extern const struct FFHWAccel ff_h263_videotoolbox_hwaccel; > extern const struct FFHWAccel ff_h264_d3d11va_hwaccel; > diff --git a/libavcodec/vulkan_dirac.c b/libavcodec/vulkan_dirac.c > new file mode 100644 > index 0000000000..7f30e4f0fe > --- /dev/null > +++ b/libavcodec/vulkan_dirac.c > @@ -0,0 +1,3817 @@ > +/* > + * This file is part of FFmpeg. > + * > + * FFmpeg is free software; you can redistribute it and/or > + * modify it under the terms of the GNU Lesser General Public > + * License as published by the Free Software Foundation; either > + * version 2.1 of the License, or (at your option) any later version. > + * > + * FFmpeg is distributed in the hope that it will be useful, > + * but WITHOUT ANY WARRANTY; without even the implied warranty of > + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU > + * Lesser General Public License for more details. > + * > + * You should have received a copy of the GNU Lesser General Public > + * License along with FFmpeg; if not, write to the Free Software > + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA > + */ > + > +#include "diracdec.h" > +#include "libavcodec/dirac_vlc.h" > +#include "libavcodec/pthread_internal.h" > +#include "libavfilter/vulkan_spirv.h" > +#include "libavutil/vulkan_loader.h" > +#include "vulkan.h" > +#include "vulkan_decode.h" > + > +typedef struct SubbandOffset { > + int base_off; > + int stride; > + int pad0; > + int pad1; > +} SubbandOffset; > + > +typedef struct SliceCoeffVk { > + int left; > + int top; > + int tot_h; > + int tot_v; > + int tot; > + int offs; > + int pad0; > + int pad1; > +} SliceCoeffVk; > + > +typedef struct WaveletPushConst { > + int real_plane_dims[6]; > + int plane_offs[3]; > + int plane_strides[3]; > + int dw[3]; > + int wavelet_depth; > +} WaveletPushConst; > + > +typedef struct DiracVulkanDecodeContext { > + FFVulkanContext vkctx; > + VkSamplerYcbcrConversion yuv_sampler; > + VkSampler sampler; > + > + FFVulkanPipeline vert_wavelet_pl[9]; > + FFVkSPIRVShader vert_wavelet_shd[9]; > + > + FFVulkanPipeline horiz_wavelet_pl[9]; > + FFVkSPIRVShader horiz_wavelet_shd[9]; > + > + FFVulkanPipeline cpy_to_image_pl[3]; > + FFVkSPIRVShader cpy_to_image_shd[3]; > + > + FFVulkanPipeline quant_pl; > + FFVkSPIRVShader quant_shd; > + > + FFVkQueueFamilyCtx qf; > + FFVkExecPool exec_pool; > + > + int quant_val_buf_size; > + int thread_buf_size; > + int32_t *quant_val_buf_vk_ptr; > + FFVkBuffer *quant_val_buf; > + AVBufferRef *av_quant_val_buf; > + size_t quant_val_buf_offs; > + > + int n_slice_bufs; > + int slice_buf_size; > + SliceCoeffVk *slice_buf_vk_ptr; > + FFVkBuffer *quant_buf; > + AVBufferRef *av_quant_buf; > + size_t quant_buf_offs; > + > + int32_t *quant_buf_vk_ptr; > + int quant_buf_size; > + FFVkBuffer *slice_buf; > + AVBufferRef *av_slice_buf; > + size_t slice_buf_offs; > + > + FFVkBuffer tmp_buf; > + FFVkBuffer tmp_interleave_buf; > + > + FFVkBuffer subband_info; > + SubbandOffset *subband_info_ptr; > + > + int slice_vals_size; > + > + WaveletPushConst pConst; > +} DiracVulkanDecodeContext; > + > +typedef struct DiracVulkanDecodePicture { > + DiracFrame *frame; > +} DiracVulkanDecodePicture; > + > +static void free_common(AVCodecContext *avctx) { > + DiracVulkanDecodeContext *dec = avctx->internal->hwaccel_priv_data; > + DiracContext *ctx = avctx->priv_data; > + FFVulkanContext *s = &dec->vkctx; > + FFVulkanFunctions *vk = &dec->vkctx.vkfn; > + > + if (ctx->hwaccel_picture_private) { > + av_free(ctx->hwaccel_picture_private); > + } > + > + /* Wait on and free execution pool */ > + if (dec->exec_pool.cmd_bufs) { > + ff_vk_exec_pool_free(s, &dec->exec_pool); > + } > + > + ff_vk_pipeline_free(s, &dec->quant_pl); > + ff_vk_shader_free(s, &dec->quant_shd); > + > + for (int i = 0; i < 3; i++) { > + ff_vk_pipeline_free(s, &dec->cpy_to_image_pl[i]); > + ff_vk_shader_free(s, &dec->cpy_to_image_shd[i]); > + } > + > + for (int i = 0; i < 9; i++) { > + ff_vk_pipeline_free(s, &dec->vert_wavelet_pl[i]); > + ff_vk_shader_free(s, &dec->vert_wavelet_shd[i]); > + > + ff_vk_pipeline_free(s, &dec->horiz_wavelet_pl[i]); > + ff_vk_shader_free(s, &dec->horiz_wavelet_shd[i]); > + } > + // TODO: Add freeing all pipelines and shaders for wavelets > + // > + > + // if (dec->yuv_sampler) > + // vk->DestroySamplerYcbcrConversion(s->hwctx->act_dev, > + // dec->yuv_sampler, > + // s->hwctx->alloc); > + if (dec->sampler) > + vk->DestroySampler(s->hwctx->act_dev, dec->sampler, s->hwctx->alloc); > + > + av_buffer_unref(&dec->av_quant_val_buf); > + av_buffer_unref(&dec->av_quant_buf); > + av_buffer_unref(&dec->av_slice_buf); > + av_buffer_unref(&dec->av_slice_buf); > + > + ff_vk_free_buf(&dec->vkctx, &dec->subband_info); > + > + ff_vk_free_buf(&dec->vkctx, &dec->tmp_buf); > + ff_vk_free_buf(&dec->vkctx, &dec->tmp_interleave_buf); > + > + ff_vk_uninit(s); > +} > + > +static av_always_inline inline void bar_read(VkBufferMemoryBarrier2 *buf_bar, > + int *nb_buf_bar, FFVkBuffer *buf) { > + buf_bar[(*nb_buf_bar)++] = (VkBufferMemoryBarrier2){ > + .sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER_2, > + .srcStageMask = VK_PIPELINE_STAGE_2_ALL_COMMANDS_BIT, > + .dstStageMask = VK_PIPELINE_STAGE_2_COMPUTE_SHADER_BIT, > + .srcAccessMask = VK_ACCESS_2_MEMORY_READ_BIT, > + .dstAccessMask = VK_ACCESS_2_MEMORY_READ_BIT, > + .srcQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED, > + .dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED, > + .buffer = buf->buf, > + .size = buf->size, > + .offset = 0, > + }; > +} > + > +static av_always_inline inline void > +bar_write(VkBufferMemoryBarrier2 *buf_bar, int *nb_buf_bar, FFVkBuffer *buf) { > + buf_bar[(*nb_buf_bar)++] = (VkBufferMemoryBarrier2){ > + .sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER_2, > + .srcStageMask = VK_PIPELINE_STAGE_2_ALL_COMMANDS_BIT, > + .dstStageMask = VK_PIPELINE_STAGE_2_COMPUTE_SHADER_BIT, > + .srcAccessMask = VK_ACCESS_2_MEMORY_READ_BIT, > + .dstAccessMask = VK_ACCESS_2_MEMORY_WRITE_BIT, > + .srcQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED, > + .dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED, > + .buffer = buf->buf, > + .size = buf->size, > + .offset = 0, > + }; > +} > + > +static inline int alloc_tmp_bufs(DiracContext *ctx, > + DiracVulkanDecodeContext *dec) { > + int err, plane_size; > + > + plane_size = sizeof(int32_t) * > + (ctx->plane[0].idwt.width * ctx->plane[0].idwt.height + > + ctx->plane[1].idwt.width * ctx->plane[1].idwt.height + > + ctx->plane[2].idwt.width * ctx->plane[2].idwt.height); > + > + if (dec->tmp_buf.buf != NULL) { > + ff_vk_free_buf(&dec->vkctx, &dec->tmp_buf); > + ff_vk_free_buf(&dec->vkctx, &dec->tmp_interleave_buf); > + } > + > + err = ff_vk_create_buf(&dec->vkctx, &dec->tmp_buf, plane_size, NULL, NULL, > + VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT | > + VK_BUFFER_USAGE_STORAGE_BUFFER_BIT, > + VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT); > + if (err < 0) > + return err; > + > + err = ff_vk_create_buf(&dec->vkctx, &dec->tmp_interleave_buf, plane_size, > + NULL, NULL, > + VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT | > + VK_BUFFER_USAGE_STORAGE_BUFFER_BIT, > + VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT); > + if (err < 0) > + return err; > + > + return 0; > +} > + > +static inline int alloc_host_mapped_buf(DiracVulkanDecodeContext *dec, > + size_t req_size, void **mem, > + AVBufferRef **avbuf, FFVkBuffer **buf) { > + // FFVulkanFunctions *vk = &dec->vkctx.vkfn; > + // VkResult ret; > + int err; > + > + err = ff_vk_create_avbuf(&dec->vkctx, avbuf, req_size, NULL, NULL, > + VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT | > + VK_BUFFER_USAGE_STORAGE_BUFFER_BIT, > + VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT); > + if (err < 0) > + return err; > + > + *buf = (FFVkBuffer *)(*avbuf)->data; > + err = ff_vk_map_buffer(&dec->vkctx, *buf, (uint8_t **)mem, 0); > + if (err < 0) > + return err; > + > + return 0; > +} > + > +static int alloc_slices_buf(DiracContext *ctx, DiracVulkanDecodeContext *dec) { > + int err, length = ctx->num_y * ctx->num_x; > + > + dec->n_slice_bufs = length; > + > + if (dec->slice_buf_vk_ptr) { > + av_buffer_unref(&dec->av_slice_buf); > + } > + > + dec->slice_buf_size = sizeof(SliceCoeffVk) * length * 3 * MAX_DWT_LEVELS; > + err = alloc_host_mapped_buf(dec, dec->slice_buf_size, > + (void **)&dec->slice_buf_vk_ptr, > + &dec->av_slice_buf, &dec->slice_buf); > + if (err < 0) > + return err; > + > + return 0; > +} > + > +static int alloc_dequant_buf(DiracContext *ctx, DiracVulkanDecodeContext *dec) { > + int err, length = ctx->num_y * ctx->num_x; > + > + if (dec->quant_buf_vk_ptr) { > + av_buffer_unref(&dec->av_quant_buf); > + } > + > + dec->n_slice_bufs = length; > + > + dec->quant_buf_size = sizeof(int32_t) * MAX_DWT_LEVELS * 8 * length; > + err = alloc_host_mapped_buf(dec, dec->quant_buf_size, > + (void **)&dec->quant_buf_vk_ptr, > + &dec->av_quant_buf, &dec->quant_buf); > + if (err < 0) > + return err; > + > + return 0; > +} > + > +static int subband_coeffs(const DiracContext *s, int x, int y, int p, int off, > + SliceCoeffVk *c) { > + int level, coef = 0; > + for (level = 0; level <= s->wavelet_depth; level++) { > + SliceCoeffVk *o = &c[level]; > + const SubBand *b = > + &s->plane[p].band[level][3]; /* orientation doens't matter */ > + o->top = b->height * y / s->num_y; > + o->left = b->width * x / s->num_x; > + o->tot_h = ((b->width * (x + 1)) / s->num_x) - o->left; > + o->tot_v = ((b->height * (y + 1)) / s->num_y) - o->top; > + o->tot = o->tot_h * o->tot_v; > + o->offs = off + coef; > + coef += o->tot * (4 - !!level); > + } > + return coef; > +} > + > +static int alloc_quant_buf(DiracContext *ctx, DiracVulkanDecodeContext *dec) { > + int err, length = ctx->num_y * ctx->num_x, coef_buf_size; > + SliceCoeffVk tmp[MAX_DWT_LEVELS]; > + coef_buf_size = > + subband_coeffs(ctx, ctx->num_x - 1, ctx->num_y - 1, 0, 0, tmp) + 8; > + coef_buf_size = coef_buf_size + 512; > + dec->slice_vals_size = coef_buf_size / sizeof(int32_t); > + // coef_buf_size *= sizeof(int32_t); > + > + if (dec->quant_val_buf_vk_ptr) { > + av_buffer_unref(&dec->av_quant_val_buf); > + } > + > + dec->thread_buf_size = coef_buf_size; > + > + dec->quant_val_buf_size = dec->thread_buf_size * 3 * length; > + err = alloc_host_mapped_buf(dec, dec->quant_val_buf_size, > + (void **)&dec->quant_val_buf_vk_ptr, > + &dec->av_quant_val_buf, &dec->quant_val_buf); > + if (err < 0) > + return err; > + > + return 0; > +} > + > +/* ----- Copy Shader init and pipeline pass ----- */ > + > +static int init_cpy_shd(DiracVulkanDecodeContext *s, FFVkSPIRVCompiler *spv, > + int idx) { > + int err = 0; > + uint8_t *spv_data; > + size_t spv_len; > + void *spv_opaque = NULL; > + FFVulkanContext *vkctx = &s->vkctx; > + FFVulkanDescriptorSetBinding *desc; > + FFVkSPIRVShader *shd = &s->cpy_to_image_shd[idx]; > + FFVulkanPipeline *pl = &s->cpy_to_image_pl[idx]; > + FFVkExecPool *exec = &s->exec_pool; > + const int planes = av_pix_fmt_count_planes(s->vkctx.output_format); > + > + RET(ff_vk_shader_init(pl, shd, "cpy_to_image", VK_SHADER_STAGE_COMPUTE_BIT, > + 0)); > + > + shd = &s->cpy_to_image_shd[idx]; > + ff_vk_shader_set_compute_sizes(shd, 8, 8, 1); > + > + GLSLC(0, #extension GL_EXT_debug_printf : enable); > + GLSLC(0, #extension GL_EXT_scalar_block_layout : require); > + GLSLC(0, #extension GL_EXT_shader_explicit_arithmetic_types : require); > + > + desc = (FFVulkanDescriptorSetBinding[]){ > + { > + .name = "in_buf", > + .stages = VK_SHADER_STAGE_COMPUTE_BIT, > + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, > + .buf_content = "int32_t inBuf[];", > + .mem_quali = "readonly", > + .dimensions = 1, > + }, > + { > + .name = "out_img", > + .stages = VK_SHADER_STAGE_COMPUTE_BIT, > + .type = VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, > + .mem_quali = "writeonly", > + // .mem_layout = ff_vk_shader_rep_fmt(vkctx->output_format), > + .mem_layout = "rgba16f", > + .dimensions = 2, > + .elems = planes, > + }, > + }; > + RET(ff_vk_pipeline_descriptor_set_add(vkctx, pl, shd, desc, 2, 0, 0)); > + > + ff_vk_add_push_constant(pl, 0, sizeof(WaveletPushConst), > + VK_SHADER_STAGE_COMPUTE_BIT); > + > + GLSLC( > + 0, layout(push_constant, std430) uniform pushConstants { ); > + GLSLC(1, ivec2 plane_sizes[3];); > + GLSLC(1, int plane_offs[3];); > + GLSLC(1, int plane_strides[3];); > + GLSLC(1, int dw[3];); > + GLSLC(1, int wavelet_depth;); > + GLSLC(0, > + };); > + GLSLC(0, ); > + > + GLSLC( > + 0, void main() {); > + GLSLC(1, int x = int(gl_GlobalInvocationID.x);); > + GLSLC(1, int y = int(gl_GlobalInvocationID.y);); > + GLSLC(1, int plane = int(gl_GlobalInvocationID.z);); > + GLSLC(1, if (!IS_WITHIN(ivec2(x, y), > + imageSize(out_img[plane]))) return;); > + GLSLC(1, > + int idx = plane_offs[plane] + y * plane_strides[plane] + x;); > + if (idx == 2) { > + GLSLC(1, int32_t ival = inBuf[idx] + 2048;); > + GLSLC(1, float val = float(clamp(ival, 0, 4096)) / 65535.0;); > + } else if (idx == 1) { > + GLSLC(1, int32_t ival = inBuf[idx] + 512;); > + GLSLC(1, float val = float(clamp(ival, 0, 1024)) / 65535.0;); > + } else { > + GLSLC(1, int32_t ival = inBuf[idx] + 128;); > + GLSLC(1, float val = float(clamp(ival, 0, 256)) / 255.0;); > + } > + GLSLC(1, imageStore(out_img[plane], ivec2(x, y), vec4(val));); > + GLSLC(1, memoryBarrier();); > + GLSLC(0, > + }); > + > + RET(spv->compile_shader(spv, vkctx, shd, &spv_data, &spv_len, "main", > + &spv_opaque)); > + RET(ff_vk_shader_create(vkctx, shd, spv_data, spv_len, "main")); > + RET(ff_vk_init_compute_pipeline(vkctx, pl, shd)); > + RET(ff_vk_exec_pipeline_register(vkctx, exec, pl)); > + > +fail: > + if (spv_opaque) > + spv->free_shader(spv, &spv_opaque); > + > + return err; > +} > + > +static av_always_inline int inline cpy_to_image_pass( > + DiracVulkanDecodeContext *dec, DiracContext *ctx, FFVkExecContext *exec, > + VkImageView *views, VkBufferMemoryBarrier2 *buf_bar, int *nb_buf_bar, > + VkImageMemoryBarrier2 *img_bar, int *nb_img_bar, int idx) { > + int err, prev_nb_bar = *nb_buf_bar, prev_nb_img_bar = *nb_img_bar; > + FFVulkanFunctions *vk = &dec->vkctx.vkfn; > + DiracVulkanDecodePicture *pic = ctx->hwaccel_picture_private; > + > + err = ff_vk_set_descriptor_buffer(&dec->vkctx, &dec->cpy_to_image_pl[idx], > + exec, 0, 0, 0, dec->tmp_buf.address, > + dec->tmp_buf.size, VK_FORMAT_UNDEFINED); > + if (err < 0) > + return err; > + > + ff_vk_update_descriptor_img_array(&dec->vkctx, &dec->cpy_to_image_pl[idx], > + exec, pic->frame->avframe, views, 0, 1, > + VK_IMAGE_LAYOUT_GENERAL, dec->sampler); > + > + dec->pConst.real_plane_dims[0] = ctx->plane[0].idwt.width; > + dec->pConst.real_plane_dims[1] = ctx->plane[0].idwt.height; > + dec->pConst.real_plane_dims[2] = ctx->plane[1].idwt.width; > + dec->pConst.real_plane_dims[3] = ctx->plane[1].idwt.height; > + dec->pConst.real_plane_dims[4] = ctx->plane[2].idwt.width; > + dec->pConst.real_plane_dims[5] = ctx->plane[2].idwt.height; > + > + dec->pConst.plane_strides[0] = ctx->plane[0].idwt.width; > + dec->pConst.plane_strides[1] = ctx->plane[1].idwt.width; > + dec->pConst.plane_strides[2] = ctx->plane[2].idwt.width; > + > + dec->pConst.plane_offs[0] = 0; > + dec->pConst.plane_offs[1] = > + ctx->plane[0].idwt.width * ctx->plane[0].idwt.height; > + dec->pConst.plane_offs[2] = > + dec->pConst.plane_offs[1] + > + ctx->plane[1].idwt.width * ctx->plane[1].idwt.height; > + > + ff_vk_update_push_exec(&dec->vkctx, exec, &dec->cpy_to_image_pl[idx], > + VK_SHADER_STAGE_COMPUTE_BIT, 0, > + sizeof(WaveletPushConst), &dec->pConst); > + > + bar_read(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); > + bar_write(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); > + bar_read(buf_bar, nb_buf_bar, &dec->tmp_buf); > + bar_write(buf_bar, nb_buf_bar, &dec->tmp_buf); > + > + ff_vk_frame_barrier(&dec->vkctx, exec, pic->frame->avframe, img_bar, > + nb_img_bar, VK_PIPELINE_STAGE_2_ALL_COMMANDS_BIT, > + VK_PIPELINE_STAGE_2_COMPUTE_SHADER_BIT, > + VK_ACCESS_SHADER_READ_BIT, VK_IMAGE_LAYOUT_GENERAL, > + VK_QUEUE_FAMILY_IGNORED); > + > + vk->CmdPipelineBarrier2( > + exec->buf, &(VkDependencyInfo){ > + .sType = VK_STRUCTURE_TYPE_DEPENDENCY_INFO, > + .pBufferMemoryBarriers = buf_bar + prev_nb_bar, > + .bufferMemoryBarrierCount = *nb_buf_bar - prev_nb_bar, > + .pImageMemoryBarriers = img_bar + prev_nb_img_bar, > + .imageMemoryBarrierCount = *nb_img_bar - prev_nb_img_bar, > + }); > + > + ff_vk_exec_bind_pipeline(&dec->vkctx, exec, &dec->cpy_to_image_pl[idx]); > + > + vk->CmdDispatch(exec->buf, ctx->plane[0].width >> 3, > + ctx->plane[0].height >> 3, 3); > + > + return 0; > +} > + > +/* ----- LeGall Wavelet init and pipeline pass ----- */ > + > +static const char get_idx[] = {C( > + 0, int getIdx(int plane, int x, int y) { ) > + C(1, return plane_offs[plane] + plane_strides[plane] * y + x; ) > + C(0, > + })}; > + > +static const char legall_low_y[] = {C( > + 0, int32_t legall_low_y(int plane, int x, int y) { ) > + C(1, const int h = plane_sizes[plane].y; ) > + C(1, ) > + C(1, const int y_1 = ((y - 1) > 0) ? (y - 1) : 1; ) > + C(1, const int32_t val_1 = inBuf[getIdx(plane, x, y_1)]; ) > + C(1, const int y0 = y; ) > + C(1, const int32_t val0 = inBuf[getIdx(plane, x, y0)]; ) > + C(1, const int y1 = y + 1; ) > + C(1, const int32_t val1 = inBuf[getIdx(plane, x, y1)]; ) > + C(1, return val0 - ((val1 + val_1 + 2) >> 2); ) > + C(0, > + })}; > + > +static const char legall_high[] = {C( > + 0, int32_t legall_high(int32_t v1, int32_t v2, int32_t v3) { ) > + C(1, return v1 + ((v2 + v3 + 1) >> 1); ) > + C(0, > + })}; > + > +static const char legall_vert[] = {C( > + 0, void idwt_vert(int plane, int x, int y) { ) > + C(1, const int h = plane_sizes[plane].y; ) > + C(1, ) > + C(1, const int32_t out0 = legall_low_y(plane, x, y); ) > + C(1, const int32_t yy = ((y + 2) < h) ? (y + 2) : (h - 2); ) > + C(1, const int32_t tmp1 = legall_low_y(plane, x, yy); ) > + C(1, ) > + C(1, const int y1 = y + 1; ) > + C(1, const int32_t val1 = inBuf[getIdx(plane, x, y1)]; ) > + C(1, ) > + C(1, const int32_t out1 = legall_high(val1, out0, tmp1); ) > + C(1, ) > + C(1, outBuf[getIdx(plane, x, y)] = out0; ) > + C(1, outBuf[getIdx(plane, x, y + 1)] = out1; ) > + C(0, > + })}; > + > +static const char legall_low_x[] = {C( > + 0, int32_t legall_low_x(int plane, int x, int y) { ) > + C(1, const int w = plane_sizes[plane].x; ) > + C(1, const int dw = w / 2; ) > + C(1, ) > + C(1, const int x_1 = (x > 0) ? x : 0; ) > + C(1, const int32_t val_1 = inBuf[getIdx(plane, x_1, y)]; ) > + C(1, ) > + C(1, const int x1 = (x > 0) ? (x + dw) : dw; ) > + C(1, const int32_t val1 = inBuf[getIdx(plane, x1, y)]; ) > + C(1, ) > + C(1, const int x0 = (x > 0) ? (x + dw - 1) : dw; ) > + C(1, const int32_t val0 = inBuf[getIdx(plane, x0, y)]; ) > + C(1, ) > + C(1, return val_1 - ((val0 + val1 + 2) >> 2); ) > + C(0, > + })}; > + > +static const char legall_horiz[] = {C( > + 0, void idwt_horiz(int plane, int x, int y) { ) > + C(1, const int w = plane_sizes[plane].x; ) > + C(1, const int dw = w / 2 - 1; ) > + C(1, ) > + C(1, const int32_t out0 = legall_low_x(plane, x, y); ) > + C(1, const int32_t tmp1 = (x == dw) ? out0 : legall_low_x(plane, x + 1, y); ) > + C(1, ) > + C(1, const int x1 = x + dw + 1; ) > + C(1, const int32_t val1 = inBuf[getIdx(plane, x1, y)]; ) > + C(1, ) > + C(1, const int32_t out1 = legall_high(val1, out0, tmp1); ) > + C(1, ) > + C(1, outBuf[getIdx(plane, 2 * x, y)] = (out0 + 1) >> 1; ) > + C(1, outBuf[getIdx(plane, 2 * x + 1, y)] = (out1 + 1) >> 1; ) > + C(0, > + })}; > + > +static int init_wavelet_shd_legall_vert(DiracVulkanDecodeContext *s, > + FFVkSPIRVCompiler *spv) { > + int err = 0; > + uint8_t *spv_data; > + size_t spv_len; > + void *spv_opaque = NULL; > + int wavelet_idx = DWT_DIRAC_LEGALL5_3; > + FFVulkanContext *vkctx = &s->vkctx; > + FFVulkanDescriptorSetBinding *desc; > + FFVkSPIRVShader *shd = &s->vert_wavelet_shd[wavelet_idx]; > + FFVulkanPipeline *pl = &s->vert_wavelet_pl[wavelet_idx]; > + FFVkExecPool *exec = &s->exec_pool; > + > + RET(ff_vk_shader_init(pl, shd, "legall_vert", VK_SHADER_STAGE_COMPUTE_BIT, > + 0)); > + > + shd = &s->vert_wavelet_shd[wavelet_idx]; > + ff_vk_shader_set_compute_sizes(shd, 8, 8, 3); > + > + GLSLC(0, #extension GL_EXT_scalar_block_layout : require); > + GLSLC(0, #extension GL_EXT_shader_explicit_arithmetic_types : require); > + > + desc = (FFVulkanDescriptorSetBinding[]){ > + { > + .name = "in_buf", > + .stages = VK_SHADER_STAGE_COMPUTE_BIT, > + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, > + .buf_content = "int32_t inBuf[];", > + .mem_quali = "readonly", > + .dimensions = 1, > + }, > + { > + .name = "out_buf", > + .stages = VK_SHADER_STAGE_COMPUTE_BIT, > + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, > + .buf_content = "int32_t outBuf[];", > + .mem_quali = "writeonly", > + .dimensions = 1, > + }, > + }; > + RET(ff_vk_pipeline_descriptor_set_add(vkctx, pl, shd, desc, 2, 0, 0)); > + > + ff_vk_add_push_constant(pl, 0, sizeof(WaveletPushConst), > + VK_SHADER_STAGE_COMPUTE_BIT); > + > + GLSLC( > + 0, layout(push_constant, std430) uniform pushConstants { ); > + GLSLC(1, ivec2 plane_sizes[3];); > + GLSLC(1, int plane_offs[3];); > + GLSLC(1, int plane_strides[3];); > + GLSLC(1, int dw[3];); > + GLSLC(1, int wavelet_depth;); > + GLSLC(0, > + };); > + GLSLC(0, ); > + > + GLSLD(get_idx); > + GLSLD(legall_low_y); > + GLSLD(legall_high); > + GLSLD(legall_vert); > + > + GLSLC( > + 0, void main() { ); > + GLSLC(1, int off_y = int(gl_WorkGroupSize.y * gl_NumWorkGroups.y);); > + GLSLC(1, int off_x = int(gl_WorkGroupSize.x * gl_NumWorkGroups.x);); > + GLSLC(1, int pic_z = int(gl_GlobalInvocationID.z);); > + GLSLC(1, ); > + GLSLC(1, uint h = int(plane_sizes[pic_z].y);); > + GLSLC(2, uint w = int(plane_sizes[pic_z].x);); > + GLSLC(1, ); > + GLSLC(1, int y = int(gl_GlobalInvocationID.y);); > + GLSLC( > + 1, for (; 2 * y < h; y += off_y) { ); > + GLSLC(2, int x = int(gl_GlobalInvocationID.x);); > + GLSLC( > + 2, for (; x < w; x += off_x) { ); > + GLSLC(3, idwt_vert(pic_z, x, 2 * y);); > + GLSLC(2, > + }); > + GLSLC(1, > + }); > + GLSLC(0, > + }); > + > + RET(spv->compile_shader(spv, vkctx, shd, &spv_data, &spv_len, "main", > + &spv_opaque)); > + RET(ff_vk_shader_create(vkctx, shd, spv_data, spv_len, "main")); > + RET(ff_vk_init_compute_pipeline(vkctx, pl, shd)); > + RET(ff_vk_exec_pipeline_register(vkctx, exec, pl)); > + > +fail: > + if (spv_opaque) > + spv->free_shader(spv, &spv_opaque); > + > + return err; > +} > + > +static int init_wavelet_shd_legall_horiz(DiracVulkanDecodeContext *s, > + FFVkSPIRVCompiler *spv) { > + int err = 0; > + uint8_t *spv_data; > + size_t spv_len; > + void *spv_opaque = NULL; > + int wavelet_idx = DWT_DIRAC_LEGALL5_3; > + FFVulkanContext *vkctx = &s->vkctx; > + FFVulkanDescriptorSetBinding *desc; > + FFVkSPIRVShader *shd = &s->horiz_wavelet_shd[wavelet_idx]; > + FFVulkanPipeline *pl = &s->horiz_wavelet_pl[wavelet_idx]; > + FFVkExecPool *exec = &s->exec_pool; > + > + RET(ff_vk_shader_init(pl, shd, "legall_horiz", VK_SHADER_STAGE_COMPUTE_BIT, > + 0)); > + > + shd = &s->horiz_wavelet_shd[wavelet_idx]; > + ff_vk_shader_set_compute_sizes(shd, 8, 8, 3); > + > + GLSLC(0, #extension GL_EXT_debug_printf : enable); > + GLSLC(0, #extension GL_EXT_scalar_block_layout : require); > + GLSLC(0, #extension GL_EXT_shader_explicit_arithmetic_types : require); > + > + desc = (FFVulkanDescriptorSetBinding[]){ > + { > + .name = "in_buf", > + .stages = VK_SHADER_STAGE_COMPUTE_BIT, > + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, > + .buf_content = "int32_t inBuf[];", > + .mem_quali = "readonly", > + .dimensions = 1, > + }, > + { > + .name = "out_buf", > + .stages = VK_SHADER_STAGE_COMPUTE_BIT, > + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, > + .buf_content = "int32_t outBuf[];", > + .mem_quali = "writeonly", > + .dimensions = 1, > + }, > + }; > + RET(ff_vk_pipeline_descriptor_set_add(vkctx, pl, shd, desc, 2, 0, 0)); > + > + ff_vk_add_push_constant(pl, 0, sizeof(WaveletPushConst), > + VK_SHADER_STAGE_COMPUTE_BIT); > + > + GLSLC( > + 0, layout(push_constant, std430) uniform pushConstants { ); > + GLSLC(1, ivec2 plane_sizes[3];); > + GLSLC(1, int plane_offs[3];); > + GLSLC(1, int plane_strides[3];); > + GLSLC(1, int dw[3];); > + GLSLC(1, int wavelet_depth;); > + GLSLC(0, > + };); > + GLSLC(0, ); > + > + GLSLD(get_idx); > + GLSLD(legall_low_x); > + GLSLD(legall_high); > + GLSLD(legall_horiz); > + > + GLSLC( > + 0, void main() { ); > + GLSLC(1, int off_y = int(gl_WorkGroupSize.y * gl_NumWorkGroups.y);); > + GLSLC(1, int off_x = int(gl_WorkGroupSize.x * gl_NumWorkGroups.x);); > + GLSLC(1, int pic_z = int(gl_GlobalInvocationID.z);); > + GLSLC(1, ); > + GLSLC(1, uint h = int(plane_sizes[pic_z].y);); > + GLSLC(2, uint w = int(plane_sizes[pic_z].x);); > + GLSLC(1, ); > + GLSLC(1, int y = int(gl_GlobalInvocationID.y);); > + GLSLC( > + 1, for (; y < h; y += off_y) { ); > + GLSLC(2, int x = int(gl_GlobalInvocationID.x);); > + GLSLC( > + 2, for (; 2 * x < w; x += off_x) { ); > + GLSLC(3, idwt_horiz(pic_z, x, y);); > + GLSLC(2, > + }); > + GLSLC(1, > + }); > + GLSLC(0, > + }); > + > + RET(spv->compile_shader(spv, vkctx, shd, &spv_data, &spv_len, "main", > + &spv_opaque)); > + RET(ff_vk_shader_create(vkctx, shd, spv_data, spv_len, "main")); > + RET(ff_vk_init_compute_pipeline(vkctx, pl, shd)); > + RET(ff_vk_exec_pipeline_register(vkctx, exec, pl)); > + > +fail: > + if (spv_opaque) > + spv->free_shader(spv, &spv_opaque); > + > + return err; > +} > + > +static av_always_inline int inline wavelet_legall_pass( > + DiracVulkanDecodeContext *dec, DiracContext *ctx, FFVkExecContext *exec, > + VkBufferMemoryBarrier2 *buf_bar, int *nb_buf_bar) { > + int err; > + int barrier_num = *nb_buf_bar; > + int wavelet_idx = DWT_DIRAC_LEGALL5_3; > + FFVulkanFunctions *vk = &dec->vkctx.vkfn; > + > + FFVulkanPipeline *pl_hor = &dec->horiz_wavelet_pl[wavelet_idx]; > + FFVulkanPipeline *pl_vert = &dec->vert_wavelet_pl[wavelet_idx]; > + > + err = ff_vk_set_descriptor_buffer(&dec->vkctx, pl_vert, exec, 0, 0, 0, > + dec->tmp_buf.address, dec->tmp_buf.size, > + VK_FORMAT_UNDEFINED); > + if (err < 0) > + goto fail; > + err = ff_vk_set_descriptor_buffer( > + &dec->vkctx, pl_vert, exec, 0, 1, 0, dec->tmp_interleave_buf.address, > + dec->tmp_interleave_buf.size, VK_FORMAT_UNDEFINED); > + if (err < 0) > + goto fail; > + > + err = ff_vk_set_descriptor_buffer( > + &dec->vkctx, pl_hor, exec, 0, 0, 0, dec->tmp_interleave_buf.address, > + dec->tmp_interleave_buf.size, VK_FORMAT_UNDEFINED); > + if (err < 0) > + goto fail; > + err = ff_vk_set_descriptor_buffer(&dec->vkctx, pl_hor, exec, 0, 1, 0, > + dec->tmp_buf.address, dec->tmp_buf.size, > + VK_FORMAT_UNDEFINED); > + if (err < 0) > + goto fail; > + > + for (int i = ctx->wavelet_depth - 1; i >= 0; i--) { > + dec->pConst.plane_strides[0] = ctx->plane[0].idwt.width << i; > + dec->pConst.plane_strides[1] = ctx->plane[1].idwt.width << i; > + dec->pConst.plane_strides[2] = ctx->plane[2].idwt.width << i; > + > + dec->pConst.plane_offs[0] = 0; > + dec->pConst.plane_offs[1] = > + ctx->plane[0].idwt.width * ctx->plane[0].idwt.height; > + dec->pConst.plane_offs[2] = > + dec->pConst.plane_offs[1] + > + ctx->plane[1].idwt.width * ctx->plane[1].idwt.height; > + > + dec->pConst.dw[0] = ctx->plane[0].idwt.width >> (i + 1); > + dec->pConst.dw[1] = ctx->plane[1].idwt.width >> (i + 1); > + dec->pConst.dw[2] = ctx->plane[2].idwt.width >> (i + 1); > + > + dec->pConst.real_plane_dims[0] = (ctx->plane[0].idwt.width) >> i; > + dec->pConst.real_plane_dims[1] = (ctx->plane[0].idwt.height) >> i; > + dec->pConst.real_plane_dims[2] = (ctx->plane[1].idwt.width) >> i; > + dec->pConst.real_plane_dims[3] = (ctx->plane[1].idwt.height) >> i; > + dec->pConst.real_plane_dims[4] = (ctx->plane[2].idwt.width) >> i; > + dec->pConst.real_plane_dims[5] = (ctx->plane[2].idwt.height) >> i; > + > + /* Vertical wavelet pass */ > + ff_vk_update_push_exec(&dec->vkctx, exec, pl_vert, > + VK_SHADER_STAGE_COMPUTE_BIT, 0, > + sizeof(WaveletPushConst), &dec->pConst); > + > + barrier_num = *nb_buf_bar; > + bar_read(buf_bar, nb_buf_bar, &dec->tmp_buf); > + bar_write(buf_bar, nb_buf_bar, &dec->tmp_buf); > + bar_read(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); > + bar_write(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); > + > + vk->CmdPipelineBarrier2( > + exec->buf, > + &(VkDependencyInfo){ > + .sType = VK_STRUCTURE_TYPE_DEPENDENCY_INFO, > + .pBufferMemoryBarriers = buf_bar + barrier_num, > + .bufferMemoryBarrierCount = *nb_buf_bar - barrier_num, > + }); > + > + ff_vk_exec_bind_pipeline(&dec->vkctx, exec, pl_vert); > + vk->CmdDispatch(exec->buf, dec->pConst.real_plane_dims[0] >> 3, > + dec->pConst.real_plane_dims[1] >> 4, 1); > + > + /* Horizontal wavelet pass */ > + ff_vk_update_push_exec(&dec->vkctx, exec, pl_hor, > + VK_SHADER_STAGE_COMPUTE_BIT, 0, > + sizeof(WaveletPushConst), &dec->pConst); > + > + barrier_num = *nb_buf_bar; > + bar_read(buf_bar, nb_buf_bar, &dec->tmp_buf); > + bar_write(buf_bar, nb_buf_bar, &dec->tmp_buf); > + bar_read(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); > + bar_write(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); > + > + vk->CmdPipelineBarrier2( > + exec->buf, > + &(VkDependencyInfo){ > + .sType = VK_STRUCTURE_TYPE_DEPENDENCY_INFO, > + .pBufferMemoryBarriers = buf_bar + barrier_num, > + .bufferMemoryBarrierCount = *nb_buf_bar - barrier_num, > + }); > + > + ff_vk_exec_bind_pipeline(&dec->vkctx, exec, pl_hor); > + vk->CmdDispatch(exec->buf, dec->pConst.real_plane_dims[0] >> 4, > + dec->pConst.real_plane_dims[1] >> 3, 1); > + } > + > + barrier_num = *nb_buf_bar; > + bar_read(buf_bar, nb_buf_bar, &dec->tmp_buf); > + bar_write(buf_bar, nb_buf_bar, &dec->tmp_buf); > + bar_read(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); > + bar_write(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); > + > + vk->CmdPipelineBarrier2( > + exec->buf, &(VkDependencyInfo){ > + .sType = VK_STRUCTURE_TYPE_DEPENDENCY_INFO, > + .pBufferMemoryBarriers = buf_bar + barrier_num, > + .bufferMemoryBarrierCount = *nb_buf_bar - barrier_num, > + }); > + > + return 0; > +fail: > + ff_vk_exec_discard_deps(&dec->vkctx, exec); > + return err; > +} > + > +/* ----- Fidelity init and pipeline pass ----- */ > + > +static const char fidelity_low[] = {C( > + 0, int32_t fidelity_low(int32_t v0, int32_t v1, int32_t v2, int32_t v3, > + int32_t v4, int32_t v5, int32_t v6, int32_t v7) {) > + C(1, return (-2 * v0 + 10 * v1 - 25 * v2 + 81 * v3 + 81 * v4 - 25 * v5 + 10 * v6 - 2 * v7 + 128) >> 8;) > + C(0, > + })}; > + > +static const char fidelity_high[] = {C( > + 0, int32_t fidelity_high(int32_t v0, int32_t v1, int32_t v2, int32_t v3, > + int32_t v4, int32_t v5, int32_t v6, int32_t v7) {) > + C(1, return (-8 * v0 + 21 * v1 - 46 * v2 + 161 * v3 + 161 * v4 - 46 * v5 + 21 * v6 - 8 * v7 + 128) >> 8;) > + C(0, > + })}; > + > +static const char fidelity_low_y[] = {C( > + 0, int32_t fidelity_low_y(int plane, int x, int y) { ) > + C(1, const int h = plane_sizes[plane].y; ) > + C(1, ) > + C(1, const int32_t v1 = inBuf[getIdx(plane, x, y + 1)]; ) > + C(1, ) > + C(1, const int y_6 = ((y - 6) > 0) ? (y - 6) : 0; ) > + C(1, const int32_t v_6 = inBuf[getIdx(plane, x, y_6)]; ) > + C(1, ) > + C(1, const int y_4 = ((y - 4) > 0) ? (y - 4) : 0; ) > + C(1, const int32_t v_4 = inBuf[getIdx(plane, x, y_4)]; ) > + C(1, ) > + C(1, const int y_2 = ((y - 2) > 0) ? (y - 2) : 0; ) > + C(1, const int32_t v_2 = inBuf[getIdx(plane, x, y_2)]; ) > + C(1, ) > + C(1, const int32_t v0 = inBuf[getIdx(plane, x, y)]; ) > + C(1, ) > + C(1, const int y2 = ((y + 2) < h) ? (y + 2) : (h - 2); ) > + C(1, const int32_t v2 = inBuf[getIdx(plane, x, y2)]; ) > + C(1, ) > + C(1, const int y4 = ((y + 4) < h) ? (y + 4) : (h - 2); ) > + C(1, const int32_t v4 = inBuf[getIdx(plane, x, y4)]; ) > + C(1, ) > + C(1, const int y6 = ((y + 6) < h) ? (y + 6) : (h - 2); ) > + C(1, const int32_t v6 = inBuf[getIdx(plane, x, y6)]; ) > + C(1, ) > + C(1, const int y8 = ((y + 8) < h) ? (y + 8) : (h - 2); ) > + C(1, const int32_t v8 = inBuf[getIdx(plane, x, y8)]; ) > + C(1, ) > + C(1, return v1 + fidelity_low(v_6, v_4, v_2, v0, v2, v4, v6, v8); ) > + C(0, > + })}; > + > +static const char fidelity_vert[] = {C( > + 0, void idwt_vert(int plane, int x, int y) { ) > + C(1, const int h = plane_sizes[plane].y; ) > + C(1, ) > + C(1, const int32_t v0 = inBuf[getIdx(plane, x, y)]; ) > + C(1, const int32_t v1 = fidelity_low_y(plane, x, y); ) > + C(1, const int32_t v_7 = (y - 8 > 0) ? fidelity_low_y(plane, x, y - 8) : v1; ) > + C(1, const int32_t v_5 = (y - 6 > 0) ? fidelity_low_y(plane, x, y - 6) : v1; ) > + C(1, const int32_t v_3 = (y - 4 > 0) ? fidelity_low_y(plane, x, y - 4) : v1; ) > + C(1, const int32_t v_1 = (y - 2 > 0) ? fidelity_low_y(plane, x, y - 2) : v1; ) > + C(1, const int32_t v3 = (y + 2 < h) ? fidelity_low_y(plane, x, y + 2) : ) > + C(1, fidelity_low_y(plane, x, h - 2); ) > + C(1, const int32_t v5 = (y + 4 < h) ? fidelity_low_y(plane, x, y + 4) : ) > + C(1, fidelity_low_y(plane, x, h - 2); ) > + C(1, const int32_t v7 = (y + 6 < h) ? fidelity_low_y(plane, x, y + 6) : ) > + C(1, fidelity_low_y(plane, x, h - 2); ) > + C(1, outBuf[getIdx(plane, x, y)] = v0 - fidelity_high(v_7, v_5, v_3, v_1, v1, v3, v5, v7);) > + C(1, outBuf[getIdx(plane, x, y + 1)] = v1; ) > + C(0, > + })}; > + > +static const char fidelity_low_x[] = {C( > + 0, int32_t fidelity_low_x(int plane, int x, int y) { ) > + C(1, const int w = plane_sizes[plane].x; ) > + C(1, const int dw = w / 2 - 1; ) > + C(1, ) > + C(1, const int x_3 = clamp(x - 3, 0, dw); ) > + C(1, const int32_t v_3 = inBuf[getIdx(plane, x_3, y)]; ) > + C(1, ) > + C(1, const int x_2 = clamp(x - 2, 0, dw); ) > + C(1, const int32_t v_2 = inBuf[getIdx(plane, x_2, y)]; ) > + C(1, ) > + C(1, const int x_1 = clamp(x - 1, 0, dw); ) > + C(1, const int32_t v_1 = inBuf[getIdx(plane, x_1, y)]; ) > + C(1, ) > + C(1, const int32_t v0 = inBuf[getIdx(plane, x, y)]; ) > + C(1, ) > + C(1, const int x_w = x + dw + 1; ) > + C(1, const int32_t v_w = inBuf[getIdx(plane, x_w, y)]; ) > + C(1, ) > + C(1, const int x1 = clamp(x + 1, 0, dw); ) > + C(1, const int32_t v1 = inBuf[getIdx(plane, x1, y)]; ) > + C(1, ) > + C(1, const int x2 = clamp(x + 2, 0, dw); ) > + C(1, const int32_t v2 = inBuf[getIdx(plane, x2, y)]; ) > + C(1, ) > + C(1, const int x3 = clamp(x + 3, 0, dw); ) > + C(1, const int32_t v3 = inBuf[getIdx(plane, x3, y)]; ) > + C(1, ) > + C(1, const int x4 = clamp(x + 4, 0, dw); ) > + C(1, const int32_t v4 = inBuf[getIdx(plane, x4, y)]; ) > + C(1, ) > + C(1, return v_w + fidelity_low(v_3, v_2, v_1, v0, v1, v2, v3, v4); ) > + C(0, > + })}; > + > +static const char fidelity_horiz[] = {C( > + 0, void idwt_horiz(int plane, int x, int y) { ) > + C(1, const int w = plane_sizes[plane].x; ) > + C(1, const int dw = w / 2 - 1; ) > + C(1, ) > + C(1, const int32_t vo0 = inBuf[getIdx(plane, x, y)]; ) > + C(1, ) > + C(1, const int x_4 = clamp(x - 4, 0, dw); ) > + C(1, const int32_t v_4 = fidelity_low_x(plane, x_4, y); ) > + C(1, const int x_3 = clamp(x - 3, 0, dw); ) > + C(1, const int32_t v_3 = fidelity_low_x(plane, x_3, y); ) > + C(1, const int x_2 = clamp(x - 2, 0, dw); ) > + C(1, const int32_t v_2 = fidelity_low_x(plane, x_2, y); ) > + C(1, const int x_1 = clamp(x - 1, 0, dw); ) > + C(1, const int32_t v_1 = fidelity_low_x(plane, x_1, y); ) > + C(1, const int x0 = clamp(x, 0, dw); ) > + C(1, const int32_t v0 = fidelity_low_x(plane, x0, y); ) > + C(1, const int x1 = clamp(x + 1, 0, dw); ) > + C(1, const int32_t v1 = fidelity_low_x(plane, x1, y); ) > + C(1, const int x2 = clamp(x + 2, 0, dw); ) > + C(1, const int32_t v2 = fidelity_low_x(plane, x2, y); ) > + C(1, const int x3 = clamp(x + 3, 0, dw); ) > + C(1, const int32_t v3 = fidelity_low_x(plane, x3, y); ) > + C(1, ) > + C(1, outBuf[getIdx(plane, 2 * x, y)] = vo0 - fidelity_high(v_4, v_3, v_2, v_1, v0, v1, v2, v3);) > + C(1, outBuf[getIdx(plane, 2 * x + 1, y)] = v0; ) > + C(0, > + })}; > + > +static int init_wavelet_shd_fidelity_vert(DiracVulkanDecodeContext *s, > + FFVkSPIRVCompiler *spv) { > + int err = 0; > + uint8_t *spv_data; > + size_t spv_len; > + void *spv_opaque = NULL; > + int wavelet_idx = DWT_DIRAC_FIDELITY; > + FFVulkanContext *vkctx = &s->vkctx; > + FFVulkanDescriptorSetBinding *desc; > + FFVkSPIRVShader *shd = &s->vert_wavelet_shd[wavelet_idx]; > + FFVulkanPipeline *pl = &s->vert_wavelet_pl[wavelet_idx]; > + FFVkExecPool *exec = &s->exec_pool; > + > + RET(ff_vk_shader_init(pl, shd, "fidelity_vert", VK_SHADER_STAGE_COMPUTE_BIT, > + 0)); > + > + shd = &s->vert_wavelet_shd[wavelet_idx]; > + ff_vk_shader_set_compute_sizes(shd, 8, 8, 3); > + > + GLSLC(0, #extension GL_EXT_scalar_block_layout : require); > + GLSLC(0, #extension GL_EXT_shader_explicit_arithmetic_types : require); > + > + desc = (FFVulkanDescriptorSetBinding[]){ > + { > + .name = "in_buf", > + .stages = VK_SHADER_STAGE_COMPUTE_BIT, > + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, > + .buf_content = "int32_t inBuf[];", > + .mem_quali = "readonly", > + .dimensions = 1, > + }, > + { > + .name = "out_buf", > + .stages = VK_SHADER_STAGE_COMPUTE_BIT, > + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, > + .buf_content = "int32_t outBuf[];", > + .mem_quali = "writeonly", > + .dimensions = 1, > + }, > + }; > + RET(ff_vk_pipeline_descriptor_set_add(vkctx, pl, shd, desc, 2, 0, 0)); > + > + ff_vk_add_push_constant(pl, 0, sizeof(WaveletPushConst), > + VK_SHADER_STAGE_COMPUTE_BIT); > + > + GLSLC( > + 0, layout(push_constant, std430) uniform pushConstants { ); > + GLSLC(1, ivec2 plane_sizes[3];); > + GLSLC(1, int plane_offs[3];); > + GLSLC(1, int plane_strides[3];); > + GLSLC(1, int dw[3];); > + GLSLC(1, int wavelet_depth;); > + GLSLC(0, > + };); > + GLSLC(0, ); > + > + GLSLD(get_idx); > + GLSLD(fidelity_low); > + GLSLD(fidelity_high); > + GLSLD(fidelity_low_y); > + GLSLD(fidelity_vert); > + > + GLSLC( > + 0, void main() { ); > + GLSLC(1, int off_y = int(gl_WorkGroupSize.y * gl_NumWorkGroups.y);); > + GLSLC(1, int off_x = int(gl_WorkGroupSize.x * gl_NumWorkGroups.x);); > + GLSLC(1, int pic_z = int(gl_GlobalInvocationID.z);); > + GLSLC(1, ); > + GLSLC(1, uint h = int(plane_sizes[pic_z].y);); > + GLSLC(2, uint w = int(plane_sizes[pic_z].x);); > + GLSLC(1, ); > + GLSLC(1, int y = int(gl_GlobalInvocationID.y);); > + GLSLC( > + 1, for (; 2 * y < h; y += off_y) { ); > + GLSLC(2, int x = int(gl_GlobalInvocationID.x);); > + GLSLC( > + 2, for (; x < w; x += off_x) { ); > + GLSLC(3, idwt_vert(pic_z, x, 2 * y);); > + GLSLC(2, > + }); > + GLSLC(1, > + }); > + GLSLC(0, > + }); > + > + RET(spv->compile_shader(spv, vkctx, shd, &spv_data, &spv_len, "main", > + &spv_opaque)); > + RET(ff_vk_shader_create(vkctx, shd, spv_data, spv_len, "main")); > + RET(ff_vk_init_compute_pipeline(vkctx, pl, shd)); > + > + RET(ff_vk_exec_pipeline_register(vkctx, exec, pl)); > + > +fail: > + if (spv_opaque) > + spv->free_shader(spv, &spv_opaque); > + > + return err; > +} > + > +static int init_wavelet_shd_fidelity_horiz(DiracVulkanDecodeContext *s, > + FFVkSPIRVCompiler *spv) { > + int err = 0; > + uint8_t *spv_data; > + size_t spv_len; > + void *spv_opaque = NULL; > + int wavelet_idx = DWT_DIRAC_FIDELITY; > + FFVulkanContext *vkctx = &s->vkctx; > + FFVulkanDescriptorSetBinding *desc; > + FFVkSPIRVShader *shd = &s->horiz_wavelet_shd[wavelet_idx]; > + FFVulkanPipeline *pl = &s->horiz_wavelet_pl[wavelet_idx]; > + FFVkExecPool *exec = &s->exec_pool; > + > + RET(ff_vk_shader_init(pl, shd, "fidelity_horiz", > + VK_SHADER_STAGE_COMPUTE_BIT, 0)); > + > + shd = &s->horiz_wavelet_shd[wavelet_idx]; > + ff_vk_shader_set_compute_sizes(shd, 8, 8, 3); > + > + GLSLC(0, #extension GL_EXT_debug_printf : enable); > + GLSLC(0, #extension GL_EXT_scalar_block_layout : require); > + GLSLC(0, #extension GL_EXT_shader_explicit_arithmetic_types : require); > + > + desc = (FFVulkanDescriptorSetBinding[]){ > + { > + .name = "in_buf", > + .stages = VK_SHADER_STAGE_COMPUTE_BIT, > + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, > + .buf_content = "int32_t inBuf[];", > + .mem_quali = "readonly", > + .dimensions = 1, > + }, > + { > + .name = "out_buf", > + .stages = VK_SHADER_STAGE_COMPUTE_BIT, > + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, > + .buf_content = "int32_t outBuf[];", > + .mem_quali = "writeonly", > + .dimensions = 1, > + }, > + }; > + RET(ff_vk_pipeline_descriptor_set_add(vkctx, pl, shd, desc, 2, 0, 0)); > + > + ff_vk_add_push_constant(pl, 0, sizeof(WaveletPushConst), > + VK_SHADER_STAGE_COMPUTE_BIT); > + > + GLSLC( > + 0, layout(push_constant, std430) uniform pushConstants { ); > + GLSLC(1, ivec2 plane_sizes[3];); > + GLSLC(1, int plane_offs[3];); > + GLSLC(1, int plane_strides[3];); > + GLSLC(1, int dw[3];); > + GLSLC(1, int wavelet_depth;); > + GLSLC(0, > + };); > + GLSLC(0, ); > + > + GLSLD(get_idx); > + GLSLD(fidelity_low); > + GLSLD(fidelity_high); > + GLSLD(fidelity_low_x); > + GLSLD(fidelity_horiz); > + > + GLSLC( > + 0, void main() { ); > + GLSLC(1, int off_y = int(gl_WorkGroupSize.y * gl_NumWorkGroups.y);); > + GLSLC(1, int off_x = int(gl_WorkGroupSize.x * gl_NumWorkGroups.x);); > + GLSLC(1, int pic_z = int(gl_GlobalInvocationID.z);); > + GLSLC(1, ); > + GLSLC(1, uint h = int(plane_sizes[pic_z].y);); > + GLSLC(2, uint w = int(plane_sizes[pic_z].x);); > + GLSLC(1, ); > + GLSLC(1, int y = int(gl_GlobalInvocationID.y);); > + GLSLC( > + 1, for (; y < h; y += off_y) { ); > + GLSLC(2, int x = int(gl_GlobalInvocationID.x);); > + GLSLC( > + 2, for (; 2 * x < w; x += off_x) { ); > + GLSLC(3, idwt_horiz(pic_z, x, y);); > + GLSLC(2, > + }); > + GLSLC(1, > + }); > + GLSLC(0, > + }); > + > + RET(spv->compile_shader(spv, vkctx, shd, &spv_data, &spv_len, "main", > + &spv_opaque)); > + RET(ff_vk_shader_create(vkctx, shd, spv_data, spv_len, "main")); > + RET(ff_vk_init_compute_pipeline(vkctx, pl, shd)); > + RET(ff_vk_exec_pipeline_register(vkctx, exec, pl)); > + > +fail: > + if (spv_opaque) > + spv->free_shader(spv, &spv_opaque); > + > + return err; > +} > + > +static av_always_inline int inline wavelet_fidelity_pass( > + DiracVulkanDecodeContext *dec, DiracContext *ctx, FFVkExecContext *exec, > + VkBufferMemoryBarrier2 *buf_bar, int *nb_buf_bar) { > + int err; > + int barrier_num = *nb_buf_bar; > + int wavelet_idx = DWT_DIRAC_FIDELITY; > + FFVulkanFunctions *vk = &dec->vkctx.vkfn; > + > + FFVulkanPipeline *pl_hor = &dec->horiz_wavelet_pl[wavelet_idx]; > + FFVulkanPipeline *pl_vert = &dec->vert_wavelet_pl[wavelet_idx]; > + > + err = ff_vk_set_descriptor_buffer(&dec->vkctx, pl_vert, exec, 0, 0, 0, > + dec->tmp_buf.address, dec->tmp_buf.size, > + VK_FORMAT_UNDEFINED); > + if (err < 0) > + goto fail; > + err = ff_vk_set_descriptor_buffer( > + &dec->vkctx, pl_vert, exec, 0, 1, 0, dec->tmp_interleave_buf.address, > + dec->tmp_interleave_buf.size, VK_FORMAT_UNDEFINED); > + if (err < 0) > + goto fail; > + > + err = ff_vk_set_descriptor_buffer( > + &dec->vkctx, pl_hor, exec, 0, 0, 0, dec->tmp_interleave_buf.address, > + dec->tmp_interleave_buf.size, VK_FORMAT_UNDEFINED); > + if (err < 0) > + goto fail; > + err = ff_vk_set_descriptor_buffer(&dec->vkctx, pl_hor, exec, 0, 1, 0, > + dec->tmp_buf.address, dec->tmp_buf.size, > + VK_FORMAT_UNDEFINED); > + if (err < 0) > + goto fail; > + > + for (int i = ctx->wavelet_depth - 1; i >= 0; i--) { > + dec->pConst.plane_strides[0] = ctx->plane[0].idwt.width << i; > + dec->pConst.plane_strides[1] = ctx->plane[1].idwt.width << i; > + dec->pConst.plane_strides[2] = ctx->plane[2].idwt.width << i; > + > + dec->pConst.plane_offs[0] = 0; > + dec->pConst.plane_offs[1] = > + ctx->plane[0].idwt.width * ctx->plane[0].idwt.height; > + dec->pConst.plane_offs[2] = > + dec->pConst.plane_offs[1] + > + ctx->plane[1].idwt.width * ctx->plane[1].idwt.height; > + > + dec->pConst.dw[0] = ctx->plane[0].idwt.width >> (i + 1); > + dec->pConst.dw[1] = ctx->plane[1].idwt.width >> (i + 1); > + dec->pConst.dw[2] = ctx->plane[2].idwt.width >> (i + 1); > + > + dec->pConst.real_plane_dims[0] = (ctx->plane[0].idwt.width) >> i; > + dec->pConst.real_plane_dims[1] = (ctx->plane[0].idwt.height) >> i; > + dec->pConst.real_plane_dims[2] = (ctx->plane[1].idwt.width) >> i; > + dec->pConst.real_plane_dims[3] = (ctx->plane[1].idwt.height) >> i; > + dec->pConst.real_plane_dims[4] = (ctx->plane[2].idwt.width) >> i; > + dec->pConst.real_plane_dims[5] = (ctx->plane[2].idwt.height) >> i; > + > + /* Vertical wavelet pass */ > + ff_vk_update_push_exec(&dec->vkctx, exec, pl_vert, > + VK_SHADER_STAGE_COMPUTE_BIT, 0, > + sizeof(WaveletPushConst), &dec->pConst); > + > + barrier_num = *nb_buf_bar; > + bar_read(buf_bar, nb_buf_bar, &dec->tmp_buf); > + bar_write(buf_bar, nb_buf_bar, &dec->tmp_buf); > + bar_read(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); > + bar_write(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); > + > + vk->CmdPipelineBarrier2( > + exec->buf, > + &(VkDependencyInfo){ > + .sType = VK_STRUCTURE_TYPE_DEPENDENCY_INFO, > + .pBufferMemoryBarriers = buf_bar + barrier_num, > + .bufferMemoryBarrierCount = *nb_buf_bar - barrier_num, > + }); > + > + ff_vk_exec_bind_pipeline(&dec->vkctx, exec, pl_vert); > + vk->CmdDispatch(exec->buf, dec->pConst.real_plane_dims[0] >> 3, > + dec->pConst.real_plane_dims[1] >> 4, 1); > + > + /* Horizontal wavelet pass */ > + ff_vk_update_push_exec(&dec->vkctx, exec, pl_hor, > + VK_SHADER_STAGE_COMPUTE_BIT, 0, > + sizeof(WaveletPushConst), &dec->pConst); > + > + barrier_num = *nb_buf_bar; > + bar_read(buf_bar, nb_buf_bar, &dec->tmp_buf); > + bar_write(buf_bar, nb_buf_bar, &dec->tmp_buf); > + bar_read(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); > + bar_write(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); > + > + vk->CmdPipelineBarrier2( > + exec->buf, > + &(VkDependencyInfo){ > + .sType = VK_STRUCTURE_TYPE_DEPENDENCY_INFO, > + .pBufferMemoryBarriers = buf_bar + barrier_num, > + .bufferMemoryBarrierCount = *nb_buf_bar - barrier_num, > + }); > + > + ff_vk_exec_bind_pipeline(&dec->vkctx, exec, pl_hor); > + vk->CmdDispatch(exec->buf, dec->pConst.real_plane_dims[0] >> 4, > + dec->pConst.real_plane_dims[1] >> 3, 1); > + } > + > + barrier_num = *nb_buf_bar; > + bar_read(buf_bar, nb_buf_bar, &dec->tmp_buf); > + bar_write(buf_bar, nb_buf_bar, &dec->tmp_buf); > + bar_read(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); > + bar_write(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); > + > + vk->CmdPipelineBarrier2( > + exec->buf, &(VkDependencyInfo){ > + .sType = VK_STRUCTURE_TYPE_DEPENDENCY_INFO, > + .pBufferMemoryBarriers = buf_bar + barrier_num, > + .bufferMemoryBarrierCount = *nb_buf_bar - barrier_num, > + }); > + > + return 0; > +fail: > + ff_vk_exec_discard_deps(&dec->vkctx, exec); > + return err; > +} > + > +/* ----- Daubechies(9, 7) init and pipeline pass ----- */ > + > +static const char daub97_low1[] = {C( > + 0, int32_t daub97_low1(int32_t v1, int32_t v2, int32_t v3) { ) > + C(1, return v2 - ((1817 * (v1 + v2) + 2048) >> 12); ) > + C(0, > + })}; > + > +static const char daub97_high1[] = {C( > + 0, int32_t daub97_high1(int32_t v1, int32_t v2, int32_t v3) { ) > + C(1, return v2 - ((113 * (v1 + v2) + 64) >> 7); ) > + C(0, > + })}; > + > +static const char daub97_low0[] = {C( > + 0, int32_t daub97_low0(int32_t v1, int32_t v2, int32_t v3) { ) > + C(1, return v2 - ((217 * (v1 + v2) + 2048) >> 12); ) > + C(0, > + })}; > + > +static const char daub97_high0[] = {C( > + 0, int32_t daub97_high0(int32_t v1, int32_t v2, int32_t v3) { ) > + C(1, return v2 - ((6947 * (v1 + v2) + 2048) >> 12); ) > + C(0, > + })}; > + > +static const char daub97_low_x0[] = {C( > + 0, int32_t daub97_low_x0(int plane, int x, int y) { ) > + C(1, int w = plane_sizes[plane].x; ) > + C(1, int dw = plane_sizes[plane].x / 2; ) > + C(1, ) > + C(1, int x0 = (x == 0) ? dw : x + dw; ) > + C(1, int32_t v0 = inBuf[getIdx(plane, x0, y)]; ) > + C(1, ) > + C(1, int32_t v1 = inBuf[getIdx(plane, x, y)]; ) > + C(1, ) > + C(1, int x2 = x + dw; ) > + C(1, int32_t v2 = inBuf[getIdx(plane, x0, y)]; ) > + C(1, ) > + C(1, return daub97_low1(v0, v1, v2); ) > + C(0, > + })}; > + > +static const char daub97_high_x0[] = {C( > + 0, int32_t daub97_high_x0(int plane, int x, int y) { ) > + C(1, int w = plane_sizes[plane].x; ) > + C(1, int dw = plane_sizes[plane].x / 2; ) > + C(1, ) > + C(1, int x0 = (x == dw - 1) ? (dw - 1) : (x - 1); ) > + C(1, int32_t v0 = daub97_low_x0(plane, x0, y); ) > + C(1, ) > + C(1, int32_t v1 = inBuf[getIdx(plane, x + dw - 1, y)]; ) > + C(1, ) > + C(1, int32_t v2 = daub97_low_x0(plane, x, y); ) > + C(1, ) > + C(1, return daub97_high1(v0, v1, v2); ) > + C(0, > + })}; > + > +static const char daub97_low_x1[] = {C( > + 0, int32_t daub97_low_x1(int plane, int x, int y) { ) > + C(1, int w = plane_sizes[plane].x; ) > + C(1, int dw = plane_sizes[plane].x / 2; ) > + C(1, ) > + C(1, int32_t v0 = daub97_high_x0(plane, x, y); ) > + C(1, ) > + C(1, int32_t v1 = daub97_low_x0(plane, x, y); ) > + C(1, ) > + C(1, int32_t v2 = daub97_high_x0(plane, x + 1, y); ) > + C(1, ) > + C(1, return daub97_low0(v0, v1, v2); ) > + C(0, > + })}; > + > +static const char daub97_high_x1[] = {C( > + 0, int32_t daub97_high_x1(int plane, int x, int y) { ) > + C(1, int w = plane_sizes[plane].x; ) > + C(1, int dw = plane_sizes[plane].x / 2; ) > + C(1, ) > + C(1, int x0 = clamp(x - 1, 0, dw); ) > + C(1, int32_t v0 = daub97_low_x1(plane, x0, y); ) > + C(1, ) > + C(1, int32_t v1 = daub97_high_x0(plane, x + 1, y); ) > + C(1, ) > + C(1, int x2 = clamp(x, 0, dw); ) > + C(1, int32_t v2 = daub97_low_x1(plane, x2, y); ) > + C(1, ) > + C(1, return daub97_high0(v0, v1, v2); ) > + C(0, > + })}; > + > +static const char daub97_horiz[] = {C( > + 0, void idwt_horiz(int plane, int x, int y) { ) > + C(1, int w = plane_sizes[plane].x; ) > + C(1, int dw = plane_sizes[plane].x / 2; ) > + C(1, ) > + C(1, int32_t v0 = daub97_low_x1(plane, x, y); ) > + C(1, int32_t v1 = daub97_high_x1(plane, x, y); ) > + C(1, ) > + C(1, outBuf[getIdx(plane, 2 * x, y)] = ~((~v0) >> 1); ) > + C(1, outBuf[getIdx(plane, 2 * x + 1, y)] = ~((~v1) >> 1); ) > + C(0, > + })}; > + > +static int init_wavelet_shd_daub97_vert(DiracVulkanDecodeContext *s, > + FFVkSPIRVCompiler *spv) { > + int err = 0; > + uint8_t *spv_data; > + size_t spv_len; > + void *spv_opaque = NULL; > + int wavelet_idx = DWT_DIRAC_DAUB9_7; > + FFVulkanContext *vkctx = &s->vkctx; > + FFVulkanDescriptorSetBinding *desc; > + FFVkSPIRVShader *shd = &s->vert_wavelet_shd[wavelet_idx]; > + FFVulkanPipeline *pl = &s->vert_wavelet_pl[wavelet_idx]; > + FFVkExecPool *exec = &s->exec_pool; > + > + RET(ff_vk_shader_init(pl, shd, "daub97_vert", VK_SHADER_STAGE_COMPUTE_BIT, > + 0)); > + > + shd = &s->vert_wavelet_shd[wavelet_idx]; > + ff_vk_shader_set_compute_sizes(shd, 8, 1, 3); > + > + GLSLC(0, #extension GL_EXT_debug_printf : enable); > + GLSLC(0, #extension GL_EXT_scalar_block_layout : require); > + GLSLC(0, #extension GL_EXT_shader_explicit_arithmetic_types : require); > + > + desc = (FFVulkanDescriptorSetBinding[]){ > + { > + .name = "in_buf", > + .stages = VK_SHADER_STAGE_COMPUTE_BIT, > + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, > + .buf_content = "int32_t inBuf[];", > + /*.mem_quali = "readonly",*/ > + .dimensions = 1, > + }, > + { > + .name = "out_buf", > + .stages = VK_SHADER_STAGE_COMPUTE_BIT, > + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, > + .buf_content = "int32_t outBuf[];", > + .mem_quali = "writeonly", > + .dimensions = 1, > + }, > + }; > + RET(ff_vk_pipeline_descriptor_set_add(vkctx, pl, shd, desc, 2, 0, 0)); > + > + ff_vk_add_push_constant(pl, 0, sizeof(WaveletPushConst), > + VK_SHADER_STAGE_COMPUTE_BIT); > + > + GLSLC( > + 0, layout(push_constant, std430) uniform pushConstants { ); > + GLSLC(1, ivec2 plane_sizes[3];); > + GLSLC(1, int plane_offs[3];); > + GLSLC(1, int plane_strides[3];); > + GLSLC(1, int dw[3];); > + GLSLC(1, int wavelet_depth;); > + GLSLC(0, > + };); > + GLSLC(0, ); > + > + GLSLD(get_idx); > + > + GLSLC( > + 0, void main() { ); > + GLSLC(1, int off_x = int(gl_WorkGroupSize.x * gl_NumWorkGroups.x);); > + GLSLC(1, int pic_z = int(gl_GlobalInvocationID.z);); > + GLSLC(1, ); > + GLSLC(1, uint h = int(plane_sizes[pic_z].y);); > + GLSLC(1, uint w = int(plane_sizes[pic_z].x);); > + GLSLC(1, ); > + GLSLC(2, int x = int(gl_GlobalInvocationID.x);); > + GLSLC( > + 1, for (; x < w; x += off_x) { ); > + GLSLC( > + 2, for (int y = 0; y < h; y += 2) { ); > + GLSLC(3, int32_t v0 = inBuf[getIdx( > + pic_z, x, int(clamp(y - 1, 0, h)))];); > + GLSLC(3, > + int32_t v1 = inBuf[getIdx(pic_z, x, y + 1)];); > + GLSLC(3, inBuf[getIdx(pic_z, x, y)] -= > + (1817 * (v0 + v1 + 2048)) >> 12;); > + GLSLC(2, > + }); > + GLSLC( > + 2, for (int y = 0; y < h; y += 2) { ); > + GLSLC(3, int32_t v0 = inBuf[getIdx(pic_z, x, y)];); > + GLSLC(3, > + int32_t v1 = inBuf[getIdx( > + pic_z, x, int(clamp(y + 2, 0, h - 2)))];); > + GLSLC(3, inBuf[getIdx(pic_z, x, y + 1)] -= > + (3616 * (v0 + v1 + 2048)) >> 12;); > + GLSLC(2, > + }); > + GLSLC( > + 2, for (int y = 0; y < h; y += 2) { ); > + GLSLC(3, int32_t v0 = inBuf[getIdx( > + pic_z, x, int(clamp(y - 1, 0, h)))];); > + GLSLC(3, > + int32_t v1 = inBuf[getIdx(pic_z, x, y + 1)];); > + GLSLC(3, int32_t v2 = inBuf[getIdx(pic_z, x, y)];); > + GLSLC(3, outBuf[getIdx(pic_z, x, y)] = > + v2 + (217 * (v0 + v1 + 2048)) >> 12;); > + GLSLC(2, > + }); > + GLSLC( > + 2, for (int y = 0; y < h; y += 2) { ); > + GLSLC(3, int32_t v0 = inBuf[getIdx(pic_z, x, y)];); > + GLSLC(3, > + int32_t v1 = inBuf[getIdx( > + pic_z, x, int(clamp(y + 2, 0, h - 2)))];); > + GLSLC(3, > + int32_t v2 = inBuf[getIdx(pic_z, x, y + 1)];); > + GLSLC(3, outBuf[getIdx(pic_z, x, y + 1)] = > + v2 + (6497 * (v0 + v1 + 2048)) >> 12;); > + GLSLC(2, > + }); > + GLSLC(1, > + }); > + GLSLC(0, > + }); > + > + RET(spv->compile_shader(spv, vkctx, shd, &spv_data, &spv_len, "main", > + &spv_opaque)); > + RET(ff_vk_shader_create(vkctx, shd, spv_data, spv_len, "main")); > + RET(ff_vk_init_compute_pipeline(vkctx, pl, shd)); > + RET(ff_vk_exec_pipeline_register(vkctx, exec, pl)); > + > +fail: > + if (spv_opaque) > + spv->free_shader(spv, &spv_opaque); > + > + return err; > +} > + > +static int init_wavelet_shd_daub97_horiz(DiracVulkanDecodeContext *s, > + FFVkSPIRVCompiler *spv) { > + int err = 0; > + uint8_t *spv_data; > + size_t spv_len; > + void *spv_opaque = NULL; > + int wavelet_idx = DWT_DIRAC_DAUB9_7; > + FFVulkanContext *vkctx = &s->vkctx; > + FFVulkanDescriptorSetBinding *desc; > + FFVkSPIRVShader *shd = &s->horiz_wavelet_shd[wavelet_idx]; > + FFVulkanPipeline *pl = &s->horiz_wavelet_pl[wavelet_idx]; > + FFVkExecPool *exec = &s->exec_pool; > + > + RET(ff_vk_shader_init(pl, shd, "daub97_horiz", VK_SHADER_STAGE_COMPUTE_BIT, > + 0)); > + > + shd = &s->horiz_wavelet_shd[wavelet_idx]; > + ff_vk_shader_set_compute_sizes(shd, 8, 8, 3); > + > + GLSLC(0, #extension GL_EXT_debug_printf : enable); > + GLSLC(0, #extension GL_EXT_scalar_block_layout : require); > + GLSLC(0, #extension GL_EXT_shader_explicit_arithmetic_types : require); > + > + desc = (FFVulkanDescriptorSetBinding[]){ > + { > + .name = "in_buf", > + .stages = VK_SHADER_STAGE_COMPUTE_BIT, > + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, > + .buf_content = "int32_t inBuf[];", > + .mem_quali = "readonly", > + .dimensions = 1, > + }, > + { > + .name = "out_buf", > + .stages = VK_SHADER_STAGE_COMPUTE_BIT, > + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, > + .buf_content = "int32_t outBuf[];", > + .mem_quali = "writeonly", > + .dimensions = 1, > + }, > + }; > + RET(ff_vk_pipeline_descriptor_set_add(vkctx, pl, shd, desc, 2, 0, 0)); > + > + ff_vk_add_push_constant(pl, 0, sizeof(WaveletPushConst), > + VK_SHADER_STAGE_COMPUTE_BIT); > + > + GLSLC( > + 0, layout(push_constant, std430) uniform pushConstants { ); > + GLSLC(1, ivec2 plane_sizes[3];); > + GLSLC(1, int plane_offs[3];); > + GLSLC(1, int plane_strides[3];); > + GLSLC(1, int dw[3];); > + GLSLC(1, int wavelet_depth;); > + GLSLC(0, > + };); > + GLSLC(0, ); > + > + GLSLD(get_idx); > + GLSLD(daub97_low1); > + GLSLD(daub97_low0); > + GLSLD(daub97_high1); > + GLSLD(daub97_high0); > + GLSLD(daub97_low_x0); > + GLSLD(daub97_high_x0); > + GLSLD(daub97_low_x1); > + GLSLD(daub97_high_x1); > + GLSLD(daub97_horiz); > + > + GLSLC( > + 0, void main() { ); > + GLSLC(1, int off_y = int(gl_WorkGroupSize.y * gl_NumWorkGroups.y);); > + GLSLC(1, int off_x = int(gl_WorkGroupSize.x * gl_NumWorkGroups.x);); > + GLSLC(1, int pic_z = int(gl_GlobalInvocationID.z);); > + GLSLC(1, ); > + GLSLC(1, uint h = int(plane_sizes[pic_z].y);); > + GLSLC(2, uint w = int(plane_sizes[pic_z].x);); > + GLSLC(1, ); > + GLSLC(1, int y = int(gl_GlobalInvocationID.y);); > + GLSLC( > + 1, for (; y < h; y += off_y) { ); > + GLSLC(2, int x = int(gl_GlobalInvocationID.x);); > + GLSLC( > + 2, for (; 2 * x < w; x += off_x) { ); > + GLSLC(3, idwt_horiz(pic_z, x, y);); > + GLSLC(2, > + }); > + GLSLC(1, > + }); > + GLSLC(0, > + }); > + > + RET(spv->compile_shader(spv, vkctx, shd, &spv_data, &spv_len, "main", > + &spv_opaque)); > + RET(ff_vk_shader_create(vkctx, shd, spv_data, spv_len, "main")); > + RET(ff_vk_init_compute_pipeline(vkctx, pl, shd)); > + RET(ff_vk_exec_pipeline_register(vkctx, exec, pl)); > + > +fail: > + if (spv_opaque) > + spv->free_shader(spv, &spv_opaque); > + > + return err; > +} > + > +static av_always_inline int inline wavelet_daub97_pass( > + DiracVulkanDecodeContext *dec, DiracContext *ctx, FFVkExecContext *exec, > + VkBufferMemoryBarrier2 *buf_bar, int *nb_buf_bar) { > + int err; > + int barrier_num = *nb_buf_bar; > + int wavelet_idx = DWT_DIRAC_DAUB9_7; > + FFVulkanFunctions *vk = &dec->vkctx.vkfn; > + > + FFVulkanPipeline *pl_hor = &dec->horiz_wavelet_pl[wavelet_idx]; > + FFVulkanPipeline *pl_vert = &dec->vert_wavelet_pl[wavelet_idx]; > + > + err = ff_vk_set_descriptor_buffer(&dec->vkctx, pl_vert, exec, 0, 0, 0, > + dec->tmp_buf.address, dec->tmp_buf.size, > + VK_FORMAT_UNDEFINED); > + if (err < 0) > + goto fail; > + err = ff_vk_set_descriptor_buffer( > + &dec->vkctx, pl_vert, exec, 0, 1, 0, dec->tmp_interleave_buf.address, > + dec->tmp_interleave_buf.size, VK_FORMAT_UNDEFINED); > + if (err < 0) > + goto fail; > + > + err = ff_vk_set_descriptor_buffer( > + &dec->vkctx, pl_hor, exec, 0, 0, 0, dec->tmp_interleave_buf.address, > + dec->tmp_interleave_buf.size, VK_FORMAT_UNDEFINED); > + if (err < 0) > + goto fail; > + err = ff_vk_set_descriptor_buffer(&dec->vkctx, pl_hor, exec, 0, 1, 0, > + dec->tmp_buf.address, dec->tmp_buf.size, > + VK_FORMAT_UNDEFINED); > + if (err < 0) > + goto fail; > + > + for (int i = ctx->wavelet_depth - 1; i >= 0; i--) { > + dec->pConst.plane_strides[0] = ctx->plane[0].idwt.width << i; > + dec->pConst.plane_strides[1] = ctx->plane[1].idwt.width << i; > + dec->pConst.plane_strides[2] = ctx->plane[2].idwt.width << i; > + > + dec->pConst.plane_offs[0] = 0; > + dec->pConst.plane_offs[1] = > + ctx->plane[0].idwt.width * ctx->plane[0].idwt.height; > + dec->pConst.plane_offs[2] = > + dec->pConst.plane_offs[1] + > + ctx->plane[1].idwt.width * ctx->plane[1].idwt.height; > + > + dec->pConst.dw[0] = ctx->plane[0].idwt.width >> (i + 1); > + dec->pConst.dw[1] = ctx->plane[1].idwt.width >> (i + 1); > + dec->pConst.dw[2] = ctx->plane[2].idwt.width >> (i + 1); > + > + dec->pConst.real_plane_dims[0] = (ctx->plane[0].idwt.width) >> i; > + dec->pConst.real_plane_dims[1] = (ctx->plane[0].idwt.height) >> i; > + dec->pConst.real_plane_dims[2] = (ctx->plane[1].idwt.width) >> i; > + dec->pConst.real_plane_dims[3] = (ctx->plane[1].idwt.height) >> i; > + dec->pConst.real_plane_dims[4] = (ctx->plane[2].idwt.width) >> i; > + dec->pConst.real_plane_dims[5] = (ctx->plane[2].idwt.height) >> i; > + > + /* Vertical wavelet pass */ > + ff_vk_update_push_exec(&dec->vkctx, exec, pl_vert, > + VK_SHADER_STAGE_COMPUTE_BIT, 0, > + sizeof(WaveletPushConst), &dec->pConst); > + > + barrier_num = *nb_buf_bar; > + bar_read(buf_bar, nb_buf_bar, &dec->tmp_buf); > + bar_write(buf_bar, nb_buf_bar, &dec->tmp_buf); > + bar_read(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); > + bar_write(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); > + > + vk->CmdPipelineBarrier2( > + exec->buf, > + &(VkDependencyInfo){ > + .sType = VK_STRUCTURE_TYPE_DEPENDENCY_INFO, > + .pBufferMemoryBarriers = buf_bar + barrier_num, > + .bufferMemoryBarrierCount = *nb_buf_bar - barrier_num, > + }); > + > + ff_vk_exec_bind_pipeline(&dec->vkctx, exec, pl_vert); > + vk->CmdDispatch(exec->buf, dec->pConst.real_plane_dims[0], 1, 1); > + > + /* Horizontal wavelet pass */ > + ff_vk_update_push_exec(&dec->vkctx, exec, pl_hor, > + VK_SHADER_STAGE_COMPUTE_BIT, 0, > + sizeof(WaveletPushConst), &dec->pConst); > + > + barrier_num = *nb_buf_bar; > + bar_read(buf_bar, nb_buf_bar, &dec->tmp_buf); > + bar_write(buf_bar, nb_buf_bar, &dec->tmp_buf); > + bar_read(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); > + bar_write(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); > + > + vk->CmdPipelineBarrier2( > + exec->buf, > + &(VkDependencyInfo){ > + .sType = VK_STRUCTURE_TYPE_DEPENDENCY_INFO, > + .pBufferMemoryBarriers = buf_bar + barrier_num, > + .bufferMemoryBarrierCount = *nb_buf_bar - barrier_num, > + }); > + > + ff_vk_exec_bind_pipeline(&dec->vkctx, exec, pl_hor); > + vk->CmdDispatch(exec->buf, dec->pConst.real_plane_dims[0] >> 4, > + dec->pConst.real_plane_dims[1] >> 3, 1); > + > + barrier_num = *nb_buf_bar; > + bar_read(buf_bar, nb_buf_bar, &dec->tmp_buf); > + bar_write(buf_bar, nb_buf_bar, &dec->tmp_buf); > + bar_read(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); > + bar_write(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); > + > + vk->CmdPipelineBarrier2( > + exec->buf, > + &(VkDependencyInfo){ > + .sType = VK_STRUCTURE_TYPE_DEPENDENCY_INFO, > + .pBufferMemoryBarriers = buf_bar + barrier_num, > + .bufferMemoryBarrierCount = *nb_buf_bar - barrier_num, > + }); > + } > + > + return 0; > +fail: > + ff_vk_exec_discard_deps(&dec->vkctx, exec); > + return err; > +} > + > +/* ----- Deslauriers-Dubuc(9, 7) init and pipeline pass ----- */ > + > +static const char dd97_high[] = {C( > + 0, int32_t dd97_high(int32_t v1, int32_t v2, int32_t v3, int32_t v4, > + int32_t v5) { ) > + C(1, return v3 + ((9 * v4 + 9 * v2 - v5 - v1 + 8) >> 4); ) > + C(0, > + })}; > + > +static const char dd97_vert[] = {C( > + 0, void idwt_vert(int plane, int x, int y) { ) > + C(1, const int h = plane_sizes[plane].y; ) > + C(1, ) > + C(1, const int32_t out0 = legall_low_y(plane, x, y); ) > + C(1, const int32_t out_2 = (y - 2 > 0) ? legall_low_y(plane, x, y - 2) : ) > + C(1, legall_low_y(plane, x, 0); ) > + C(1, const int32_t out2 = (y + 2 < h) ? legall_low_y(plane, x, y + 2) : ) > + C(1, legall_low_y(plane, x, h - 2); ) > + C(1, const int32_t out4 = (y + 4 < h) ? legall_low_y(plane, x, y + 4) : ) > + C(1, legall_low_y(plane, x, h - 2); ) > + C(1, const int32_t val1 = inBuf[getIdx(plane, x, y + 1)]; ) > + C(1, ) > + C(1, outBuf[getIdx(plane, x, y)] = out0; ) > + C(1, outBuf[getIdx(plane, x, y + 1)] = dd97_high(out_2, out0, val1, out2, out4); ) > + C(1, > + })}; > + > +static const char dd97_horiz[] = {C( > + 0, void idwt_horiz(int plane, int x, int y) { ) > + C(1, const int w = plane_sizes[plane].x; ) > + C(1, const int dw = w / 2 - 1; ) > + C(1, ) > + C(1, const int32_t out0 = legall_low_x(plane, x, y); ) > + C(1, ) > + C(1, const int32_t out_1 = ((x - 1) > 0) ? legall_low_x(plane, x - 1, y) : out0; ) > + C(1, const int32_t val3 = inBuf[getIdx(plane, x + dw + 1, y)]; ) > + C(1, const int32_t out1 = ((x + 1) <= dw) ? legall_low_x(plane, x + 1, y) : ) > + C(1, legall_low_x(plane, dw, y); ) > + C(1, const int32_t out2 = ((x + 2) <= dw) ? legall_low_x(plane, x + 2, y) : ) > + C(1, legall_low_x(plane, dw, y); ) > + C(1, const int32_t res = dd97_high(out_1, out0, val3, out1, out2); ) > + C(1, ) > + C(1, outBuf[getIdx(plane, 2 * x, y)] = (out0 + 1) >> 1; ) > + C(1, outBuf[getIdx(plane, 2 * x + 1, y)] = (res + 1) >> 1; ) > + C(0, > + })}; > + > +static int init_wavelet_shd_dd97_vert(DiracVulkanDecodeContext *s, > + FFVkSPIRVCompiler *spv) { > + int err = 0; > + uint8_t *spv_data; > + size_t spv_len; > + void *spv_opaque = NULL; > + int wavelet_idx = DWT_DIRAC_DD9_7; > + FFVulkanContext *vkctx = &s->vkctx; > + FFVulkanDescriptorSetBinding *desc; > + FFVkSPIRVShader *shd = &s->vert_wavelet_shd[wavelet_idx]; > + FFVulkanPipeline *pl = &s->vert_wavelet_pl[wavelet_idx]; > + FFVkExecPool *exec = &s->exec_pool; > + > + RET(ff_vk_shader_init(pl, shd, "dd97_vert", VK_SHADER_STAGE_COMPUTE_BIT, > + 0)); > + > + shd = &s->vert_wavelet_shd[wavelet_idx]; > + ff_vk_shader_set_compute_sizes(shd, 8, 8, 3); > + > + GLSLC(0, #extension GL_EXT_scalar_block_layout : require); > + GLSLC(0, #extension GL_EXT_shader_explicit_arithmetic_types : require); > + > + desc = (FFVulkanDescriptorSetBinding[]){ > + { > + .name = "in_buf", > + .stages = VK_SHADER_STAGE_COMPUTE_BIT, > + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, > + .buf_content = "int32_t inBuf[];", > + .mem_quali = "readonly", > + .dimensions = 1, > + }, > + { > + .name = "out_buf", > + .stages = VK_SHADER_STAGE_COMPUTE_BIT, > + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, > + .buf_content = "int32_t outBuf[];", > + .mem_quali = "writeonly", > + .dimensions = 1, > + }, > + }; > + RET(ff_vk_pipeline_descriptor_set_add(vkctx, pl, shd, desc, 2, 0, 0)); > + > + ff_vk_add_push_constant(pl, 0, sizeof(WaveletPushConst), > + VK_SHADER_STAGE_COMPUTE_BIT); > + > + GLSLC( > + 0, layout(push_constant, std430) uniform pushConstants { ); > + GLSLC(1, ivec2 plane_sizes[3];); > + GLSLC(1, int plane_offs[3];); > + GLSLC(1, int plane_strides[3];); > + GLSLC(1, int dw[3];); > + GLSLC(1, int wavelet_depth;); > + GLSLC(0, > + };); > + GLSLC(0, ); > + > + GLSLD(get_idx); > + GLSLD(legall_low_y); > + GLSLD(dd97_high); > + GLSLD(dd97_vert); > + > + GLSLC( > + 0, void main() { ); > + GLSLC(1, int off_y = int(gl_WorkGroupSize.y * gl_NumWorkGroups.y);); > + GLSLC(1, int off_x = int(gl_WorkGroupSize.x * gl_NumWorkGroups.x);); > + GLSLC(1, int pic_z = int(gl_GlobalInvocationID.z);); > + GLSLC(1, ); > + GLSLC(1, uint h = int(plane_sizes[pic_z].y);); > + GLSLC(2, uint w = int(plane_sizes[pic_z].x);); > + GLSLC(1, ); > + GLSLC(1, int y = int(gl_GlobalInvocationID.y);); > + GLSLC( > + 1, for (; 2 * y < h; y += off_y) { ); > + GLSLC(2, int x = int(gl_GlobalInvocationID.x);); > + GLSLC( > + 2, for (; x < w; x += off_x) { ); > + GLSLC(3, idwt_vert(pic_z, x, 2 * y);); > + GLSLC(2, > + }); > + GLSLC(1, > + }); > + GLSLC(0, > + }); > + > + RET(spv->compile_shader(spv, vkctx, shd, &spv_data, &spv_len, "main", > + &spv_opaque)); > + RET(ff_vk_shader_create(vkctx, shd, spv_data, spv_len, "main")); > + RET(ff_vk_init_compute_pipeline(vkctx, pl, shd)); > + RET(ff_vk_exec_pipeline_register(vkctx, exec, pl)); > + > +fail: > + if (spv_opaque) > + spv->free_shader(spv, &spv_opaque); > + > + return err; > +} > + > +static int init_wavelet_shd_dd97_horiz(DiracVulkanDecodeContext *s, > + FFVkSPIRVCompiler *spv) { > + int err = 0; > + uint8_t *spv_data; > + size_t spv_len; > + void *spv_opaque = NULL; > + int wavelet_idx = DWT_DIRAC_DD9_7; > + FFVulkanContext *vkctx = &s->vkctx; > + FFVulkanDescriptorSetBinding *desc; > + FFVkSPIRVShader *shd = &s->horiz_wavelet_shd[wavelet_idx]; > + FFVulkanPipeline *pl = &s->horiz_wavelet_pl[wavelet_idx]; > + FFVkExecPool *exec = &s->exec_pool; > + > + RET(ff_vk_shader_init(pl, shd, "dd97_horiz", VK_SHADER_STAGE_COMPUTE_BIT, > + 0)); > + > + shd = &s->horiz_wavelet_shd[wavelet_idx]; > + ff_vk_shader_set_compute_sizes(shd, 8, 8, 3); > + > + GLSLC(0, #extension GL_EXT_debug_printf : enable); > + GLSLC(0, #extension GL_EXT_scalar_block_layout : require); > + GLSLC(0, #extension GL_EXT_shader_explicit_arithmetic_types : require); > + > + desc = (FFVulkanDescriptorSetBinding[]){ > + { > + .name = "in_buf", > + .stages = VK_SHADER_STAGE_COMPUTE_BIT, > + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, > + .buf_content = "int32_t inBuf[];", > + .mem_quali = "readonly", > + .dimensions = 1, > + }, > + { > + .name = "out_buf", > + .stages = VK_SHADER_STAGE_COMPUTE_BIT, > + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, > + .buf_content = "int32_t outBuf[];", > + .mem_quali = "writeonly", > + .dimensions = 1, > + }, > + }; > + RET(ff_vk_pipeline_descriptor_set_add(vkctx, pl, shd, desc, 2, 0, 0)); > + > + ff_vk_add_push_constant(pl, 0, sizeof(WaveletPushConst), > + VK_SHADER_STAGE_COMPUTE_BIT); > + > + GLSLC( > + 0, layout(push_constant, std430) uniform pushConstants { ); > + GLSLC(1, ivec2 plane_sizes[3];); > + GLSLC(1, int plane_offs[3];); > + GLSLC(1, int plane_strides[3];); > + GLSLC(1, int dw[3];); > + GLSLC(1, int wavelet_depth;); > + GLSLC(0, > + };); > + GLSLC(0, ); > + > + GLSLD(get_idx); > + GLSLD(legall_low_x); > + GLSLD(dd97_high); > + GLSLD(dd97_horiz); > + > + GLSLC( > + 0, void main() { ); > + GLSLC(1, int off_y = int(gl_WorkGroupSize.y * gl_NumWorkGroups.y);); > + GLSLC(1, int off_x = int(gl_WorkGroupSize.x * gl_NumWorkGroups.x);); > + GLSLC(1, int pic_z = int(gl_GlobalInvocationID.z);); > + GLSLC(1, ); > + GLSLC(1, uint h = int(plane_sizes[pic_z].y);); > + GLSLC(2, uint w = int(plane_sizes[pic_z].x);); > + GLSLC(1, ); > + GLSLC(1, int y = int(gl_GlobalInvocationID.y);); > + GLSLC( > + 1, for (; y < h; y += off_y) { ); > + GLSLC(2, int x = int(gl_GlobalInvocationID.x);); > + GLSLC( > + 2, for (; 2 * x < w; x += off_x) { ); > + GLSLC(3, idwt_horiz(pic_z, x, y);); > + GLSLC(2, > + }); > + GLSLC(1, > + }); > + GLSLC(0, > + }); > + > + RET(spv->compile_shader(spv, vkctx, shd, &spv_data, &spv_len, "main", > + &spv_opaque)); > + RET(ff_vk_shader_create(vkctx, shd, spv_data, spv_len, "main")); > + RET(ff_vk_init_compute_pipeline(vkctx, pl, shd)); > + RET(ff_vk_exec_pipeline_register(vkctx, exec, pl)); > + > +fail: > + if (spv_opaque) > + spv->free_shader(spv, &spv_opaque); > + > + return err; > +} > + > +static av_always_inline int inline wavelet_dd97_pass( > + DiracVulkanDecodeContext *dec, DiracContext *ctx, FFVkExecContext *exec, > + VkBufferMemoryBarrier2 *buf_bar, int *nb_buf_bar) { > + int err; > + int barrier_num = *nb_buf_bar; > + int wavelet_idx = DWT_DIRAC_DD9_7; > + FFVulkanFunctions *vk = &dec->vkctx.vkfn; > + > + FFVulkanPipeline *pl_hor = &dec->horiz_wavelet_pl[wavelet_idx]; > + FFVulkanPipeline *pl_vert = &dec->vert_wavelet_pl[wavelet_idx]; > + > + err = ff_vk_set_descriptor_buffer(&dec->vkctx, pl_vert, exec, 0, 0, 0, > + dec->tmp_buf.address, dec->tmp_buf.size, > + VK_FORMAT_UNDEFINED); > + if (err < 0) > + goto fail; > + err = ff_vk_set_descriptor_buffer( > + &dec->vkctx, pl_vert, exec, 0, 1, 0, dec->tmp_interleave_buf.address, > + dec->tmp_interleave_buf.size, VK_FORMAT_UNDEFINED); > + if (err < 0) > + goto fail; > + > + err = ff_vk_set_descriptor_buffer( > + &dec->vkctx, pl_hor, exec, 0, 0, 0, dec->tmp_interleave_buf.address, > + dec->tmp_interleave_buf.size, VK_FORMAT_UNDEFINED); > + if (err < 0) > + goto fail; > + err = ff_vk_set_descriptor_buffer(&dec->vkctx, pl_hor, exec, 0, 1, 0, > + dec->tmp_buf.address, dec->tmp_buf.size, > + VK_FORMAT_UNDEFINED); > + if (err < 0) > + goto fail; > + > + for (int i = ctx->wavelet_depth - 1; i >= 0; i--) { > + dec->pConst.plane_strides[0] = ctx->plane[0].idwt.width << i; > + dec->pConst.plane_strides[1] = ctx->plane[1].idwt.width << i; > + dec->pConst.plane_strides[2] = ctx->plane[2].idwt.width << i; > + > + dec->pConst.plane_offs[0] = 0; > + dec->pConst.plane_offs[1] = > + ctx->plane[0].idwt.width * ctx->plane[0].idwt.height; > + dec->pConst.plane_offs[2] = > + dec->pConst.plane_offs[1] + > + ctx->plane[1].idwt.width * ctx->plane[1].idwt.height; > + > + dec->pConst.dw[0] = ctx->plane[0].idwt.width >> (i + 1); > + dec->pConst.dw[1] = ctx->plane[1].idwt.width >> (i + 1); > + dec->pConst.dw[2] = ctx->plane[2].idwt.width >> (i + 1); > + > + dec->pConst.real_plane_dims[0] = (ctx->plane[0].idwt.width) >> i; > + dec->pConst.real_plane_dims[1] = (ctx->plane[0].idwt.height) >> i; > + dec->pConst.real_plane_dims[2] = (ctx->plane[1].idwt.width) >> i; > + dec->pConst.real_plane_dims[3] = (ctx->plane[1].idwt.height) >> i; > + dec->pConst.real_plane_dims[4] = (ctx->plane[2].idwt.width) >> i; > + dec->pConst.real_plane_dims[5] = (ctx->plane[2].idwt.height) >> i; > + > + /* Vertical wavelet pass */ > + ff_vk_update_push_exec(&dec->vkctx, exec, pl_vert, > + VK_SHADER_STAGE_COMPUTE_BIT, 0, > + sizeof(WaveletPushConst), &dec->pConst); > + > + barrier_num = *nb_buf_bar; > + bar_read(buf_bar, nb_buf_bar, &dec->tmp_buf); > + bar_write(buf_bar, nb_buf_bar, &dec->tmp_buf); > + bar_read(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); > + bar_write(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); > + > + vk->CmdPipelineBarrier2( > + exec->buf, > + &(VkDependencyInfo){ > + .sType = VK_STRUCTURE_TYPE_DEPENDENCY_INFO, > + .pBufferMemoryBarriers = buf_bar + barrier_num, > + .bufferMemoryBarrierCount = *nb_buf_bar - barrier_num, > + }); > + > + ff_vk_exec_bind_pipeline(&dec->vkctx, exec, pl_vert); > + vk->CmdDispatch(exec->buf, dec->pConst.real_plane_dims[0] >> 3, > + dec->pConst.real_plane_dims[1] >> 4, 1); > + > + /* Horizontal wavelet pass */ > + ff_vk_update_push_exec(&dec->vkctx, exec, pl_hor, > + VK_SHADER_STAGE_COMPUTE_BIT, 0, > + sizeof(WaveletPushConst), &dec->pConst); > + > + barrier_num = *nb_buf_bar; > + bar_read(buf_bar, nb_buf_bar, &dec->tmp_buf); > + bar_write(buf_bar, nb_buf_bar, &dec->tmp_buf); > + bar_read(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); > + bar_write(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); > + > + vk->CmdPipelineBarrier2( > + exec->buf, > + &(VkDependencyInfo){ > + .sType = VK_STRUCTURE_TYPE_DEPENDENCY_INFO, > + .pBufferMemoryBarriers = buf_bar + barrier_num, > + .bufferMemoryBarrierCount = *nb_buf_bar - barrier_num, > + }); > + > + ff_vk_exec_bind_pipeline(&dec->vkctx, exec, pl_hor); > + vk->CmdDispatch(exec->buf, dec->pConst.real_plane_dims[0] >> 4, > + dec->pConst.real_plane_dims[1] >> 3, 1); > + > + barrier_num = *nb_buf_bar; > + bar_read(buf_bar, nb_buf_bar, &dec->tmp_buf); > + bar_write(buf_bar, nb_buf_bar, &dec->tmp_buf); > + bar_read(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); > + bar_write(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); > + > + vk->CmdPipelineBarrier2( > + exec->buf, > + &(VkDependencyInfo){ > + .sType = VK_STRUCTURE_TYPE_DEPENDENCY_INFO, > + .pBufferMemoryBarriers = buf_bar + barrier_num, > + .bufferMemoryBarrierCount = *nb_buf_bar - barrier_num, > + }); > + } > + > + return 0; > +fail: > + ff_vk_exec_discard_deps(&dec->vkctx, exec); > + return err; > +} > + > +/* ----- Deslauriers-Dubuc(13, 7) init and pipeline pass ----- */ > +static const char dd137_low[] = {C( > + 0, int32_t dd137_low(int32_t v0, int32_t v1, int32_t v2, int32_t v3, > + int32_t v4) { ) > + C(0, return v2 - ((9 * v1 + 9 * v3 - v4 - v0 + 16) >> 5); ) > + C(0, > + })}; > + > +static const char dd137_low_y[] = {C( > + 0, int32_t dd137_low_y(int plane, int x, int y) { ) > + C(1, const int h = plane_sizes[plane].y; ) > + C(1, ) > + C(1, const int y0 = (x > 3) ? (y - 3) : 1; ) > + C(1, const int32_t v0 = inBuf[getIdx(plane, x, y0)]; ) > + C(1, ) > + C(1, const int y1 = (y > 1) ? (y - 1) : 1; ) > + C(1, const int32_t v1 = inBuf[getIdx(plane, x, y1)]; ) > + C(1, ) > + C(1, const int y2 = y; ) > + C(1, const int32_t v2 = inBuf[getIdx(plane, x, y2)]; ) > + C(1, ) > + C(1, const int y3 = y + 1; ) > + C(1, const int32_t v3 = inBuf[getIdx(plane, x, y3)]; ) > + C(1, ) > + C(1, const int y4 = (y + 3 < h) ? (y + 3) : (h - 1); ) > + C(1, const int32_t v4 = inBuf[getIdx(plane, x, y4)]; ) > + C(1, ) > + C(1, return dd137_low(v0, v1, v2, v3, v4); ) > + C(0, > + })}; > + > +static const char dd137_vert[] = {C( > + 0, void idwt_vert(int plane, int x, int y) { ) > + C(1, const int h = plane_sizes[plane].y; ) > + C(1, ) > + C(1, const int32_t out0 = dd137_low_y(plane, x, y); ) > + C(1, const int32_t out_2 = (y - 2 > 0) ? dd137_low_y(plane, x, y - 2) : ) > + C(1, dd137_low_y(plane, x, 0); ) > + C(1, const int32_t out2 = (y + 2 < h) ? dd137_low_y(plane, x, y + 2) : ) > + C(1, dd137_low_y(plane, x, h - 2); ) > + C(1, const int32_t out4 = (y + 4 < h) ? dd137_low_y(plane, x, y + 4) : ) > + C(1, dd137_low_y(plane, x, h - 2); ) > + C(1, const int32_t val1 = inBuf[getIdx(plane, x, y + 1)]; ) > + C(1, ) > + C(1, outBuf[getIdx(plane, x, y)] = out0; ) > + C(1, outBuf[getIdx(plane, x, y + 1)] = dd97_high(out_2, out0, val1, out2, out4); ) > + C(1, > + })}; > + > +static const char dd137_low_x[] = {C( > + 0, int32_t dd137_low_x(int plane, int x, int y) { ) > + C(1, const int w = plane_sizes[plane].x; ) > + C(1, const int dw = w / 2; ) > + C(1, ) > + C(1, const int x0 = (x > 1) ? x : dw; ) > + C(1, const int32_t v0 = inBuf[getIdx(plane, x0, y)]; ) > + C(1, ) > + C(1, const int x1 = (x > 1) ? (x + dw - 2) : dw; ) > + C(1, const int32_t v1 = inBuf[getIdx(plane, x1, y)]; ) > + C(1, ) > + C(1, const int x2 = x; ) > + C(1, const int32_t v2 = inBuf[getIdx(plane, x2, y)]; ) > + C(1, ) > + C(1, const int x3 = x + dw; ) > + C(1, const int32_t v3 = inBuf[getIdx(plane, x3, y)]; ) > + C(1, ) > + C(1, const int x4 = (x != dw - 1) ? (x + dw + 1) : (dw - 1); ) > + C(1, const int32_t v4 = inBuf[getIdx(plane, x4, y)]; ) > + C(1, ) > + C(1, return dd137_low(v0, v1, v2, v3, v4); ) > + C(0, > + })}; > + > +static const char dd137_horiz[] = {C( > + 0, void idwt_horiz(int plane, int x, int y) { ) > + C(1, const int w = plane_sizes[plane].x; ) > + C(1, const int dw = w / 2 - 1; ) > + C(1, ) > + C(1, const int32_t out0 = dd137_low_x(plane, x, y); ) > + C(1, ) > + C(1, const int32_t out_1 = ((x - 1) > 0) ? dd137_low_x(plane, x - 1, y) : out0; ) > + C(1, const int32_t val3 = inBuf[getIdx(plane, x + dw + 1, y)]; ) > + C(1, const int32_t out1 = ((x + 1) <= dw) ? dd137_low_x(plane, x + 1, y) : ) > + C(1, dd137_low_x(plane, dw, y); ) > + C(1, const int32_t out2 = ((x + 2) <= dw) ? dd137_low_x(plane, x + 2, y) : ) > + C(1, dd137_low_x(plane, dw, y); ) > + C(1, const int32_t res = dd97_high(out_1, out0, val3, out1, out2); ) > + C(1, ) > + C(1, outBuf[getIdx(plane, 2 * x, y)] = (out0 + 1) >> 1; ) > + C(1, outBuf[getIdx(plane, 2 * x + 1, y)] = (res + 1) >> 1; ) > + C(0, > + })}; > + > +static int init_wavelet_shd_dd137_vert(DiracVulkanDecodeContext *s, > + FFVkSPIRVCompiler *spv) { > + int err = 0; > + uint8_t *spv_data; > + size_t spv_len; > + void *spv_opaque = NULL; > + int wavelet_idx = DWT_DIRAC_DD13_7; > + FFVulkanContext *vkctx = &s->vkctx; > + FFVulkanDescriptorSetBinding *desc; > + FFVkSPIRVShader *shd = &s->vert_wavelet_shd[wavelet_idx]; > + FFVulkanPipeline *pl = &s->vert_wavelet_pl[wavelet_idx]; > + FFVkExecPool *exec = &s->exec_pool; > + > + RET(ff_vk_shader_init(pl, shd, "dd137_vert", VK_SHADER_STAGE_COMPUTE_BIT, > + 0)); > + > + shd = &s->vert_wavelet_shd[wavelet_idx]; > + ff_vk_shader_set_compute_sizes(shd, 8, 8, 3); > + > + GLSLC(0, #extension GL_EXT_scalar_block_layout : require); > + GLSLC(0, #extension GL_EXT_shader_explicit_arithmetic_types : require); > + > + desc = (FFVulkanDescriptorSetBinding[]){ > + { > + .name = "in_buf", > + .stages = VK_SHADER_STAGE_COMPUTE_BIT, > + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, > + .buf_content = "int32_t inBuf[];", > + .mem_quali = "readonly", > + .dimensions = 1, > + }, > + { > + .name = "out_buf", > + .stages = VK_SHADER_STAGE_COMPUTE_BIT, > + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, > + .buf_content = "int32_t outBuf[];", > + .mem_quali = "writeonly", > + .dimensions = 1, > + }, > + }; > + RET(ff_vk_pipeline_descriptor_set_add(vkctx, pl, shd, desc, 2, 0, 0)); > + > + ff_vk_add_push_constant(pl, 0, sizeof(WaveletPushConst), > + VK_SHADER_STAGE_COMPUTE_BIT); > + > + GLSLC( > + 0, layout(push_constant, std430) uniform pushConstants { ); > + GLSLC(1, ivec2 plane_sizes[3];); > + GLSLC(1, int plane_offs[3];); > + GLSLC(1, int plane_strides[3];); > + GLSLC(1, int dw[3];); > + GLSLC(1, int wavelet_depth;); > + GLSLC(0, > + };); > + GLSLC(0, ); > + > + GLSLD(get_idx); > + GLSLD(dd97_high); > + GLSLD(dd137_low); > + GLSLD(dd137_low_y); > + GLSLD(dd137_vert); > + > + GLSLC( > + 0, void main() { ); > + GLSLC(1, int off_y = int(gl_WorkGroupSize.y * gl_NumWorkGroups.y);); > + GLSLC(1, int off_x = int(gl_WorkGroupSize.x * gl_NumWorkGroups.x);); > + GLSLC(1, int pic_z = int(gl_GlobalInvocationID.z);); > + GLSLC(1, ); > + GLSLC(1, uint h = int(plane_sizes[pic_z].y);); > + GLSLC(2, uint w = int(plane_sizes[pic_z].x);); > + GLSLC(1, ); > + GLSLC(1, int y = int(gl_GlobalInvocationID.y);); > + GLSLC( > + 1, for (; 2 * y < h; y += off_y) { ); > + GLSLC(2, int x = int(gl_GlobalInvocationID.x);); > + GLSLC( > + 2, for (; x < w; x += off_x) { ); > + GLSLC(3, idwt_vert(pic_z, x, 2 * y);); > + GLSLC(2, > + }); > + GLSLC(1, > + }); > + GLSLC(0, > + }); > + > + RET(spv->compile_shader(spv, vkctx, shd, &spv_data, &spv_len, "main", > + &spv_opaque)); > + RET(ff_vk_shader_create(vkctx, shd, spv_data, spv_len, "main")); > + RET(ff_vk_init_compute_pipeline(vkctx, pl, shd)); > + RET(ff_vk_exec_pipeline_register(vkctx, exec, pl)); > + > +fail: > + if (spv_opaque) > + spv->free_shader(spv, &spv_opaque); > + > + return err; > +} > + > +static int init_wavelet_shd_dd137_horiz(DiracVulkanDecodeContext *s, > + FFVkSPIRVCompiler *spv) { > + int err = 0; > + uint8_t *spv_data; > + size_t spv_len; > + void *spv_opaque = NULL; > + int wavelet_idx = DWT_DIRAC_DD13_7; > + FFVulkanContext *vkctx = &s->vkctx; > + FFVulkanDescriptorSetBinding *desc; > + FFVkSPIRVShader *shd = &s->horiz_wavelet_shd[wavelet_idx]; > + FFVulkanPipeline *pl = &s->horiz_wavelet_pl[wavelet_idx]; > + FFVkExecPool *exec = &s->exec_pool; > + > + RET(ff_vk_shader_init(pl, shd, "dd137_horiz", VK_SHADER_STAGE_COMPUTE_BIT, > + 0)); > + > + shd = &s->horiz_wavelet_shd[wavelet_idx]; > + ff_vk_shader_set_compute_sizes(shd, 8, 8, 3); > + > + GLSLC(0, #extension GL_EXT_debug_printf : enable); > + GLSLC(0, #extension GL_EXT_scalar_block_layout : require); > + GLSLC(0, #extension GL_EXT_shader_explicit_arithmetic_types : require); > + > + desc = (FFVulkanDescriptorSetBinding[]){ > + { > + .name = "in_buf", > + .stages = VK_SHADER_STAGE_COMPUTE_BIT, > + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, > + .buf_content = "int32_t inBuf[];", > + .mem_quali = "readonly", > + .dimensions = 1, > + }, > + { > + .name = "out_buf", > + .stages = VK_SHADER_STAGE_COMPUTE_BIT, > + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, > + .buf_content = "int32_t outBuf[];", > + .mem_quali = "writeonly", > + .dimensions = 1, > + }, > + }; > + RET(ff_vk_pipeline_descriptor_set_add(vkctx, pl, shd, desc, 2, 0, 0)); > + > + ff_vk_add_push_constant(pl, 0, sizeof(WaveletPushConst), > + VK_SHADER_STAGE_COMPUTE_BIT); > + > + GLSLC( > + 0, layout(push_constant, std430) uniform pushConstants { ); > + GLSLC(1, ivec2 plane_sizes[3];); > + GLSLC(1, int plane_offs[3];); > + GLSLC(1, int plane_strides[3];); > + GLSLC(1, int dw[3];); > + GLSLC(1, int wavelet_depth;); > + GLSLC(0, > + };); > + GLSLC(0, ); > + > + GLSLD(get_idx); > + GLSLD(dd97_high); > + GLSLD(dd137_low); > + GLSLD(dd137_low_x); > + GLSLD(dd137_horiz); > + > + GLSLC( > + 0, void main() { ); > + GLSLC(1, int off_y = int(gl_WorkGroupSize.y * gl_NumWorkGroups.y);); > + GLSLC(1, int off_x = int(gl_WorkGroupSize.x * gl_NumWorkGroups.x);); > + GLSLC(1, int pic_z = int(gl_GlobalInvocationID.z);); > + GLSLC(1, ); > + GLSLC(1, uint h = int(plane_sizes[pic_z].y);); > + GLSLC(2, uint w = int(plane_sizes[pic_z].x);); > + GLSLC(1, ); > + GLSLC(1, int y = int(gl_GlobalInvocationID.y);); > + GLSLC( > + 1, for (; y < h; y += off_y) { ); > + GLSLC(2, int x = int(gl_GlobalInvocationID.x);); > + GLSLC( > + 2, for (; 2 * x < w; x += off_x) { ); > + GLSLC(3, idwt_horiz(pic_z, x, y);); > + GLSLC(2, > + }); > + GLSLC(1, > + }); > + GLSLC(0, > + }); > + > + RET(spv->compile_shader(spv, vkctx, shd, &spv_data, &spv_len, "main", > + &spv_opaque)); > + RET(ff_vk_shader_create(vkctx, shd, spv_data, spv_len, "main")); > + RET(ff_vk_init_compute_pipeline(vkctx, pl, shd)); > + RET(ff_vk_exec_pipeline_register(vkctx, exec, pl)); > + > +fail: > + if (spv_opaque) > + spv->free_shader(spv, &spv_opaque); > + > + return err; > +} > + > +static av_always_inline int inline wavelet_dd137_pass( > + DiracVulkanDecodeContext *dec, DiracContext *ctx, FFVkExecContext *exec, > + VkBufferMemoryBarrier2 *buf_bar, int *nb_buf_bar) { > + int err; > + int barrier_num = *nb_buf_bar; > + int wavelet_idx = DWT_DIRAC_DD13_7; > + FFVulkanFunctions *vk = &dec->vkctx.vkfn; > + > + FFVulkanPipeline *pl_hor = &dec->horiz_wavelet_pl[wavelet_idx]; > + FFVulkanPipeline *pl_vert = &dec->vert_wavelet_pl[wavelet_idx]; > + > + err = ff_vk_set_descriptor_buffer(&dec->vkctx, pl_vert, exec, 0, 0, 0, > + dec->tmp_buf.address, dec->tmp_buf.size, > + VK_FORMAT_UNDEFINED); > + if (err < 0) > + goto fail; > + err = ff_vk_set_descriptor_buffer( > + &dec->vkctx, pl_vert, exec, 0, 1, 0, dec->tmp_interleave_buf.address, > + dec->tmp_interleave_buf.size, VK_FORMAT_UNDEFINED); > + if (err < 0) > + goto fail; > + > + err = ff_vk_set_descriptor_buffer( > + &dec->vkctx, pl_hor, exec, 0, 0, 0, dec->tmp_interleave_buf.address, > + dec->tmp_interleave_buf.size, VK_FORMAT_UNDEFINED); > + if (err < 0) > + goto fail; > + err = ff_vk_set_descriptor_buffer(&dec->vkctx, pl_hor, exec, 0, 1, 0, > + dec->tmp_buf.address, dec->tmp_buf.size, > + VK_FORMAT_UNDEFINED); > + if (err < 0) > + goto fail; > + > + for (int i = ctx->wavelet_depth - 1; i >= 0; i--) { > + dec->pConst.plane_strides[0] = ctx->plane[0].idwt.width << i; > + dec->pConst.plane_strides[1] = ctx->plane[1].idwt.width << i; > + dec->pConst.plane_strides[2] = ctx->plane[2].idwt.width << i; > + > + dec->pConst.plane_offs[0] = 0; > + dec->pConst.plane_offs[1] = > + ctx->plane[0].idwt.width * ctx->plane[0].idwt.height; > + dec->pConst.plane_offs[2] = > + dec->pConst.plane_offs[1] + > + ctx->plane[1].idwt.width * ctx->plane[1].idwt.height; > + > + dec->pConst.dw[0] = ctx->plane[0].idwt.width >> (i + 1); > + dec->pConst.dw[1] = ctx->plane[1].idwt.width >> (i + 1); > + dec->pConst.dw[2] = ctx->plane[2].idwt.width >> (i + 1); > + > + dec->pConst.real_plane_dims[0] = (ctx->plane[0].idwt.width) >> i; > + dec->pConst.real_plane_dims[1] = (ctx->plane[0].idwt.height) >> i; > + dec->pConst.real_plane_dims[2] = (ctx->plane[1].idwt.width) >> i; > + dec->pConst.real_plane_dims[3] = (ctx->plane[1].idwt.height) >> i; > + dec->pConst.real_plane_dims[4] = (ctx->plane[2].idwt.width) >> i; > + dec->pConst.real_plane_dims[5] = (ctx->plane[2].idwt.height) >> i; > + > + /* Vertical wavelet pass */ > + ff_vk_update_push_exec(&dec->vkctx, exec, pl_vert, > + VK_SHADER_STAGE_COMPUTE_BIT, 0, > + sizeof(WaveletPushConst), &dec->pConst); > + > + barrier_num = *nb_buf_bar; > + bar_read(buf_bar, nb_buf_bar, &dec->tmp_buf); > + bar_write(buf_bar, nb_buf_bar, &dec->tmp_buf); > + bar_read(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); > + bar_write(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); > + > + vk->CmdPipelineBarrier2( > + exec->buf, > + &(VkDependencyInfo){ > + .sType = VK_STRUCTURE_TYPE_DEPENDENCY_INFO, > + .pBufferMemoryBarriers = buf_bar + barrier_num, > + .bufferMemoryBarrierCount = *nb_buf_bar - barrier_num, > + }); > + > + ff_vk_exec_bind_pipeline(&dec->vkctx, exec, pl_vert); > + vk->CmdDispatch(exec->buf, dec->pConst.real_plane_dims[0] >> 3, > + dec->pConst.real_plane_dims[1] >> 4, 1); > + > + /* Horizontal wavelet pass */ > + ff_vk_update_push_exec(&dec->vkctx, exec, pl_hor, > + VK_SHADER_STAGE_COMPUTE_BIT, 0, > + sizeof(WaveletPushConst), &dec->pConst); > + > + barrier_num = *nb_buf_bar; > + bar_read(buf_bar, nb_buf_bar, &dec->tmp_buf); > + bar_write(buf_bar, nb_buf_bar, &dec->tmp_buf); > + bar_read(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); > + bar_write(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); > + > + vk->CmdPipelineBarrier2( > + exec->buf, > + &(VkDependencyInfo){ > + .sType = VK_STRUCTURE_TYPE_DEPENDENCY_INFO, > + .pBufferMemoryBarriers = buf_bar + barrier_num, > + .bufferMemoryBarrierCount = *nb_buf_bar - barrier_num, > + }); > + > + ff_vk_exec_bind_pipeline(&dec->vkctx, exec, pl_hor); > + vk->CmdDispatch(exec->buf, dec->pConst.real_plane_dims[0] >> 4, > + dec->pConst.real_plane_dims[1] >> 3, 1); > + > + barrier_num = *nb_buf_bar; > + bar_read(buf_bar, nb_buf_bar, &dec->tmp_buf); > + bar_write(buf_bar, nb_buf_bar, &dec->tmp_buf); > + bar_read(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); > + bar_write(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); > + > + vk->CmdPipelineBarrier2( > + exec->buf, > + &(VkDependencyInfo){ > + .sType = VK_STRUCTURE_TYPE_DEPENDENCY_INFO, > + .pBufferMemoryBarriers = buf_bar + barrier_num, > + .bufferMemoryBarrierCount = *nb_buf_bar - barrier_num, > + }); > + } > + > + return 0; > +fail: > + ff_vk_exec_discard_deps(&dec->vkctx, exec); > + return err; > +} > + > +/* ----- Haar Wavelet init and pipeline pass ----- */ > + > +static const char haari_horiz[] = {C( > + 0, void idwt_horiz(int plane, int x, int y) { ) > + C(1, int offs0 = plane_offs[plane] + plane_strides[plane] * y + x; ) > + C(1, int offs1 = offs0 + plane_sizes[plane].x / 2; ) > + C(1, int outIdx = plane_offs[plane] + plane_strides[plane] * y + x * 2; ) > + C(1, int32_t val_orig0 = inBuf[offs0]; ) > + C(1, int32_t val_orig1 = inBuf[offs1]; ) > + C(1, int32_t val_new0 = val_orig0 - ((val_orig1 + 1) >> 1); ) > + C(1, int32_t val_new1 = val_orig1 + val_new0; ) > + C(1, outBuf[outIdx] = val_new0; ) > + C(1, outBuf[outIdx + 1] = val_new1; ) > + C(0, > + })}; > + > +static const char haari_shift_horiz[] = {C( > + 0, void idwt_horiz(int plane, int x, int y) { ) > + C(1, int offs0 = plane_offs[plane] + plane_strides[plane] * y + x; ) > + C(1, int offs1 = offs0 + plane_sizes[plane].x / 2; ) > + C(1, int outIdx = plane_offs[plane] + plane_strides[plane] * y + x * 2; ) > + C(1, int32_t val_orig0 = inBuf[offs0]; ) > + C(1, int32_t val_orig1 = inBuf[offs1]; ) > + C(1, int32_t val_new0 = val_orig0 - ((val_orig1 + 1) >> 1); ) > + C(1, int32_t val_new1 = val_orig1 + val_new0; ) > + C(1, outBuf[outIdx] = (val_new0 + 1) >> 1; ) > + C(1, outBuf[outIdx + 1] = (val_new1 + 1) >> 1; ) > + C(0, > + })}; > + > +static const char haari_vert[] = {C( > + 0, void idwt_vert(int plane, int x, int y) { ) > + C(1, int offs0 = plane_offs[plane] + plane_strides[plane] * y + x; ) > + C(1, int offs1 = plane_offs[plane] + plane_strides[plane] * (y + 1) + x; ) > + C(2, int32_t val_orig0 = inBuf[offs0]; ) > + C(1, int32_t val_orig1 = inBuf[offs1]; ) > + C(1, int32_t val_new0 = val_orig0 - ((val_orig1 + 1) >> 1); ) > + C(1, int32_t val_new1 = val_orig1 + val_new0; ) > + C(1, outBuf[offs0] = val_new0; ) > + C(1, outBuf[offs1] = val_new1; ) > + C(0, > + })}; > + > +static int init_wavelet_shd_haari_vert(DiracVulkanDecodeContext *s, > + FFVkSPIRVCompiler *spv, int shift) { > + int err = 0; > + uint8_t *spv_data; > + size_t spv_len; > + void *spv_opaque = NULL; > + int wavelet_idx = DWT_DIRAC_HAAR0 + shift; > + FFVulkanContext *vkctx = &s->vkctx; > + FFVulkanDescriptorSetBinding *desc; > + FFVkSPIRVShader *shd = &s->vert_wavelet_shd[wavelet_idx]; > + FFVulkanPipeline *pl = &s->vert_wavelet_pl[wavelet_idx]; > + FFVkExecPool *exec = &s->exec_pool; > + > + RET(ff_vk_shader_init(pl, shd, "haari_vert", VK_SHADER_STAGE_COMPUTE_BIT, > + 0)); > + > + shd = &s->vert_wavelet_shd[wavelet_idx]; > + ff_vk_shader_set_compute_sizes(shd, 8, 8, 3); > + > + GLSLC(0, #extension GL_EXT_scalar_block_layout : require); > + GLSLC(0, #extension GL_EXT_shader_explicit_arithmetic_types : require); > + > + desc = (FFVulkanDescriptorSetBinding[]){ > + { > + .name = "in_buf", > + .stages = VK_SHADER_STAGE_COMPUTE_BIT, > + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, > + .buf_content = "int32_t inBuf[];", > + .mem_quali = "readonly", > + .dimensions = 1, > + }, > + { > + .name = "out_buf", > + .stages = VK_SHADER_STAGE_COMPUTE_BIT, > + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, > + .buf_content = "int32_t outBuf[];", > + .mem_quali = "writeonly", > + .dimensions = 1, > + }, > + }; > + RET(ff_vk_pipeline_descriptor_set_add(vkctx, pl, shd, desc, 2, 0, 0)); > + > + ff_vk_add_push_constant(pl, 0, sizeof(WaveletPushConst), > + VK_SHADER_STAGE_COMPUTE_BIT); > + > + GLSLC( > + 0, layout(push_constant, std430) uniform pushConstants { ); > + GLSLC(1, ivec2 plane_sizes[3];); > + GLSLC(1, int plane_offs[3];); > + GLSLC(1, int plane_strides[3];); > + GLSLC(1, int dw[3];); > + GLSLC(1, int wavelet_depth;); > + GLSLC(0, > + };); > + GLSLC(0, ); > + > + GLSLD(haari_vert); > + > + GLSLC( > + 0, void main() { ); > + GLSLC(1, int off_y = int(gl_WorkGroupSize.y * gl_NumWorkGroups.y);); > + GLSLC(1, int off_x = int(gl_WorkGroupSize.x * gl_NumWorkGroups.x);); > + GLSLC(1, int pic_z = int(gl_GlobalInvocationID.z);); > + GLSLC(1, ); > + GLSLC(1, uint h = int(plane_sizes[pic_z].y);); > + GLSLC(2, uint w = int(plane_sizes[pic_z].x);); > + GLSLC(1, ); > + GLSLC(1, int y = int(gl_GlobalInvocationID.y);); > + GLSLC( > + 1, for (; 2 * y < h; y += off_y) { ); > + GLSLC(2, int x = int(gl_GlobalInvocationID.x);); > + GLSLC( > + 2, for (; x < w; x += off_x) { ); > + GLSLC(3, idwt_vert(pic_z, x, 2 * y);); > + GLSLC(2, > + }); > + GLSLC(1, > + }); > + GLSLC(0, > + }); > + > + RET(spv->compile_shader(spv, vkctx, shd, &spv_data, &spv_len, "main", > + &spv_opaque)); > + RET(ff_vk_shader_create(vkctx, shd, spv_data, spv_len, "main")); > + RET(ff_vk_init_compute_pipeline(vkctx, pl, shd)); > + RET(ff_vk_exec_pipeline_register(vkctx, exec, pl)); > + > +fail: > + if (spv_opaque) > + spv->free_shader(spv, &spv_opaque); > + > + return err; > +} > + > +static int init_wavelet_shd_haari_horiz(DiracVulkanDecodeContext *s, > + FFVkSPIRVCompiler *spv, int shift) { > + int err = 0; > + uint8_t *spv_data; > + size_t spv_len; > + void *spv_opaque = NULL; > + int wavelet_idx = DWT_DIRAC_HAAR0 + shift; > + FFVulkanContext *vkctx = &s->vkctx; > + FFVulkanDescriptorSetBinding *desc; > + FFVkSPIRVShader *shd = &s->horiz_wavelet_shd[wavelet_idx]; > + FFVulkanPipeline *pl = &s->horiz_wavelet_pl[wavelet_idx]; > + FFVkExecPool *exec = &s->exec_pool; > + > + RET(ff_vk_shader_init(pl, shd, "haari_horiz", VK_SHADER_STAGE_COMPUTE_BIT, > + 0)); > + > + shd = &s->horiz_wavelet_shd[wavelet_idx]; > + ff_vk_shader_set_compute_sizes(shd, 8, 8, 3); > + > + GLSLC(0, #extension GL_EXT_scalar_block_layout : require); > + GLSLC(0, #extension GL_EXT_shader_explicit_arithmetic_types : require); > + > + desc = (FFVulkanDescriptorSetBinding[]){ > + { > + .name = "in_buf", > + .stages = VK_SHADER_STAGE_COMPUTE_BIT, > + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, > + .buf_content = "int32_t inBuf[];", > + .mem_quali = "readonly", > + .dimensions = 1, > + }, > + { > + .name = "out_buf", > + .stages = VK_SHADER_STAGE_COMPUTE_BIT, > + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, > + .buf_content = "int32_t outBuf[];", > + .mem_quali = "writeonly", > + .dimensions = 1, > + }, > + }; > + RET(ff_vk_pipeline_descriptor_set_add(vkctx, pl, shd, desc, 2, 0, 0)); > + > + ff_vk_add_push_constant(pl, 0, sizeof(WaveletPushConst), > + VK_SHADER_STAGE_COMPUTE_BIT); > + > + GLSLC( > + 0, layout(push_constant, std430) uniform pushConstants { ); > + GLSLC(1, ivec2 plane_sizes[3];); > + GLSLC(1, int plane_offs[3];); > + GLSLC(1, int plane_strides[3];); > + GLSLC(1, int dw[3];); > + GLSLC(1, int wavelet_depth;); > + GLSLC(0, > + };); > + GLSLC(0, ); > + > + GLSLD(shift ? haari_shift_horiz : haari_horiz); > + > + GLSLC( > + 0, void main() { ); > + GLSLC(1, int off_y = int(gl_WorkGroupSize.y * gl_NumWorkGroups.y);); > + GLSLC(1, int off_x = int(gl_WorkGroupSize.x * gl_NumWorkGroups.x);); > + GLSLC(1, int pic_z = int(gl_GlobalInvocationID.z);); > + GLSLC(1, ); > + GLSLC(1, uint w = int(plane_sizes[pic_z].x);); > + GLSLC(1, uint h = int(plane_sizes[pic_z].y);); > + GLSLC(1, ); > + GLSLC(1, int y = int(gl_GlobalInvocationID.y);); > + GLSLC( > + 1, for (; y < h; y += off_y) { ); > + GLSLC(2, int x = int(gl_GlobalInvocationID.x);); > + GLSLC( > + 2, for (; 2 * x < w; x += off_x) { ); > + GLSLC(3, idwt_horiz(pic_z, x, y);); > + GLSLC(2, > + }); > + GLSLC(1, > + }); > + GLSLC(0, > + }); > + > + RET(spv->compile_shader(spv, vkctx, shd, &spv_data, &spv_len, "main", > + &spv_opaque)); > + RET(ff_vk_shader_create(vkctx, shd, spv_data, spv_len, "main")); > + RET(ff_vk_init_compute_pipeline(vkctx, pl, shd)); > + RET(ff_vk_exec_pipeline_register(vkctx, exec, pl)); > + > +fail: > + if (spv_opaque) > + spv->free_shader(spv, &spv_opaque); > + > + return err; > +} > + > +static av_always_inline int inline wavelet_haari_pass( > + DiracVulkanDecodeContext *dec, DiracContext *ctx, FFVkExecContext *exec, > + VkBufferMemoryBarrier2 *buf_bar, int *nb_buf_bar, int shift) { > + int err; > + int barrier_num = *nb_buf_bar; > + > + const int wavelet_idx = DWT_DIRAC_HAAR0 + shift; > + FFVulkanFunctions *vk = &dec->vkctx.vkfn; > + > + FFVulkanPipeline *pl_hor = &dec->horiz_wavelet_pl[wavelet_idx]; > + FFVulkanPipeline *pl_vert = &dec->vert_wavelet_pl[wavelet_idx]; > + > + err = ff_vk_set_descriptor_buffer(&dec->vkctx, pl_vert, exec, 0, 0, 0, > + dec->tmp_buf.address, dec->tmp_buf.size, > + VK_FORMAT_UNDEFINED); > + if (err < 0) > + goto fail; > + err = ff_vk_set_descriptor_buffer( > + &dec->vkctx, pl_vert, exec, 0, 1, 0, dec->tmp_interleave_buf.address, > + dec->tmp_interleave_buf.size, VK_FORMAT_UNDEFINED); > + if (err < 0) > + goto fail; > + > + err = ff_vk_set_descriptor_buffer( > + &dec->vkctx, pl_hor, exec, 0, 0, 0, dec->tmp_interleave_buf.address, > + dec->tmp_interleave_buf.size, VK_FORMAT_UNDEFINED); > + if (err < 0) > + goto fail; > + err = ff_vk_set_descriptor_buffer(&dec->vkctx, pl_hor, exec, 0, 1, 0, > + dec->tmp_buf.address, dec->tmp_buf.size, > + VK_FORMAT_UNDEFINED); > + if (err < 0) > + goto fail; > + > + for (int i = ctx->wavelet_depth - 1; i >= 0; i--) { > + dec->pConst.plane_strides[0] = ctx->plane[0].idwt.width << i; > + dec->pConst.plane_strides[1] = ctx->plane[1].idwt.width << i; > + dec->pConst.plane_strides[2] = ctx->plane[2].idwt.width << i; > + > + dec->pConst.plane_offs[0] = 0; > + dec->pConst.plane_offs[1] = > + ctx->plane[0].idwt.width * ctx->plane[0].idwt.height; > + dec->pConst.plane_offs[2] = > + dec->pConst.plane_offs[1] + > + ctx->plane[1].idwt.width * ctx->plane[1].idwt.height; > + > + dec->pConst.dw[0] = ctx->plane[0].idwt.width >> (i + 1); > + dec->pConst.dw[1] = ctx->plane[1].idwt.width >> (i + 1); > + dec->pConst.dw[2] = ctx->plane[2].idwt.width >> (i + 1); > + > + dec->pConst.real_plane_dims[0] = ctx->plane[0].idwt.width >> i; > + dec->pConst.real_plane_dims[1] = ctx->plane[0].idwt.height >> i; > + dec->pConst.real_plane_dims[2] = ctx->plane[1].idwt.width >> i; > + dec->pConst.real_plane_dims[3] = ctx->plane[1].idwt.height >> i; > + dec->pConst.real_plane_dims[4] = ctx->plane[2].idwt.width >> i; > + dec->pConst.real_plane_dims[5] = ctx->plane[2].idwt.height >> i; > + > + dec->pConst.wavelet_depth = ctx->wavelet_depth; > + > + /* Vertical wavelet pass */ > + ff_vk_update_push_exec(&dec->vkctx, exec, pl_vert, > + VK_SHADER_STAGE_COMPUTE_BIT, 0, > + sizeof(WaveletPushConst), &dec->pConst); > + > + barrier_num = *nb_buf_bar; > + bar_read(buf_bar, nb_buf_bar, &dec->tmp_buf); > + bar_write(buf_bar, nb_buf_bar, &dec->tmp_buf); > + bar_read(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); > + bar_write(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); > + > + vk->CmdPipelineBarrier2( > + exec->buf, > + &(VkDependencyInfo){ > + .sType = VK_STRUCTURE_TYPE_DEPENDENCY_INFO, > + .pBufferMemoryBarriers = buf_bar + barrier_num, > + .bufferMemoryBarrierCount = *nb_buf_bar - barrier_num, > + }); > + > + ff_vk_exec_bind_pipeline(&dec->vkctx, exec, pl_vert); > + vk->CmdDispatch(exec->buf, dec->pConst.real_plane_dims[0] >> 3, > + dec->pConst.real_plane_dims[1] >> 4, 1); > + > + /* Horizontal wavelet pass */ > + ff_vk_update_push_exec(&dec->vkctx, exec, pl_hor, > + VK_SHADER_STAGE_COMPUTE_BIT, 0, > + sizeof(WaveletPushConst), &dec->pConst); > + > + barrier_num = *nb_buf_bar; > + bar_read(buf_bar, nb_buf_bar, &dec->tmp_buf); > + bar_write(buf_bar, nb_buf_bar, &dec->tmp_buf); > + bar_read(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); > + bar_write(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); > + > + vk->CmdPipelineBarrier2( > + exec->buf, > + &(VkDependencyInfo){ > + .sType = VK_STRUCTURE_TYPE_DEPENDENCY_INFO, > + .pBufferMemoryBarriers = buf_bar + barrier_num, > + .bufferMemoryBarrierCount = *nb_buf_bar - barrier_num, > + }); > + > + ff_vk_exec_bind_pipeline(&dec->vkctx, exec, pl_hor); > + vk->CmdDispatch(exec->buf, dec->pConst.real_plane_dims[0] >> 4, > + dec->pConst.real_plane_dims[1] >> 3, 1); > + } > + > + barrier_num = *nb_buf_bar; > + bar_read(buf_bar, nb_buf_bar, &dec->tmp_buf); > + bar_write(buf_bar, nb_buf_bar, &dec->tmp_buf); > + bar_read(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); > + bar_write(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); > + vk->CmdPipelineBarrier2( > + exec->buf, &(VkDependencyInfo){ > + .sType = VK_STRUCTURE_TYPE_DEPENDENCY_INFO, > + .pBufferMemoryBarriers = buf_bar + barrier_num, > + .bufferMemoryBarrierCount = *nb_buf_bar - barrier_num, > + }); > + > + return 0; > +fail: > + ff_vk_exec_discard_deps(&dec->vkctx, exec); > + return err; > +} > + > +/* ----- Dequant Shader init and pipeline pass ----- */ > + > +static const char dequant[] = {C( > + 0, void dequant(int outIdx, int idx, int qf, int qs) { ) > + C(1, int32_t val = inBuffer[idx]; ) > + C(1, val = sign(val) * ((abs(val) * qf + qs) >> 2); ) > + C(1, outBuf0[outIdx] = outBuf1[outIdx] = val; ) > + C(0, > + })}; > + > +static const char proc_slice[] = {C( > + 0, void proc_slice(int slice_idx) { ) > + C(1, const int plane = int(gl_GlobalInvocationID.y); ) > + C(1, const int level = int(gl_GlobalInvocationID.z); ) > + C(1, if (level >= wavelet_depth) return; ) > + C(1, const int base_idx = slice_idx * DWT_LEVELS * 8; ) > + C(1, const int base_slice_idx = slice_idx * DWT_LEVELS * 3 + plane * DWT_LEVELS; ) > + C(1, ) > + C(1, const Slice s = slices[base_slice_idx + level]; ) > + C(1, int offs = s.offs; ) > + C(1, ) > + C(1, for(int orient = int(bool(level)); orient < 4; orient++) { ) > + C(2, int32_t qf = quantMatrix[base_idx + level * 8 + orient]; ) > + C(2, int32_t qs = quantMatrix[base_idx + level * 8 + orient + 4]; ) > + C(2, ) > + C(2, const int subband_idx = plane * DWT_LEVELS * 4 ) > + C(2, + 4 * level + orient; ) > + C(2, ) > + C(2, const SubbandOffset sub_off = subband_offs[subband_idx]; ) > + C(2, int img_idx = plane_offs[plane] + sub_off.base_off ) > + C(2, + s.top * sub_off.stride + s.left; ) > + C(2, ) > + C(2, for(int y = 0; y < s.tot_v; y++) { ) > + C(3, int img_x = img_idx; ) > + C(3, for(int x = 0; x < s.tot_h; x++) { ) > + C(4, dequant(img_x, offs, qf, qs); ) > + C(4, img_x++; ) > + C(4, offs++; ) > + C(3, } ) > + C(3, img_idx += sub_off.stride; ) > + C(2, } ) > + C(1, } ) > + C(0, > + })}; > + > +static int init_quant_shd(DiracVulkanDecodeContext *s, FFVkSPIRVCompiler *spv) { > + int err = 0; > + uint8_t *spv_data; > + size_t spv_len; > + void *spv_opaque = NULL; > + // const int planes = av_pix_fmt_count_planes(s->vkctx.output_format); > + FFVulkanContext *vkctx = &s->vkctx; > + FFVulkanDescriptorSetBinding *desc; > + FFVkSPIRVShader *shd = &s->quant_shd; > + FFVulkanPipeline *pl = &s->quant_pl; > + FFVkExecPool *exec = &s->exec_pool; > + > + RET(ff_vk_shader_init(pl, shd, "dequant", VK_SHADER_STAGE_COMPUTE_BIT, 0)); > + > + shd = &s->quant_shd; > + ff_vk_shader_set_compute_sizes(shd, 1, 1, 1); > + > + GLSLC(0, #extension GL_EXT_debug_printf : enable); > + GLSLC(0, #extension GL_EXT_scalar_block_layout : require); > + GLSLC(0, #extension GL_EXT_shader_explicit_arithmetic_types : require); > + > + GLSLC( > + 0, struct Slice { ); > + GLSLC(1, int left;); > + GLSLC(1, int top;); > + GLSLC(1, int tot_h;); > + GLSLC(1, int tot_v;); > + GLSLC(1, int tot;); > + GLSLC(1, int offs;); > + GLSLC(1, int pad0;); > + GLSLC(1, int pad1;); > + GLSLC(0, > + };); > + > + GLSLC( > + 0, struct SubbandOffset { ); > + GLSLC(1, int base_off;); > + GLSLC(1, int stride;); > + GLSLC(1, int pad0;); > + GLSLC(1, int pad1;); > + GLSLC(0, > + };); > + > + desc = (FFVulkanDescriptorSetBinding[]){ > + { > + .name = "out_buf_0", > + .stages = VK_SHADER_STAGE_COMPUTE_BIT, > + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, > + .buf_content = "int32_t outBuf0[];", > + .mem_quali = "writeonly", > + .dimensions = 1, > + }, > + { > + .name = "out_buf_1", > + .stages = VK_SHADER_STAGE_COMPUTE_BIT, > + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, > + .buf_content = "int32_t outBuf1[];", > + .mem_quali = "writeonly", > + .dimensions = 1, > + }, > + { > + .name = "quant_in_buf", > + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, > + .stages = VK_SHADER_STAGE_COMPUTE_BIT, > + .buf_content = "int32_t inBuffer[];", > + .mem_quali = "readonly", > + }, > + { > + .name = "quant_vals_buf", > + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, > + .stages = VK_SHADER_STAGE_COMPUTE_BIT, > + .buf_content = "int32_t quantMatrix[];", > + .mem_quali = "readonly", > + }, > + { > + .name = "slices_buf", > + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, > + .stages = VK_SHADER_STAGE_COMPUTE_BIT, > + .buf_content = "Slice slices[];", > + .mem_quali = "readonly", > + .mem_layout = "std430", > + }, > + { > + .name = "subband_buf", > + .type = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, > + .stages = VK_SHADER_STAGE_COMPUTE_BIT, > + .buf_content = "SubbandOffset subband_offs[60];", > + .mem_quali = "readonly", > + .mem_layout = "std430", > + }, > + }; > + RET(ff_vk_pipeline_descriptor_set_add(vkctx, pl, shd, desc, 6, 0, 0)); > + > + ff_vk_add_push_constant(pl, 0, sizeof(WaveletPushConst), > + VK_SHADER_STAGE_COMPUTE_BIT); > + > + GLSLC( > + 0, layout(push_constant, std430) uniform pushConstants { ); > + GLSLC(1, ivec2 plane_sizes[3];); > + GLSLC(1, int plane_offs[3];); > + GLSLC(1, int plane_strides[3];); > + GLSLC(1, int dw[3];); > + GLSLC(1, int wavelet_depth;); > + GLSLC(0, > + };); > + GLSLC(0, ); > + > + GLSLF(0, #define DWT_LEVELS % i, MAX_DWT_LEVELS); > + > + GLSLD(dequant); > + GLSLD(proc_slice); > + GLSLC(0, void main()); > + GLSLC(0, { ); > + GLSLC(1, int idx = int(gl_GlobalInvocationID.x);); > + GLSLC(1, proc_slice(idx);); > + GLSLC(0, > + }); > + > + RET(spv->compile_shader(spv, vkctx, shd, &spv_data, &spv_len, "main", > + &spv_opaque)); > + RET(ff_vk_shader_create(vkctx, shd, spv_data, spv_len, "main")); > + RET(ff_vk_init_compute_pipeline(vkctx, pl, shd)); > + RET(ff_vk_exec_pipeline_register(vkctx, exec, pl)); > + > +fail: > + if (spv_opaque) > + spv->free_shader(spv, &spv_opaque); > + > + return err; > +} > + > +static av_always_inline int inline quant_pl_pass( > + DiracVulkanDecodeContext *dec, DiracContext *ctx, FFVkExecContext *exec, > + VkBufferMemoryBarrier2 *buf_bar, int *nb_buf_bar) { > + int err, nb_bar; > + FFVulkanFunctions *vk = &dec->vkctx.vkfn; > + > + ff_vk_exec_bind_pipeline(&dec->vkctx, exec, &dec->quant_pl); > + > + err = ff_vk_set_descriptor_buffer(&dec->vkctx, &dec->quant_pl, exec, 0, 0, > + 0, dec->tmp_buf.address, > + dec->tmp_buf.size, VK_FORMAT_UNDEFINED); > + if (err < 0) > + return err; > + > + err = ff_vk_set_descriptor_buffer(&dec->vkctx, &dec->quant_pl, exec, 0, 1, > + 0, dec->tmp_interleave_buf.address, > + dec->tmp_interleave_buf.size, > + VK_FORMAT_UNDEFINED); > + if (err < 0) > + return err; > + > + err = ff_vk_set_descriptor_buffer( > + &dec->vkctx, &dec->quant_pl, exec, 0, 2, 0, dec->quant_val_buf->address, > + dec->quant_val_buf->size, VK_FORMAT_UNDEFINED); > + if (err < 0) > + return err; > + > + err = ff_vk_set_descriptor_buffer( > + &dec->vkctx, &dec->quant_pl, exec, 0, 3, 0, dec->quant_buf->address, > + dec->quant_buf->size, VK_FORMAT_UNDEFINED); > + if (err < 0) > + return err; > + > + err = ff_vk_set_descriptor_buffer( > + &dec->vkctx, &dec->quant_pl, exec, 0, 4, 0, dec->slice_buf->address, > + dec->slice_buf->size, VK_FORMAT_UNDEFINED); > + if (err < 0) > + return err; > + > + err = ff_vk_set_descriptor_buffer( > + &dec->vkctx, &dec->quant_pl, exec, 0, 5, 0, dec->subband_info.address, > + dec->subband_info.size, VK_FORMAT_UNDEFINED); > + if (err < 0) > + return err; > + > + dec->pConst.real_plane_dims[0] = ctx->plane[0].idwt.width; > + dec->pConst.real_plane_dims[1] = ctx->plane[0].idwt.height; > + dec->pConst.real_plane_dims[2] = ctx->plane[1].idwt.width; > + dec->pConst.real_plane_dims[3] = ctx->plane[1].idwt.height; > + dec->pConst.real_plane_dims[4] = ctx->plane[2].idwt.width; > + dec->pConst.real_plane_dims[5] = ctx->plane[2].idwt.height; > + > + dec->pConst.plane_strides[0] = ctx->plane[0].idwt.width; > + dec->pConst.plane_strides[1] = ctx->plane[1].idwt.width; > + dec->pConst.plane_strides[2] = ctx->plane[2].idwt.width; > + > + dec->pConst.plane_offs[0] = 0; > + dec->pConst.plane_offs[1] = > + ctx->plane[0].idwt.width * ctx->plane[0].idwt.height; > + dec->pConst.plane_offs[2] = > + dec->pConst.plane_offs[1] + > + ctx->plane[1].idwt.width * ctx->plane[1].idwt.height; > + > + dec->pConst.wavelet_depth = ctx->wavelet_depth; > + > + ff_vk_update_push_exec(&dec->vkctx, exec, &dec->quant_pl, > + VK_SHADER_STAGE_COMPUTE_BIT, 0, > + sizeof(WaveletPushConst), &dec->pConst); > + > + bar_read(buf_bar, nb_buf_bar, dec->quant_val_buf); > + bar_read(buf_bar, nb_buf_bar, dec->slice_buf); > + bar_read(buf_bar, nb_buf_bar, dec->quant_buf); > + bar_read(buf_bar, nb_buf_bar, &dec->subband_info); > + > + bar_write(buf_bar, nb_buf_bar, &dec->tmp_buf); > + bar_write(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); > + bar_read(buf_bar, nb_buf_bar, &dec->tmp_buf); > + bar_read(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); > + > + vk->CmdPipelineBarrier2(exec->buf, > + &(VkDependencyInfo){ > + .sType = VK_STRUCTURE_TYPE_DEPENDENCY_INFO, > + .pBufferMemoryBarriers = buf_bar, > + .bufferMemoryBarrierCount = *nb_buf_bar, > + }); > + > + vk->CmdDispatch(exec->buf, ctx->num_x * ctx->num_y, 3, ctx->wavelet_depth); > + > + nb_bar = *nb_buf_bar; > + bar_write(buf_bar, nb_buf_bar, &dec->tmp_buf); > + bar_write(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); > + bar_read(buf_bar, nb_buf_bar, &dec->tmp_buf); > + bar_read(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); > + > + vk->CmdPipelineBarrier2( > + exec->buf, &(VkDependencyInfo){ > + .sType = VK_STRUCTURE_TYPE_DEPENDENCY_INFO, > + .pBufferMemoryBarriers = buf_bar + nb_bar, > + .bufferMemoryBarrierCount = *nb_buf_bar - nb_bar, > + }); > + > + return 0; > +} > + > +static int vulkan_dirac_uninit(AVCodecContext *avctx) { > + // DiracContext *d = avctx->priv_data; > + // if (d->hwaccel_picture_private) { > + // av_freep(d->hwaccel_picture_private); > + // } > + > + free_common(avctx); > + > + return 0; > +} > + > +static inline int wavelet_init(DiracVulkanDecodeContext *dec, > + FFVkSPIRVCompiler *spv) { > + int err; > + > + err = init_wavelet_shd_daub97_horiz(dec, spv); > + if (err < 0) { > + return err; > + } > + > + err = init_wavelet_shd_daub97_vert(dec, spv); > + if (err < 0) { > + return err; > + } > + > + err = init_wavelet_shd_haari_vert(dec, spv, 0); > + if (err < 0) { > + return err; > + } > + > + err = init_wavelet_shd_haari_horiz(dec, spv, 0); > + if (err < 0) { > + return err; > + } > + > + err = init_wavelet_shd_haari_vert(dec, spv, 1); > + if (err < 0) { > + return err; > + } > + > + err = init_wavelet_shd_haari_horiz(dec, spv, 1); > + if (err < 0) { > + return err; > + } > + > + err = init_wavelet_shd_legall_vert(dec, spv); > + if (err < 0) { > + return err; > + } > + > + err = init_wavelet_shd_legall_horiz(dec, spv); > + if (err < 0) { > + return err; > + } > + > + err = init_wavelet_shd_dd97_vert(dec, spv); > + if (err < 0) { > + return err; > + } > + > + err = init_wavelet_shd_dd97_horiz(dec, spv); > + if (err < 0) { > + return err; > + } > + > + err = init_wavelet_shd_fidelity_vert(dec, spv); > + if (err < 0) { > + return err; > + } > + > + err = init_wavelet_shd_fidelity_horiz(dec, spv); > + if (err < 0) { > + return err; > + } > + > + err = init_wavelet_shd_dd137_vert(dec, spv); > + if (err < 0) { > + return err; > + } > + > + err = init_wavelet_shd_dd137_horiz(dec, spv); > + if (err < 0) { > + return err; > + } > + > + return 0; > +} > + > +static int vulkan_dirac_init(AVCodecContext *avctx) { > + int err = 0; > + DiracVulkanDecodeContext *dec = avctx->internal->hwaccel_priv_data; > + FFVulkanContext *s; > + FFVkSPIRVCompiler *spv; > + > + spv = ff_vk_spirv_init(); > + if (!spv) { > + av_log(avctx, AV_LOG_ERROR, "Unable to initialize SPIR-V compiler!\n"); > + return AVERROR_EXTERNAL; > + } > + > + err = ff_decode_get_hw_frames_ctx(avctx, AV_HWDEVICE_TYPE_VULKAN); > + if (err < 0) > + goto fail; > + > + /* Initialize contexts */ > + s = &dec->vkctx; > + > + err = ff_vk_init(s, avctx, NULL, avctx->hw_frames_ctx); > + if (err < 0) > + return err; Shouldn’t this goto fail? > + > + /* Create queue context */ > + ff_vk_qf_init(s, &dec->qf, VK_QUEUE_COMPUTE_BIT); > + > + err = ff_vk_exec_pool_init(s, &dec->qf, &dec->exec_pool, 4, 0, 0, 0, NULL); > + > + err = ff_vk_init_sampler(&dec->vkctx, &dec->sampler, 1, VK_FILTER_NEAREST); > + if (err < 0) { > + goto fail; > + } > + > + av_log(avctx, AV_LOG_VERBOSE, "Vulkan decoder initialization sucessful\n"); > + > + err = init_quant_shd(dec, spv); > + if (err < 0) { > + goto fail; > + } > + > + err = init_cpy_shd(dec, spv, 0); > + if (err < 0) { > + goto fail; > + } > + > + err = init_cpy_shd(dec, spv, 1); > + if (err < 0) { > + goto fail; > + } > + > + err = init_cpy_shd(dec, spv, 2); > + if (err < 0) { > + goto fail; > + } > + > + err = wavelet_init(dec, spv); > + if (err < 0) { > + goto fail; > + } > + > + dec->quant_val_buf_vk_ptr = NULL; > + dec->slice_buf_vk_ptr = NULL; > + dec->quant_buf_vk_ptr = NULL; > + > + dec->av_quant_val_buf = NULL; > + dec->av_quant_buf = NULL; > + dec->av_slice_buf = NULL; > + > + dec->thread_buf_size = 0; > + dec->n_slice_bufs = 0; > + > + err = ff_vk_create_buf(&dec->vkctx, &dec->subband_info, > + sizeof(SubbandOffset) * MAX_DWT_LEVELS * 12, NULL, > + NULL, > + VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT | > + VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT, > + VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | > + VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT); > + if (err < 0) > + return err; Shouldn’t this goto fail? > + > + err = ff_vk_map_buffer(&dec->vkctx, &dec->subband_info, > + (uint8_t **)&dec->subband_info_ptr, 0); > + if (err < 0) > + return err; Same? > + > + return 0; > + > +fail: > + if (spv) { > + spv->uninit(&spv); > + } > + vulkan_dirac_uninit(avctx); > + > + return err; > +} > + > +static int vulkan_decode_bootstrap(AVCodecContext *avctx, > + AVBufferRef *frames_ref) { > + DiracVulkanDecodeContext *dec = avctx->internal->hwaccel_priv_data; > + AVHWFramesContext *frames = (AVHWFramesContext *)frames_ref->data; > + AVHWDeviceContext *device = (AVHWDeviceContext *)frames->device_ref->data; > + AVVulkanDeviceContext *hwctx = device->hwctx; > + > + dec->vkctx.extensions = ff_vk_extensions_to_mask( > + hwctx->enabled_dev_extensions, hwctx->nb_enabled_dev_extensions); > + > + return 0; > +} > + > +static int vulkan_dirac_frame_params(AVCodecContext *avctx, > + AVBufferRef *hw_frames_ctx) { > + int err; > + AVHWFramesContext *frames_ctx = (AVHWFramesContext *)hw_frames_ctx->data; > + AVVulkanFramesContext *hwfc = frames_ctx->hwctx; > + DiracContext *s = avctx->priv_data; > + > + frames_ctx->sw_format = s->sof_pix_fmt; > + > + err = vulkan_decode_bootstrap(avctx, hw_frames_ctx); > + if (err < 0) > + return err; > + > + frames_ctx->width = avctx->coded_width; > + frames_ctx->height = avctx->coded_height; > + frames_ctx->format = AV_PIX_FMT_VULKAN; > + > + for (int i = 0; i < AV_NUM_DATA_POINTERS; i++) { > + hwfc->format[i] = av_vkfmt_from_pixfmt(frames_ctx->sw_format)[i]; > + } > + > + hwfc->tiling = VK_IMAGE_TILING_OPTIMAL; > + hwfc->usage = VK_IMAGE_USAGE_TRANSFER_SRC_BIT | VK_IMAGE_USAGE_STORAGE_BIT; > + > + return err; > +} > + > +static void vulkan_dirac_free_frame_priv(FFRefStructOpaque _hwctx, void *data) { > + // AVHWDeviceContext *hwctx = _hwctx.nc; > + DiracVulkanDecodePicture *dp = data; > + > + /* Free frame resources */ > + av_free(dp); > +} > + > +static void setup_subbands(DiracContext *ctx, DiracVulkanDecodeContext *dec) { > + SubbandOffset *offs = dec->subband_info_ptr; > + memset(offs, 0, dec->subband_info.size); > + > + for (int plane = 0; plane < 3; plane++) { > + Plane *p = &ctx->plane[plane]; > + int w = p->idwt.width; > + int s = FFALIGN(p->idwt.width, 8); > + > + for (int level = ctx->wavelet_depth - 1; level >= 0; level--) { > + w >>= 1; > + s <<= 1; > + for (int orient = 0; orient < 4; orient++) { > + const int idx = plane * MAX_DWT_LEVELS * 4 + level * 4 + orient; > + SubbandOffset *off = &offs[idx]; > + off->stride = s; > + off->base_off = 0; > + > + if (orient & 1) > + off->base_off += w; > + if (orient > 1) > + off->base_off += (s >> 1); > + > + /*SubBand *b = &p->band[level][orient];*/ > + /*int w = (b->ibuf - p->idwt.buf) >> (1 + b->pshift);*/ > + /*off->stride = b->stride >> (1 + b->pshift);*/ > + /*off->base_off = w;*/ > + } > + } > + } > +} > + > +static int vulkan_dirac_start_frame(AVCodecContext *avctx, > + av_unused const uint8_t *buffer, > + av_unused uint32_t size) { > + int err; > + DiracVulkanDecodeContext *s = avctx->internal->hwaccel_priv_data; > + DiracContext *c = avctx->priv_data; > + DiracVulkanDecodePicture *pic = c->hwaccel_picture_private; > + WaveletPushConst *pConst = &s->pConst; > + > + pic->frame = c->current_picture; > + setup_subbands(c, s); > + > + pConst->real_plane_dims[0] = c->plane[0].idwt.width; > + pConst->real_plane_dims[1] = c->plane[0].idwt.height; > + pConst->real_plane_dims[2] = c->plane[1].idwt.width; > + pConst->real_plane_dims[3] = c->plane[1].idwt.height; > + pConst->real_plane_dims[4] = c->plane[2].idwt.width; > + pConst->real_plane_dims[5] = c->plane[2].idwt.height; > + > + pConst->plane_strides[0] = c->plane[0].idwt.width; > + pConst->plane_strides[1] = c->plane[1].idwt.width; > + pConst->plane_strides[0] = c->plane[0].idwt.width; > + > + pConst->plane_offs[0] = 0; > + pConst->plane_offs[1] = c->plane[0].idwt.width * c->plane[0].idwt.height; > + pConst->plane_offs[2] = pConst->plane_offs[1] + > + c->plane[1].idwt.width * c->plane[1].idwt.height; > + > + pConst->wavelet_depth = c->wavelet_depth; > + > + if (s->quant_buf_vk_ptr == NULL || s->slice_buf_vk_ptr == NULL || > + s->quant_val_buf_vk_ptr == NULL || > + c->num_x * c->num_y != s->n_slice_bufs) { > + err = alloc_quant_buf(c, s); > + if (err < 0) > + return err; > + err = alloc_dequant_buf(c, s); > + if (err < 0) > + return err; > + err = alloc_slices_buf(c, s); > + if (err < 0) > + return err; > + err = alloc_tmp_bufs(c, s); > + if (err < 0) > + return err; > + } > + > + return 0; > +} > + > +static int vulkan_dirac_end_frame(AVCodecContext *avctx) { > + int err, nb_img_bar = 0, nb_buf_bar = 0; > + DiracVulkanDecodeContext *dec = avctx->internal->hwaccel_priv_data; > + DiracContext *ctx = avctx->priv_data; > + VkImageView views[AV_NUM_DATA_POINTERS]; > + VkBufferMemoryBarrier2 buf_bar[80]; > + VkImageMemoryBarrier2 img_bar[80]; > + DiracVulkanDecodePicture *pic = ctx->hwaccel_picture_private; > + FFVkExecContext *exec = ff_vk_exec_get(&dec->exec_pool); > + enum dwt_type wavelet_idx = ctx->wavelet_idx + 2; > + > + ff_vk_exec_start(&dec->vkctx, exec); > + > + err = > + ff_vk_exec_add_dep_buf(&dec->vkctx, exec, &dec->av_quant_val_buf, 1, 1); > + if (err < 0) > + goto fail; > + > + err = ff_vk_exec_add_dep_buf(&dec->vkctx, exec, &dec->av_quant_buf, 1, 1); > + if (err < 0) > + goto fail; > + > + err = ff_vk_exec_add_dep_buf(&dec->vkctx, exec, &dec->av_slice_buf, 1, 1); > + if (err < 0) > + goto fail; > + > + err = quant_pl_pass(dec, ctx, exec, buf_bar, &nb_buf_bar); > + if (err < 0) > + goto fail; > + > + err = ff_vk_exec_add_dep_frame(&dec->vkctx, exec, pic->frame->avframe, > + VK_PIPELINE_STAGE_2_ALL_COMMANDS_BIT, > + VK_PIPELINE_STAGE_2_COMPUTE_SHADER_BIT); > + if (err < 0) > + goto fail; > + > + err = > + ff_vk_create_imageviews(&dec->vkctx, exec, views, pic->frame->avframe); > + if (err < 0) > + goto fail; > + > + switch (wavelet_idx) { > + case DWT_DIRAC_DAUB9_7: > + err = wavelet_daub97_pass(dec, ctx, exec, buf_bar, &nb_buf_bar); > + break; > + > + case DWT_DIRAC_FIDELITY: > + err = wavelet_fidelity_pass(dec, ctx, exec, buf_bar, &nb_buf_bar); > + break; > + > + case DWT_DIRAC_DD9_7: > + err = wavelet_dd97_pass(dec, ctx, exec, buf_bar, &nb_buf_bar); > + break; > + > + case DWT_DIRAC_DD13_7: > + err = wavelet_dd137_pass(dec, ctx, exec, buf_bar, &nb_buf_bar); > + break; > + > + case DWT_DIRAC_LEGALL5_3: > + err = wavelet_legall_pass(dec, ctx, exec, buf_bar, &nb_buf_bar); > + break; > + > + case DWT_DIRAC_HAAR0: > + err = wavelet_haari_pass(dec, ctx, exec, buf_bar, &nb_buf_bar, 0); > + break; > + > + case DWT_DIRAC_HAAR1: > + err = wavelet_haari_pass(dec, ctx, exec, buf_bar, &nb_buf_bar, 1); > + break; > + > + default: > + err = AVERROR_PATCHWELCOME; > + break; > + } It seems there is a missing error check here, as the err is immediately overwritten in the next line. > + > + err = cpy_to_image_pass(dec, ctx, exec, views, buf_bar, &nb_buf_bar, > + img_bar, &nb_img_bar, (ctx->bit_depth - 8) >> 1); > + if (err < 0) > + goto fail; > + > + err = ff_vk_exec_submit(&dec->vkctx, exec); > + if (err < 0) > + goto fail; > + > + ff_vk_exec_wait(&dec->vkctx, exec); > + > + return 0; > + > +fail: > + ff_vk_exec_discard_deps(&dec->vkctx, exec); > + return err; > +} > + > +static int vulkan_dirac_update_thread_context(AVCodecContext *dst, > + const AVCodecContext *src) { > + int err; > + DiracVulkanDecodeContext *src_ctx = src->internal->hwaccel_priv_data; > + DiracVulkanDecodeContext *dst_ctx = dst->internal->hwaccel_priv_data; > + FFVkSPIRVCompiler *spv; > + > + spv = ff_vk_spirv_init(); > + if (!spv) { > + av_log(dst, AV_LOG_ERROR, "Unable to initialize SPIR-V compiler!\n"); > + return AVERROR_EXTERNAL; > + } > + > + memset(dst_ctx, 0, sizeof(DiracVulkanDecodeContext)); > + > + dst_ctx->vkctx = src_ctx->vkctx; > + dst_ctx->sampler = src_ctx->sampler; > + dst_ctx->qf = src_ctx->qf; > + dst_ctx->exec_pool = src_ctx->exec_pool; > + dst_ctx->quant_pl = src_ctx->quant_pl; > + > + err = init_quant_shd(dst_ctx, spv); > + if (err < 0) { > + goto fail; > + } > + > + err = init_cpy_shd(dst_ctx, spv, 0); > + if (err < 0) { > + goto fail; > + } > + > + err = init_cpy_shd(dst_ctx, spv, 1); > + if (err < 0) { > + goto fail; > + } > + > + err = init_cpy_shd(dst_ctx, spv, 2); > + if (err < 0) { > + goto fail; > + } > + > + err = wavelet_init(dst_ctx, spv); > + if (err < 0) { > + goto fail; > + } > + > + dst_ctx->quant_val_buf_vk_ptr = NULL; > + dst_ctx->slice_buf_vk_ptr = NULL; > + dst_ctx->quant_buf_vk_ptr = NULL; > + > + dst_ctx->av_quant_val_buf = NULL; > + dst_ctx->av_quant_buf = NULL; > + dst_ctx->av_slice_buf = NULL; > + > + dst_ctx->thread_buf_size = 0; > + dst_ctx->n_slice_bufs = 0; > + > + err = ff_vk_create_buf(&dst_ctx->vkctx, &dst_ctx->subband_info, > + sizeof(SubbandOffset) * MAX_DWT_LEVELS * 12, NULL, > + NULL, > + VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT | > + VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT, > + VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | > + VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT); > + if (err < 0) > + return err; Shouldnt this goto fail? > + > + err = ff_vk_map_buffer(&dst_ctx->vkctx, &dst_ctx->subband_info, > + (uint8_t **)&dst_ctx->subband_info_ptr, 0); > + if (err < 0) > + return err; Same? > + > + return 0; > + > +fail: > + if (spv) { > + spv->uninit(&spv); > + } > + vulkan_dirac_uninit(dst); > + > + return err; > +} > + > +static inline int decode_hq_slice(const DiracContext *s, int jobnr) { > + int i, level, orientation, quant_idx; > + DiracVulkanDecodeContext *dec = s->avctx->internal->hwaccel_priv_data; > + int32_t *qfactor = &dec->quant_buf_vk_ptr[jobnr * 8 * MAX_DWT_LEVELS]; > + int32_t *qoffset = &dec->quant_buf_vk_ptr[jobnr * 8 * MAX_DWT_LEVELS + 4]; > + int32_t *quant_val_base = dec->quant_val_buf_vk_ptr; > + DiracSlice *slice = &s->slice_params_buf[jobnr]; > + SliceCoeffVk *slice_vk = &dec->slice_buf_vk_ptr[jobnr * 3 * MAX_DWT_LEVELS]; > + GetBitContext *gb = &slice->gb; > + > + skip_bits_long(gb, 8 * s->highquality.prefix_bytes); > + quant_idx = get_bits(gb, 8); > + > + if (quant_idx > DIRAC_MAX_QUANT_INDEX - 1) { > + av_log(s->avctx, AV_LOG_ERROR, "Invalid quantization index - %i\n", > + quant_idx); > + return AVERROR_INVALIDDATA; > + } > + > + /* Slice quantization (slice_quantizers() in the specs) */ > + for (level = 0; level < s->wavelet_depth; level++) { > + for (orientation = !!level; orientation < 4; orientation++) { > + const int quant = > + FFMAX(quant_idx - s->lowdelay.quant[level][orientation], 0); > + qfactor[level * 8 + orientation] = ff_dirac_qscale_tab[quant]; > + qoffset[level * 8 + orientation] = > + ff_dirac_qoffset_intra_tab[quant] + 2; > + } > + } > + > + /* Luma + 2 Chroma planes */ > + for (i = 0; i < 3; i++) { > + int coef_num, coef_par; > + int64_t length = s->highquality.size_scaler * get_bits(gb, 8); > + int64_t bits_end = get_bits_count(gb) + 8 * length; > + const uint8_t *addr = align_get_bits(gb); > + int offs = dec->slice_vals_size * (3 * jobnr + i); > + uint8_t *tmp_buf = (uint8_t *)&quant_val_base[offs]; > + > + if (length * 8 > get_bits_left(gb)) { > + av_log(s->avctx, AV_LOG_ERROR, "end too far away\n"); > + return AVERROR_INVALIDDATA; > + } > + > + coef_num = subband_coeffs(s, slice->slice_x, slice->slice_y, i, offs, > + &slice_vk[MAX_DWT_LEVELS * i]); > + > + coef_par = ff_dirac_golomb_read_32bit(addr, length, tmp_buf, coef_num); > + > + if (coef_num > coef_par) { > + const int start_b = coef_par * sizeof(int32_t); > + const int end_b = coef_num * sizeof(int32_t); > + memset(&tmp_buf[start_b], 0, end_b - start_b); > + } > + > + skip_bits_long(gb, bits_end - get_bits_count(gb)); > + } > + > + return 0; > +} > + > +static int decode_hq_slice_row(AVCodecContext *avctx, void *arg, int jobnr, > + int threadnr) { > + const DiracContext *s = avctx->priv_data; > + int i, jobn = s->num_x * jobnr; > + > + for (i = 0; i < s->num_x; i++) { > + decode_hq_slice(s, jobn); > + jobn++; > + } > + > + return 0; > +} > + > +static int vulkan_dirac_decode_slice(AVCodecContext *avctx, const uint8_t *data, > + uint32_t size) { > + DiracContext *s = avctx->priv_data; > + > + /*avctx->execute2(avctx, decode_hq_slice_row, NULL, NULL, s->num_y);*/ > + for (int i = 0; i < s->num_y; i++) { > + decode_hq_slice_row(avctx, NULL, i, 0); > + } > + > + return 0; > +} > + > +const FFHWAccel ff_dirac_vulkan_hwaccel = { > + .p.name = "dirac_vulkan", > + .p.type = AVMEDIA_TYPE_VIDEO, > + .p.id = AV_CODEC_ID_DIRAC, > + .p.pix_fmt = AV_PIX_FMT_VULKAN, > + .start_frame = &vulkan_dirac_start_frame, > + .end_frame = &vulkan_dirac_end_frame, > + .decode_slice = &vulkan_dirac_decode_slice, > + .free_frame_priv = &vulkan_dirac_free_frame_priv, > + .uninit = &vulkan_dirac_uninit, > + .init = &vulkan_dirac_init, > + .frame_params = &vulkan_dirac_frame_params, > + .frame_priv_data_size = sizeof(DiracVulkanDecodePicture), > + .decode_params = &ff_vk_params_invalidate, > + .flush = &ff_vk_decode_flush, > + .update_thread_context = &vulkan_dirac_update_thread_context, > + .priv_data_size = sizeof(DiracVulkanDecodeContext), > + // .caps_internal = HWACCEL_CAP_ASYNC_SAFE | HWACCEL_CAP_THREAD_SAFE, > + .caps_internal = FF_CODEC_CAP_NOT_INIT_THREADSAFE, > +}; > -- > 2.46.0 > > _______________________________________________ > ffmpeg-devel mailing list > ffmpeg-devel@ffmpeg.org > https://ffmpeg.org/mailman/listinfo/ffmpeg-devel > > To unsubscribe, visit link above, or email > ffmpeg-devel-request@ffmpeg.org with subject "unsubscribe".
diff --git a/configure b/configure index d3bd46f382..fd7e4ab6d8 100755 --- a/configure +++ b/configure @@ -3172,6 +3172,8 @@ av1_vdpau_hwaccel_deps="vdpau VdpPictureInfoAV1" av1_vdpau_hwaccel_select="av1_decoder" av1_vulkan_hwaccel_deps="vulkan" av1_vulkan_hwaccel_select="av1_decoder" +dirac_vulkan_hwaccel_deps="vulkan spirv_compiler" +dirac_vulkan_hwaccel_select="dirac_decoder" h263_vaapi_hwaccel_deps="vaapi" h263_vaapi_hwaccel_select="h263_decoder" h263_videotoolbox_hwaccel_deps="videotoolbox" diff --git a/libavcodec/Makefile b/libavcodec/Makefile index b6243bbc82..90548ea2d5 100644 --- a/libavcodec/Makefile +++ b/libavcodec/Makefile @@ -1006,6 +1006,7 @@ OBJS-$(CONFIG_AV1_NVDEC_HWACCEL) += nvdec_av1.o OBJS-$(CONFIG_AV1_VAAPI_HWACCEL) += vaapi_av1.o OBJS-$(CONFIG_AV1_VDPAU_HWACCEL) += vdpau_av1.o OBJS-$(CONFIG_AV1_VULKAN_HWACCEL) += vulkan_decode.o vulkan_av1.o +OBJS-$(CONFIG_DIRAC_VULKAN_HWACCEL) += vulkan_dirac.o OBJS-$(CONFIG_H263_VAAPI_HWACCEL) += vaapi_mpeg4.o OBJS-$(CONFIG_H263_VIDEOTOOLBOX_HWACCEL) += videotoolbox.o OBJS-$(CONFIG_H264_D3D11VA_HWACCEL) += dxva2_h264.o diff --git a/libavcodec/diracdec.c b/libavcodec/diracdec.c index 76209aebba..542824f6e1 100644 --- a/libavcodec/diracdec.c +++ b/libavcodec/diracdec.c @@ -26,228 +26,11 @@ * @author Marco Gerards <marco@gnu.org>, David Conrad, Jordi Ortiz <nenjordi@gmail.com> */ -#include "libavutil/mem.h" -#include "libavutil/mem_internal.h" -#include "libavutil/pixdesc.h" -#include "libavutil/thread.h" -#include "avcodec.h" -#include "get_bits.h" -#include "codec_internal.h" -#include "decode.h" -#include "golomb.h" -#include "dirac_arith.h" -#include "dirac_vlc.h" -#include "mpegvideoencdsp.h" -#include "dirac_dwt.h" -#include "dirac.h" -#include "diractab.h" -#include "diracdsp.h" -#include "videodsp.h" - -#define EDGE_WIDTH 16 - -/** - * The spec limits this to 3 for frame coding, but in practice can be as high as 6 - */ -#define MAX_REFERENCE_FRAMES 8 -#define MAX_DELAY 5 /* limit for main profile for frame coding (TODO: field coding) */ -#define MAX_FRAMES (MAX_REFERENCE_FRAMES + MAX_DELAY + 1) -#define MAX_QUANT 255 /* max quant for VC-2 */ -#define MAX_BLOCKSIZE 32 /* maximum xblen/yblen we support */ - -/** - * DiracBlock->ref flags, if set then the block does MC from the given ref - */ -#define DIRAC_REF_MASK_REF1 1 -#define DIRAC_REF_MASK_REF2 2 -#define DIRAC_REF_MASK_GLOBAL 4 - -/** - * Value of Picture.reference when Picture is not a reference picture, but - * is held for delayed output. - */ -#define DELAYED_PIC_REF 4 - -#define CALC_PADDING(size, depth) \ - (((size + (1 << depth) - 1) >> depth) << depth) - -#define DIVRNDUP(a, b) (((a) + (b) - 1) / (b)) - -typedef struct { - AVFrame *avframe; - int interpolated[3]; /* 1 if hpel[] is valid */ - uint8_t *hpel[3][4]; - uint8_t *hpel_base[3][4]; - int reference; - unsigned picture_number; -} DiracFrame; - -typedef struct { - union { - int16_t mv[2][2]; - int16_t dc[3]; - } u; /* anonymous unions aren't in C99 :( */ - uint8_t ref; -} DiracBlock; - -typedef struct SubBand { - int level; - int orientation; - int stride; /* in bytes */ - int width; - int height; - int pshift; - int quant; - uint8_t *ibuf; - struct SubBand *parent; - - /* for low delay */ - unsigned length; - const uint8_t *coeff_data; -} SubBand; - -typedef struct Plane { - DWTPlane idwt; - - int width; - int height; - ptrdiff_t stride; - - /* block length */ - uint8_t xblen; - uint8_t yblen; - /* block separation (block n+1 starts after this many pixels in block n) */ - uint8_t xbsep; - uint8_t ybsep; - /* amount of overspill on each edge (half of the overlap between blocks) */ - uint8_t xoffset; - uint8_t yoffset; - - SubBand band[MAX_DWT_LEVELS][4]; -} Plane; - -/* Used by Low Delay and High Quality profiles */ -typedef struct DiracSlice { - GetBitContext gb; - int slice_x; - int slice_y; - int bytes; -} DiracSlice; - -typedef struct DiracContext { - AVCodecContext *avctx; - MpegvideoEncDSPContext mpvencdsp; - VideoDSPContext vdsp; - DiracDSPContext diracdsp; - DiracVersionInfo version; - GetBitContext gb; - AVDiracSeqHeader seq; - int seen_sequence_header; - int64_t frame_number; /* number of the next frame to display */ - Plane plane[3]; - int chroma_x_shift; - int chroma_y_shift; - - int bit_depth; /* bit depth */ - int pshift; /* pixel shift = bit_depth > 8 */ - - int zero_res; /* zero residue flag */ - int is_arith; /* whether coeffs use arith or golomb coding */ - int core_syntax; /* use core syntax only */ - int low_delay; /* use the low delay syntax */ - int hq_picture; /* high quality picture, enables low_delay */ - int ld_picture; /* use low delay picture, turns on low_delay */ - int dc_prediction; /* has dc prediction */ - int globalmc_flag; /* use global motion compensation */ - int num_refs; /* number of reference pictures */ - - /* wavelet decoding */ - unsigned wavelet_depth; /* depth of the IDWT */ - unsigned wavelet_idx; - - /** - * schroedinger older than 1.0.8 doesn't store - * quant delta if only one codebook exists in a band - */ - unsigned old_delta_quant; - unsigned codeblock_mode; - - unsigned num_x; /* number of horizontal slices */ - unsigned num_y; /* number of vertical slices */ - - uint8_t *thread_buf; /* Per-thread buffer for coefficient storage */ - int threads_num_buf; /* Current # of buffers allocated */ - int thread_buf_size; /* Each thread has a buffer this size */ - - DiracSlice *slice_params_buf; - int slice_params_num_buf; - - struct { - unsigned width; - unsigned height; - } codeblock[MAX_DWT_LEVELS+1]; - - struct { - AVRational bytes; /* average bytes per slice */ - uint8_t quant[MAX_DWT_LEVELS][4]; /* [DIRAC_STD] E.1 */ - } lowdelay; - - struct { - unsigned prefix_bytes; - uint64_t size_scaler; - } highquality; - - struct { - int pan_tilt[2]; /* pan/tilt vector */ - int zrs[2][2]; /* zoom/rotate/shear matrix */ - int perspective[2]; /* perspective vector */ - unsigned zrs_exp; - unsigned perspective_exp; - } globalmc[2]; - - /* motion compensation */ - uint8_t mv_precision; /* [DIRAC_STD] REFS_WT_PRECISION */ - int16_t weight[2]; /* [DIRAC_STD] REF1_WT and REF2_WT */ - unsigned weight_log2denom; /* [DIRAC_STD] REFS_WT_PRECISION */ - - int blwidth; /* number of blocks (horizontally) */ - int blheight; /* number of blocks (vertically) */ - int sbwidth; /* number of superblocks (horizontally) */ - int sbheight; /* number of superblocks (vertically) */ - - uint8_t *sbsplit; - DiracBlock *blmotion; - - uint8_t *edge_emu_buffer[4]; - uint8_t *edge_emu_buffer_base; - - uint16_t *mctmp; /* buffer holding the MC data multiplied by OBMC weights */ - uint8_t *mcscratch; - int buffer_stride; - - DECLARE_ALIGNED(16, uint8_t, obmc_weight)[3][MAX_BLOCKSIZE*MAX_BLOCKSIZE]; - - void (*put_pixels_tab[4])(uint8_t *dst, const uint8_t *src[5], int stride, int h); - void (*avg_pixels_tab[4])(uint8_t *dst, const uint8_t *src[5], int stride, int h); - void (*add_obmc)(uint16_t *dst, const uint8_t *src, int stride, const uint8_t *obmc_weight, int yblen); - dirac_weight_func weight_func; - dirac_biweight_func biweight_func; - - DiracFrame *current_picture; - DiracFrame *ref_pics[2]; - - DiracFrame *ref_frames[MAX_REFERENCE_FRAMES+1]; - DiracFrame *delay_frames[MAX_DELAY+1]; - DiracFrame all_frames[MAX_FRAMES]; -} DiracContext; - -enum dirac_subband { - subband_ll = 0, - subband_hl = 1, - subband_lh = 2, - subband_hh = 3, - subband_nb, -}; +#include "diracdec.h" +#include "hwaccels.h" +#include "hwconfig.h" +#include "libavutil/imgutils.c" +#include "config_components.h" /* magic number division by 3 from schroedinger */ static inline int divide3(int x) @@ -351,7 +134,7 @@ static int alloc_buffers(DiracContext *s, int stride) return 0; } -static av_cold void free_sequence_buffers(DiracContext *s) +static void free_sequence_buffers(DiracContext *s) { int i, j, k; @@ -403,8 +186,11 @@ static av_cold int dirac_decode_init(AVCodecContext *avctx) for (i = 0; i < MAX_FRAMES; i++) { s->all_frames[i].avframe = av_frame_alloc(); - if (!s->all_frames[i].avframe) + if (!s->all_frames[i].avframe) { + while (i > 0) + av_frame_free(&s->all_frames[--i].avframe); return AVERROR(ENOMEM); + } } ret = ff_thread_once(&dirac_arith_init, ff_dirac_init_arith_tables); if (ret != 0) @@ -413,7 +199,7 @@ static av_cold int dirac_decode_init(AVCodecContext *avctx) return 0; } -static av_cold void dirac_decode_flush(AVCodecContext *avctx) +static void dirac_decode_flush(AVCodecContext *avctx) { DiracContext *s = avctx->priv_data; free_sequence_buffers(s); @@ -426,9 +212,7 @@ static av_cold int dirac_decode_end(AVCodecContext *avctx) DiracContext *s = avctx->priv_data; int i; - // Necessary in case dirac_decode_init() failed - if (s->all_frames[MAX_FRAMES - 1].avframe) - free_sequence_buffers(s); + dirac_decode_flush(avctx); for (i = 0; i < MAX_FRAMES; i++) av_frame_free(&s->all_frames[i].avframe); @@ -812,14 +596,6 @@ static int decode_lowdelay_slice(AVCodecContext *avctx, void *arg) return 0; } -typedef struct SliceCoeffs { - int left; - int top; - int tot_h; - int tot_v; - int tot; -} SliceCoeffs; - static int subband_coeffs(const DiracContext *s, int x, int y, int p, SliceCoeffs c[MAX_DWT_LEVELS]) { @@ -1006,7 +782,10 @@ static int decode_lowdelay(DiracContext *s) return AVERROR_INVALIDDATA; } - avctx->execute2(avctx, decode_hq_slice_row, slices, NULL, s->num_y); + if (avctx->hwaccel) + FF_HW_CALL(avctx, decode_slice, NULL, 0); + else + avctx->execute2(avctx, decode_hq_slice_row, slices, NULL, s->num_y); } else { for (slice_y = 0; bufsize > 0 && slice_y < s->num_y; slice_y++) { for (slice_x = 0; bufsize > 0 && slice_x < s->num_x; slice_x++) { @@ -1873,7 +1652,13 @@ static int dirac_decode_frame_internal(DiracContext *s) { DWTContext d; int y, i, comp, dsty; - int ret; + int ret = -1; + + if (s->avctx->hwaccel) { + ret = FF_HW_CALL(s->avctx, start_frame, NULL, 0); + if (ret < 0) + return ret; + } if (s->low_delay) { /* [DIRAC_STD] 13.5.1 low_delay_transform_data() */ @@ -1889,6 +1674,14 @@ static int dirac_decode_frame_internal(DiracContext *s) } } + if (s->avctx->hwaccel) { + ret = ffhwaccel(s->avctx->hwaccel)->end_frame(s->avctx); + if (ret == 0) { + /* Hwaccel failed - fall back on software decoder */ + } + return ret; + } + for (comp = 0; comp < 3; comp++) { Plane *p = &s->plane[comp]; uint8_t *frame = s->current_picture->avframe->data[comp]; @@ -1904,6 +1697,7 @@ static int dirac_decode_frame_internal(DiracContext *s) if (ret < 0) return ret; } + ret = ff_spatial_idwt_init(&d, &p->idwt, s->wavelet_idx+2, s->wavelet_depth, s->bit_depth); if (ret < 0) @@ -1970,15 +1764,23 @@ static int get_buffer_with_edge(AVCodecContext *avctx, AVFrame *f, int flags) { int ret, i; int chroma_x_shift, chroma_y_shift; - ret = av_pix_fmt_get_chroma_sub_sample(avctx->pix_fmt, &chroma_x_shift, + DiracContext *s = avctx->priv_data; + ret = av_pix_fmt_get_chroma_sub_sample(s->sof_pix_fmt, &chroma_x_shift, &chroma_y_shift); if (ret < 0) return ret; + /*if (avctx->hwaccel) {*/ + /* f->width = s->plane[0].width;*/ + /* f->height = s->plane[0].height;*/ + /* ret = ff_get_buffer(avctx, f, flags);*/ + /* return ret;*/ + /*}*/ + f->width = avctx->width + 2 * EDGE_WIDTH; f->height = avctx->height + 2 * EDGE_WIDTH + 2; ret = ff_get_buffer(avctx, f, flags); - if (ret < 0) + if (ret < 0 || avctx->hwaccel) return ret; for (i = 0; f->data[i]; i++) { @@ -2136,6 +1938,7 @@ static int dirac_decode_data_unit(AVCodecContext *avctx, const uint8_t *buf, int init_get_bits(&s->gb, &buf[13], 8*(size - DATA_UNIT_HEADER_SIZE)); if (parse_code == DIRAC_PCODE_SEQ_HEADER) { + enum AVPixelFormat *pix_fmts; if (s->seen_sequence_header) return 0; @@ -2156,6 +1959,7 @@ static int dirac_decode_data_unit(AVCodecContext *avctx, const uint8_t *buf, int } ff_set_sar(avctx, dsh->sample_aspect_ratio); + s->sof_pix_fmt = dsh->pix_fmt; avctx->pix_fmt = dsh->pix_fmt; avctx->color_range = dsh->color_range; avctx->color_trc = dsh->color_trc; @@ -2172,7 +1976,20 @@ static int dirac_decode_data_unit(AVCodecContext *avctx, const uint8_t *buf, int s->pshift = s->bit_depth > 8; - ret = av_pix_fmt_get_chroma_sub_sample(avctx->pix_fmt, + /*if (s->pshift) {*/ + /* avctx->pix_fmt = s->sof_pix_fmt;*/ + /*} else {*/ + pix_fmts = (enum AVPixelFormat[]){ +#if CONFIG_DIRAC_VULKAN_HWACCEL + AV_PIX_FMT_VULKAN, +#endif + s->sof_pix_fmt, + AV_PIX_FMT_NONE, + }; + avctx->pix_fmt = ff_get_format(s->avctx, pix_fmts); + /*}*/ + + ret = av_pix_fmt_get_chroma_sub_sample(s->sof_pix_fmt, &s->chroma_x_shift, &s->chroma_y_shift); if (ret < 0) @@ -2202,9 +2019,10 @@ static int dirac_decode_data_unit(AVCodecContext *avctx, const uint8_t *buf, int } /* find an unused frame */ - for (i = 0; i < MAX_FRAMES; i++) + for (i = 0; i < MAX_FRAMES; i++) if (s->all_frames[i].avframe->data[0] == NULL) pic = &s->all_frames[i]; + if (!pic) { av_log(avctx, AV_LOG_ERROR, "framelist full\n"); return AVERROR_INVALIDDATA; @@ -2244,12 +2062,28 @@ static int dirac_decode_data_unit(AVCodecContext *avctx, const uint8_t *buf, int if ((ret = get_buffer_with_edge(avctx, pic->avframe, (parse_code & 0x0C) == 0x0C ? AV_GET_BUFFER_FLAG_REF : 0)) < 0) return ret; s->current_picture = pic; - s->plane[0].stride = pic->avframe->linesize[0]; - s->plane[1].stride = pic->avframe->linesize[1]; - s->plane[2].stride = pic->avframe->linesize[2]; - if (alloc_buffers(s, FFMAX3(FFABS(s->plane[0].stride), FFABS(s->plane[1].stride), FFABS(s->plane[2].stride))) < 0) - return AVERROR(ENOMEM); + if (s->avctx->hwaccel) { + if (!(s->low_delay && s->hq_picture)) { + av_log(avctx, AV_LOG_ERROR, "The HWaccel only supports VC-2\n"); + return AVERROR_INVALIDDATA; + } + + if (!s->hwaccel_picture_private) { + const FFHWAccel *hwaccel = ffhwaccel(s->avctx->hwaccel); + s->hwaccel_picture_private = + av_mallocz(hwaccel->frame_priv_data_size); + if (!s->hwaccel_picture_private) + return AVERROR(ENOMEM); + } + } else { + s->plane[0].stride = pic->avframe->linesize[0]; + s->plane[1].stride = pic->avframe->linesize[1]; + s->plane[2].stride = pic->avframe->linesize[2]; + + if (alloc_buffers(s, FFMAX3(FFABS(s->plane[0].stride), FFABS(s->plane[1].stride), FFABS(s->plane[2].stride))) < 0) + return AVERROR(ENOMEM); + } /* [DIRAC_STD] 11.1 Picture parse. picture_parse() */ ret = dirac_decode_picture_header(s); @@ -2359,6 +2193,7 @@ static int dirac_decode_frame(AVCodecContext *avctx, AVFrame *picture, return buf_idx; } + const FFCodec ff_dirac_decoder = { .p.name = "dirac", CODEC_LONG_NAME("BBC Dirac VC-2"), @@ -2370,5 +2205,10 @@ const FFCodec ff_dirac_decoder = { FF_CODEC_DECODE_CB(dirac_decode_frame), .p.capabilities = AV_CODEC_CAP_DELAY | AV_CODEC_CAP_SLICE_THREADS | AV_CODEC_CAP_DR1, .flush = dirac_decode_flush, - .caps_internal = FF_CODEC_CAP_INIT_CLEANUP, + .hw_configs = (const AVCodecHWConfigInternal *const []) { +#if CONFIG_DIRAC_VULKAN_HWACCEL + HWACCEL_VULKAN(dirac), +#endif + NULL + }, }; diff --git a/libavcodec/diracdec.h b/libavcodec/diracdec.h new file mode 100644 index 0000000000..4ca07342ac --- /dev/null +++ b/libavcodec/diracdec.h @@ -0,0 +1,263 @@ +/* + * This file is part of FFmpeg. + * + * FFmpeg is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * FFmpeg is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with FFmpeg; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ + +/** + * @file + * Dirac Decoder Header + * @author Marco Gerards <marco@gnu.org>, David Conrad, Jordi Ortiz <nenjordi@gmail.com> + */ + + +#ifndef AVCODEC_DIRACDEC_H +#define AVCODEC_DIRACDEC_H + +#include "libavutil/mem.h" +#include "libavutil/mem_internal.h" +#include "libavutil/pixdesc.h" +#include "libavutil/thread.h" +#include "avcodec.h" +#include "get_bits.h" +#include "codec_internal.h" +#include "decode.h" +#include "golomb.h" +#include "dirac_arith.h" +#include "dirac_vlc.h" +#include "mpegvideoencdsp.h" +#include "dirac_dwt.h" +#include "dirac.h" +#include "diractab.h" +#include "diracdsp.h" +#include "videodsp.h" +#include "hwaccel_internal.h" + +#define EDGE_WIDTH 16 + +/** + * The spec limits this to 3 for frame coding, but in practice can be as high as 6 + */ +#define MAX_REFERENCE_FRAMES 8 +#define MAX_DELAY 5 /* limit for main profile for frame coding (TODO: field coding) */ +#define MAX_FRAMES (MAX_REFERENCE_FRAMES + MAX_DELAY + 1) +#define MAX_QUANT 255 /* max quant for VC-2 */ +#define MAX_BLOCKSIZE 32 /* maximum xblen/yblen we support */ + +/** + * DiracBlock->ref flags, if set then the block does MC from the given ref + */ +#define DIRAC_REF_MASK_REF1 1 +#define DIRAC_REF_MASK_REF2 2 +#define DIRAC_REF_MASK_GLOBAL 4 + +/** + * Value of Picture.reference when Picture is not a reference picture, but + * is held for delayed output. + */ +#define DELAYED_PIC_REF 4 + +#define CALC_PADDING(size, depth) \ + (((size + (1 << depth) - 1) >> depth) << depth) + +#define DIVRNDUP(a, b) (((a) + (b) - 1) / (b)) + +typedef struct { + AVFrame *avframe; + int interpolated[3]; /* 1 if hpel[] is valid */ + uint8_t *hpel[3][4]; + uint8_t *hpel_base[3][4]; + int reference; + unsigned picture_number; +} DiracFrame; + +typedef struct { + union { + int16_t mv[2][2]; + int16_t dc[3]; + } u; /* anonymous unions aren't in C99 :( */ + uint8_t ref; +} DiracBlock; + +typedef struct SubBand { + int level; + int orientation; + int stride; /* in bytes */ + int width; + int height; + int pshift; + int quant; + uint8_t *ibuf; + struct SubBand *parent; + + /* for low delay */ + unsigned length; + const uint8_t *coeff_data; +} SubBand; + +typedef struct Plane { + DWTPlane idwt; + + int width; + int height; + ptrdiff_t stride; + + /* block length */ + uint8_t xblen; + uint8_t yblen; + /* block separation (block n+1 starts after this many pixels in block n) */ + uint8_t xbsep; + uint8_t ybsep; + /* amount of overspill on each edge (half of the overlap between blocks) */ + uint8_t xoffset; + uint8_t yoffset; + + SubBand band[MAX_DWT_LEVELS][4]; +} Plane; + +/* Used by Low Delay and High Quality profiles */ +typedef struct DiracSlice { + GetBitContext gb; + int slice_x; + int slice_y; + int bytes; +} DiracSlice; + +typedef struct DiracContext { + AVCodecContext *avctx; + MpegvideoEncDSPContext mpvencdsp; + VideoDSPContext vdsp; + DiracDSPContext diracdsp; + DiracVersionInfo version; + GetBitContext gb; + AVDiracSeqHeader seq; + enum AVPixelFormat sof_pix_fmt; + void *hwaccel_picture_private; + int seen_sequence_header; + int64_t frame_number; /* number of the next frame to display */ + Plane plane[3]; + int chroma_x_shift; + int chroma_y_shift; + + int bit_depth; /* bit depth */ + int pshift; /* pixel shift = bit_depth > 8 */ + + int zero_res; /* zero residue flag */ + int is_arith; /* whether coeffs use arith or golomb coding */ + int core_syntax; /* use core syntax only */ + int low_delay; /* use the low delay syntax */ + int hq_picture; /* high quality picture, enables low_delay */ + int ld_picture; /* use low delay picture, turns on low_delay */ + int dc_prediction; /* has dc prediction */ + int globalmc_flag; /* use global motion compensation */ + int num_refs; /* number of reference pictures */ + + /* wavelet decoding */ + unsigned wavelet_depth; /* depth of the IDWT */ + unsigned wavelet_idx; + + /** + * schroedinger older than 1.0.8 doesn't store + * quant delta if only one codebook exists in a band + */ + unsigned old_delta_quant; + unsigned codeblock_mode; + + unsigned num_x; /* number of horizontal slices */ + unsigned num_y; /* number of vertical slices */ + + uint8_t *thread_buf; /* Per-thread buffer for coefficient storage */ + int threads_num_buf; /* Current # of buffers allocated */ + int thread_buf_size; /* Each thread has a buffer this size */ + + DiracSlice *slice_params_buf; + int slice_params_num_buf; + + struct { + unsigned width; + unsigned height; + } codeblock[MAX_DWT_LEVELS+1]; + + struct { + AVRational bytes; /* average bytes per slice */ + uint8_t quant[MAX_DWT_LEVELS][4]; /* [DIRAC_STD] E.1 */ + } lowdelay; + + struct { + unsigned prefix_bytes; + uint64_t size_scaler; + } highquality; + + struct { + int pan_tilt[2]; /* pan/tilt vector */ + int zrs[2][2]; /* zoom/rotate/shear matrix */ + int perspective[2]; /* perspective vector */ + unsigned zrs_exp; + unsigned perspective_exp; + } globalmc[2]; + + /* motion compensation */ + uint8_t mv_precision; /* [DIRAC_STD] REFS_WT_PRECISION */ + int16_t weight[2]; /* [DIRAC_STD] REF1_WT and REF2_WT */ + unsigned weight_log2denom; /* [DIRAC_STD] REFS_WT_PRECISION */ + + int blwidth; /* number of blocks (horizontally) */ + int blheight; /* number of blocks (vertically) */ + int sbwidth; /* number of superblocks (horizontally) */ + int sbheight; /* number of superblocks (vertically) */ + + uint8_t *sbsplit; + DiracBlock *blmotion; + + uint8_t *edge_emu_buffer[4]; + uint8_t *edge_emu_buffer_base; + + uint16_t *mctmp; /* buffer holding the MC data multiplied by OBMC weights */ + uint8_t *mcscratch; + int buffer_stride; + + DECLARE_ALIGNED(16, uint8_t, obmc_weight)[3][MAX_BLOCKSIZE*MAX_BLOCKSIZE]; + + void (*put_pixels_tab[4])(uint8_t *dst, const uint8_t *src[5], int stride, int h); + void (*avg_pixels_tab[4])(uint8_t *dst, const uint8_t *src[5], int stride, int h); + void (*add_obmc)(uint16_t *dst, const uint8_t *src, int stride, const uint8_t *obmc_weight, int yblen); + dirac_weight_func weight_func; + dirac_biweight_func biweight_func; + + DiracFrame *current_picture; + DiracFrame *ref_pics[2]; + + DiracFrame *ref_frames[MAX_REFERENCE_FRAMES+1]; + DiracFrame *delay_frames[MAX_DELAY+1]; + DiracFrame all_frames[MAX_FRAMES]; +} DiracContext; + +enum dirac_subband { + subband_ll = 0, + subband_hl = 1, + subband_lh = 2, + subband_hh = 3, + subband_nb, +}; + +typedef struct SliceCoeffs { + int left; + int top; + int tot_h; + int tot_v; + int tot; +} SliceCoeffs; + +#endif diff --git a/libavcodec/hwaccels.h b/libavcodec/hwaccels.h index 5171e4c7d7..f6d148b169 100644 --- a/libavcodec/hwaccels.h +++ b/libavcodec/hwaccels.h @@ -27,6 +27,7 @@ extern const struct FFHWAccel ff_av1_nvdec_hwaccel; extern const struct FFHWAccel ff_av1_vaapi_hwaccel; extern const struct FFHWAccel ff_av1_vdpau_hwaccel; extern const struct FFHWAccel ff_av1_vulkan_hwaccel; +extern const struct FFHWAccel ff_dirac_vulkan_hwaccel; extern const struct FFHWAccel ff_h263_vaapi_hwaccel; extern const struct FFHWAccel ff_h263_videotoolbox_hwaccel; extern const struct FFHWAccel ff_h264_d3d11va_hwaccel; diff --git a/libavcodec/vulkan_dirac.c b/libavcodec/vulkan_dirac.c new file mode 100644 index 0000000000..7f30e4f0fe --- /dev/null +++ b/libavcodec/vulkan_dirac.c @@ -0,0 +1,3817 @@ +/* + * This file is part of FFmpeg. + * + * FFmpeg is free software; you can redistribute it and/or + * modify it under the terms of the GNU Lesser General Public + * License as published by the Free Software Foundation; either + * version 2.1 of the License, or (at your option) any later version. + * + * FFmpeg is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU + * Lesser General Public License for more details. + * + * You should have received a copy of the GNU Lesser General Public + * License along with FFmpeg; if not, write to the Free Software + * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA + */ + +#include "diracdec.h" +#include "libavcodec/dirac_vlc.h" +#include "libavcodec/pthread_internal.h" +#include "libavfilter/vulkan_spirv.h" +#include "libavutil/vulkan_loader.h" +#include "vulkan.h" +#include "vulkan_decode.h" + +typedef struct SubbandOffset { + int base_off; + int stride; + int pad0; + int pad1; +} SubbandOffset; + +typedef struct SliceCoeffVk { + int left; + int top; + int tot_h; + int tot_v; + int tot; + int offs; + int pad0; + int pad1; +} SliceCoeffVk; + +typedef struct WaveletPushConst { + int real_plane_dims[6]; + int plane_offs[3]; + int plane_strides[3]; + int dw[3]; + int wavelet_depth; +} WaveletPushConst; + +typedef struct DiracVulkanDecodeContext { + FFVulkanContext vkctx; + VkSamplerYcbcrConversion yuv_sampler; + VkSampler sampler; + + FFVulkanPipeline vert_wavelet_pl[9]; + FFVkSPIRVShader vert_wavelet_shd[9]; + + FFVulkanPipeline horiz_wavelet_pl[9]; + FFVkSPIRVShader horiz_wavelet_shd[9]; + + FFVulkanPipeline cpy_to_image_pl[3]; + FFVkSPIRVShader cpy_to_image_shd[3]; + + FFVulkanPipeline quant_pl; + FFVkSPIRVShader quant_shd; + + FFVkQueueFamilyCtx qf; + FFVkExecPool exec_pool; + + int quant_val_buf_size; + int thread_buf_size; + int32_t *quant_val_buf_vk_ptr; + FFVkBuffer *quant_val_buf; + AVBufferRef *av_quant_val_buf; + size_t quant_val_buf_offs; + + int n_slice_bufs; + int slice_buf_size; + SliceCoeffVk *slice_buf_vk_ptr; + FFVkBuffer *quant_buf; + AVBufferRef *av_quant_buf; + size_t quant_buf_offs; + + int32_t *quant_buf_vk_ptr; + int quant_buf_size; + FFVkBuffer *slice_buf; + AVBufferRef *av_slice_buf; + size_t slice_buf_offs; + + FFVkBuffer tmp_buf; + FFVkBuffer tmp_interleave_buf; + + FFVkBuffer subband_info; + SubbandOffset *subband_info_ptr; + + int slice_vals_size; + + WaveletPushConst pConst; +} DiracVulkanDecodeContext; + +typedef struct DiracVulkanDecodePicture { + DiracFrame *frame; +} DiracVulkanDecodePicture; + +static void free_common(AVCodecContext *avctx) { + DiracVulkanDecodeContext *dec = avctx->internal->hwaccel_priv_data; + DiracContext *ctx = avctx->priv_data; + FFVulkanContext *s = &dec->vkctx; + FFVulkanFunctions *vk = &dec->vkctx.vkfn; + + if (ctx->hwaccel_picture_private) { + av_free(ctx->hwaccel_picture_private); + } + + /* Wait on and free execution pool */ + if (dec->exec_pool.cmd_bufs) { + ff_vk_exec_pool_free(s, &dec->exec_pool); + } + + ff_vk_pipeline_free(s, &dec->quant_pl); + ff_vk_shader_free(s, &dec->quant_shd); + + for (int i = 0; i < 3; i++) { + ff_vk_pipeline_free(s, &dec->cpy_to_image_pl[i]); + ff_vk_shader_free(s, &dec->cpy_to_image_shd[i]); + } + + for (int i = 0; i < 9; i++) { + ff_vk_pipeline_free(s, &dec->vert_wavelet_pl[i]); + ff_vk_shader_free(s, &dec->vert_wavelet_shd[i]); + + ff_vk_pipeline_free(s, &dec->horiz_wavelet_pl[i]); + ff_vk_shader_free(s, &dec->horiz_wavelet_shd[i]); + } + // TODO: Add freeing all pipelines and shaders for wavelets + // + + // if (dec->yuv_sampler) + // vk->DestroySamplerYcbcrConversion(s->hwctx->act_dev, + // dec->yuv_sampler, + // s->hwctx->alloc); + if (dec->sampler) + vk->DestroySampler(s->hwctx->act_dev, dec->sampler, s->hwctx->alloc); + + av_buffer_unref(&dec->av_quant_val_buf); + av_buffer_unref(&dec->av_quant_buf); + av_buffer_unref(&dec->av_slice_buf); + av_buffer_unref(&dec->av_slice_buf); + + ff_vk_free_buf(&dec->vkctx, &dec->subband_info); + + ff_vk_free_buf(&dec->vkctx, &dec->tmp_buf); + ff_vk_free_buf(&dec->vkctx, &dec->tmp_interleave_buf); + + ff_vk_uninit(s); +} + +static av_always_inline inline void bar_read(VkBufferMemoryBarrier2 *buf_bar, + int *nb_buf_bar, FFVkBuffer *buf) { + buf_bar[(*nb_buf_bar)++] = (VkBufferMemoryBarrier2){ + .sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER_2, + .srcStageMask = VK_PIPELINE_STAGE_2_ALL_COMMANDS_BIT, + .dstStageMask = VK_PIPELINE_STAGE_2_COMPUTE_SHADER_BIT, + .srcAccessMask = VK_ACCESS_2_MEMORY_READ_BIT, + .dstAccessMask = VK_ACCESS_2_MEMORY_READ_BIT, + .srcQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED, + .dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED, + .buffer = buf->buf, + .size = buf->size, + .offset = 0, + }; +} + +static av_always_inline inline void +bar_write(VkBufferMemoryBarrier2 *buf_bar, int *nb_buf_bar, FFVkBuffer *buf) { + buf_bar[(*nb_buf_bar)++] = (VkBufferMemoryBarrier2){ + .sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER_2, + .srcStageMask = VK_PIPELINE_STAGE_2_ALL_COMMANDS_BIT, + .dstStageMask = VK_PIPELINE_STAGE_2_COMPUTE_SHADER_BIT, + .srcAccessMask = VK_ACCESS_2_MEMORY_READ_BIT, + .dstAccessMask = VK_ACCESS_2_MEMORY_WRITE_BIT, + .srcQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED, + .dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED, + .buffer = buf->buf, + .size = buf->size, + .offset = 0, + }; +} + +static inline int alloc_tmp_bufs(DiracContext *ctx, + DiracVulkanDecodeContext *dec) { + int err, plane_size; + + plane_size = sizeof(int32_t) * + (ctx->plane[0].idwt.width * ctx->plane[0].idwt.height + + ctx->plane[1].idwt.width * ctx->plane[1].idwt.height + + ctx->plane[2].idwt.width * ctx->plane[2].idwt.height); + + if (dec->tmp_buf.buf != NULL) { + ff_vk_free_buf(&dec->vkctx, &dec->tmp_buf); + ff_vk_free_buf(&dec->vkctx, &dec->tmp_interleave_buf); + } + + err = ff_vk_create_buf(&dec->vkctx, &dec->tmp_buf, plane_size, NULL, NULL, + VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT | + VK_BUFFER_USAGE_STORAGE_BUFFER_BIT, + VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT); + if (err < 0) + return err; + + err = ff_vk_create_buf(&dec->vkctx, &dec->tmp_interleave_buf, plane_size, + NULL, NULL, + VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT | + VK_BUFFER_USAGE_STORAGE_BUFFER_BIT, + VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT); + if (err < 0) + return err; + + return 0; +} + +static inline int alloc_host_mapped_buf(DiracVulkanDecodeContext *dec, + size_t req_size, void **mem, + AVBufferRef **avbuf, FFVkBuffer **buf) { + // FFVulkanFunctions *vk = &dec->vkctx.vkfn; + // VkResult ret; + int err; + + err = ff_vk_create_avbuf(&dec->vkctx, avbuf, req_size, NULL, NULL, + VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT | + VK_BUFFER_USAGE_STORAGE_BUFFER_BIT, + VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT); + if (err < 0) + return err; + + *buf = (FFVkBuffer *)(*avbuf)->data; + err = ff_vk_map_buffer(&dec->vkctx, *buf, (uint8_t **)mem, 0); + if (err < 0) + return err; + + return 0; +} + +static int alloc_slices_buf(DiracContext *ctx, DiracVulkanDecodeContext *dec) { + int err, length = ctx->num_y * ctx->num_x; + + dec->n_slice_bufs = length; + + if (dec->slice_buf_vk_ptr) { + av_buffer_unref(&dec->av_slice_buf); + } + + dec->slice_buf_size = sizeof(SliceCoeffVk) * length * 3 * MAX_DWT_LEVELS; + err = alloc_host_mapped_buf(dec, dec->slice_buf_size, + (void **)&dec->slice_buf_vk_ptr, + &dec->av_slice_buf, &dec->slice_buf); + if (err < 0) + return err; + + return 0; +} + +static int alloc_dequant_buf(DiracContext *ctx, DiracVulkanDecodeContext *dec) { + int err, length = ctx->num_y * ctx->num_x; + + if (dec->quant_buf_vk_ptr) { + av_buffer_unref(&dec->av_quant_buf); + } + + dec->n_slice_bufs = length; + + dec->quant_buf_size = sizeof(int32_t) * MAX_DWT_LEVELS * 8 * length; + err = alloc_host_mapped_buf(dec, dec->quant_buf_size, + (void **)&dec->quant_buf_vk_ptr, + &dec->av_quant_buf, &dec->quant_buf); + if (err < 0) + return err; + + return 0; +} + +static int subband_coeffs(const DiracContext *s, int x, int y, int p, int off, + SliceCoeffVk *c) { + int level, coef = 0; + for (level = 0; level <= s->wavelet_depth; level++) { + SliceCoeffVk *o = &c[level]; + const SubBand *b = + &s->plane[p].band[level][3]; /* orientation doens't matter */ + o->top = b->height * y / s->num_y; + o->left = b->width * x / s->num_x; + o->tot_h = ((b->width * (x + 1)) / s->num_x) - o->left; + o->tot_v = ((b->height * (y + 1)) / s->num_y) - o->top; + o->tot = o->tot_h * o->tot_v; + o->offs = off + coef; + coef += o->tot * (4 - !!level); + } + return coef; +} + +static int alloc_quant_buf(DiracContext *ctx, DiracVulkanDecodeContext *dec) { + int err, length = ctx->num_y * ctx->num_x, coef_buf_size; + SliceCoeffVk tmp[MAX_DWT_LEVELS]; + coef_buf_size = + subband_coeffs(ctx, ctx->num_x - 1, ctx->num_y - 1, 0, 0, tmp) + 8; + coef_buf_size = coef_buf_size + 512; + dec->slice_vals_size = coef_buf_size / sizeof(int32_t); + // coef_buf_size *= sizeof(int32_t); + + if (dec->quant_val_buf_vk_ptr) { + av_buffer_unref(&dec->av_quant_val_buf); + } + + dec->thread_buf_size = coef_buf_size; + + dec->quant_val_buf_size = dec->thread_buf_size * 3 * length; + err = alloc_host_mapped_buf(dec, dec->quant_val_buf_size, + (void **)&dec->quant_val_buf_vk_ptr, + &dec->av_quant_val_buf, &dec->quant_val_buf); + if (err < 0) + return err; + + return 0; +} + +/* ----- Copy Shader init and pipeline pass ----- */ + +static int init_cpy_shd(DiracVulkanDecodeContext *s, FFVkSPIRVCompiler *spv, + int idx) { + int err = 0; + uint8_t *spv_data; + size_t spv_len; + void *spv_opaque = NULL; + FFVulkanContext *vkctx = &s->vkctx; + FFVulkanDescriptorSetBinding *desc; + FFVkSPIRVShader *shd = &s->cpy_to_image_shd[idx]; + FFVulkanPipeline *pl = &s->cpy_to_image_pl[idx]; + FFVkExecPool *exec = &s->exec_pool; + const int planes = av_pix_fmt_count_planes(s->vkctx.output_format); + + RET(ff_vk_shader_init(pl, shd, "cpy_to_image", VK_SHADER_STAGE_COMPUTE_BIT, + 0)); + + shd = &s->cpy_to_image_shd[idx]; + ff_vk_shader_set_compute_sizes(shd, 8, 8, 1); + + GLSLC(0, #extension GL_EXT_debug_printf : enable); + GLSLC(0, #extension GL_EXT_scalar_block_layout : require); + GLSLC(0, #extension GL_EXT_shader_explicit_arithmetic_types : require); + + desc = (FFVulkanDescriptorSetBinding[]){ + { + .name = "in_buf", + .stages = VK_SHADER_STAGE_COMPUTE_BIT, + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, + .buf_content = "int32_t inBuf[];", + .mem_quali = "readonly", + .dimensions = 1, + }, + { + .name = "out_img", + .stages = VK_SHADER_STAGE_COMPUTE_BIT, + .type = VK_DESCRIPTOR_TYPE_STORAGE_IMAGE, + .mem_quali = "writeonly", + // .mem_layout = ff_vk_shader_rep_fmt(vkctx->output_format), + .mem_layout = "rgba16f", + .dimensions = 2, + .elems = planes, + }, + }; + RET(ff_vk_pipeline_descriptor_set_add(vkctx, pl, shd, desc, 2, 0, 0)); + + ff_vk_add_push_constant(pl, 0, sizeof(WaveletPushConst), + VK_SHADER_STAGE_COMPUTE_BIT); + + GLSLC( + 0, layout(push_constant, std430) uniform pushConstants { ); + GLSLC(1, ivec2 plane_sizes[3];); + GLSLC(1, int plane_offs[3];); + GLSLC(1, int plane_strides[3];); + GLSLC(1, int dw[3];); + GLSLC(1, int wavelet_depth;); + GLSLC(0, + };); + GLSLC(0, ); + + GLSLC( + 0, void main() {); + GLSLC(1, int x = int(gl_GlobalInvocationID.x);); + GLSLC(1, int y = int(gl_GlobalInvocationID.y);); + GLSLC(1, int plane = int(gl_GlobalInvocationID.z);); + GLSLC(1, if (!IS_WITHIN(ivec2(x, y), + imageSize(out_img[plane]))) return;); + GLSLC(1, + int idx = plane_offs[plane] + y * plane_strides[plane] + x;); + if (idx == 2) { + GLSLC(1, int32_t ival = inBuf[idx] + 2048;); + GLSLC(1, float val = float(clamp(ival, 0, 4096)) / 65535.0;); + } else if (idx == 1) { + GLSLC(1, int32_t ival = inBuf[idx] + 512;); + GLSLC(1, float val = float(clamp(ival, 0, 1024)) / 65535.0;); + } else { + GLSLC(1, int32_t ival = inBuf[idx] + 128;); + GLSLC(1, float val = float(clamp(ival, 0, 256)) / 255.0;); + } + GLSLC(1, imageStore(out_img[plane], ivec2(x, y), vec4(val));); + GLSLC(1, memoryBarrier();); + GLSLC(0, + }); + + RET(spv->compile_shader(spv, vkctx, shd, &spv_data, &spv_len, "main", + &spv_opaque)); + RET(ff_vk_shader_create(vkctx, shd, spv_data, spv_len, "main")); + RET(ff_vk_init_compute_pipeline(vkctx, pl, shd)); + RET(ff_vk_exec_pipeline_register(vkctx, exec, pl)); + +fail: + if (spv_opaque) + spv->free_shader(spv, &spv_opaque); + + return err; +} + +static av_always_inline int inline cpy_to_image_pass( + DiracVulkanDecodeContext *dec, DiracContext *ctx, FFVkExecContext *exec, + VkImageView *views, VkBufferMemoryBarrier2 *buf_bar, int *nb_buf_bar, + VkImageMemoryBarrier2 *img_bar, int *nb_img_bar, int idx) { + int err, prev_nb_bar = *nb_buf_bar, prev_nb_img_bar = *nb_img_bar; + FFVulkanFunctions *vk = &dec->vkctx.vkfn; + DiracVulkanDecodePicture *pic = ctx->hwaccel_picture_private; + + err = ff_vk_set_descriptor_buffer(&dec->vkctx, &dec->cpy_to_image_pl[idx], + exec, 0, 0, 0, dec->tmp_buf.address, + dec->tmp_buf.size, VK_FORMAT_UNDEFINED); + if (err < 0) + return err; + + ff_vk_update_descriptor_img_array(&dec->vkctx, &dec->cpy_to_image_pl[idx], + exec, pic->frame->avframe, views, 0, 1, + VK_IMAGE_LAYOUT_GENERAL, dec->sampler); + + dec->pConst.real_plane_dims[0] = ctx->plane[0].idwt.width; + dec->pConst.real_plane_dims[1] = ctx->plane[0].idwt.height; + dec->pConst.real_plane_dims[2] = ctx->plane[1].idwt.width; + dec->pConst.real_plane_dims[3] = ctx->plane[1].idwt.height; + dec->pConst.real_plane_dims[4] = ctx->plane[2].idwt.width; + dec->pConst.real_plane_dims[5] = ctx->plane[2].idwt.height; + + dec->pConst.plane_strides[0] = ctx->plane[0].idwt.width; + dec->pConst.plane_strides[1] = ctx->plane[1].idwt.width; + dec->pConst.plane_strides[2] = ctx->plane[2].idwt.width; + + dec->pConst.plane_offs[0] = 0; + dec->pConst.plane_offs[1] = + ctx->plane[0].idwt.width * ctx->plane[0].idwt.height; + dec->pConst.plane_offs[2] = + dec->pConst.plane_offs[1] + + ctx->plane[1].idwt.width * ctx->plane[1].idwt.height; + + ff_vk_update_push_exec(&dec->vkctx, exec, &dec->cpy_to_image_pl[idx], + VK_SHADER_STAGE_COMPUTE_BIT, 0, + sizeof(WaveletPushConst), &dec->pConst); + + bar_read(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); + bar_write(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); + bar_read(buf_bar, nb_buf_bar, &dec->tmp_buf); + bar_write(buf_bar, nb_buf_bar, &dec->tmp_buf); + + ff_vk_frame_barrier(&dec->vkctx, exec, pic->frame->avframe, img_bar, + nb_img_bar, VK_PIPELINE_STAGE_2_ALL_COMMANDS_BIT, + VK_PIPELINE_STAGE_2_COMPUTE_SHADER_BIT, + VK_ACCESS_SHADER_READ_BIT, VK_IMAGE_LAYOUT_GENERAL, + VK_QUEUE_FAMILY_IGNORED); + + vk->CmdPipelineBarrier2( + exec->buf, &(VkDependencyInfo){ + .sType = VK_STRUCTURE_TYPE_DEPENDENCY_INFO, + .pBufferMemoryBarriers = buf_bar + prev_nb_bar, + .bufferMemoryBarrierCount = *nb_buf_bar - prev_nb_bar, + .pImageMemoryBarriers = img_bar + prev_nb_img_bar, + .imageMemoryBarrierCount = *nb_img_bar - prev_nb_img_bar, + }); + + ff_vk_exec_bind_pipeline(&dec->vkctx, exec, &dec->cpy_to_image_pl[idx]); + + vk->CmdDispatch(exec->buf, ctx->plane[0].width >> 3, + ctx->plane[0].height >> 3, 3); + + return 0; +} + +/* ----- LeGall Wavelet init and pipeline pass ----- */ + +static const char get_idx[] = {C( + 0, int getIdx(int plane, int x, int y) { ) + C(1, return plane_offs[plane] + plane_strides[plane] * y + x; ) + C(0, + })}; + +static const char legall_low_y[] = {C( + 0, int32_t legall_low_y(int plane, int x, int y) { ) + C(1, const int h = plane_sizes[plane].y; ) + C(1, ) + C(1, const int y_1 = ((y - 1) > 0) ? (y - 1) : 1; ) + C(1, const int32_t val_1 = inBuf[getIdx(plane, x, y_1)]; ) + C(1, const int y0 = y; ) + C(1, const int32_t val0 = inBuf[getIdx(plane, x, y0)]; ) + C(1, const int y1 = y + 1; ) + C(1, const int32_t val1 = inBuf[getIdx(plane, x, y1)]; ) + C(1, return val0 - ((val1 + val_1 + 2) >> 2); ) + C(0, + })}; + +static const char legall_high[] = {C( + 0, int32_t legall_high(int32_t v1, int32_t v2, int32_t v3) { ) + C(1, return v1 + ((v2 + v3 + 1) >> 1); ) + C(0, + })}; + +static const char legall_vert[] = {C( + 0, void idwt_vert(int plane, int x, int y) { ) + C(1, const int h = plane_sizes[plane].y; ) + C(1, ) + C(1, const int32_t out0 = legall_low_y(plane, x, y); ) + C(1, const int32_t yy = ((y + 2) < h) ? (y + 2) : (h - 2); ) + C(1, const int32_t tmp1 = legall_low_y(plane, x, yy); ) + C(1, ) + C(1, const int y1 = y + 1; ) + C(1, const int32_t val1 = inBuf[getIdx(plane, x, y1)]; ) + C(1, ) + C(1, const int32_t out1 = legall_high(val1, out0, tmp1); ) + C(1, ) + C(1, outBuf[getIdx(plane, x, y)] = out0; ) + C(1, outBuf[getIdx(plane, x, y + 1)] = out1; ) + C(0, + })}; + +static const char legall_low_x[] = {C( + 0, int32_t legall_low_x(int plane, int x, int y) { ) + C(1, const int w = plane_sizes[plane].x; ) + C(1, const int dw = w / 2; ) + C(1, ) + C(1, const int x_1 = (x > 0) ? x : 0; ) + C(1, const int32_t val_1 = inBuf[getIdx(plane, x_1, y)]; ) + C(1, ) + C(1, const int x1 = (x > 0) ? (x + dw) : dw; ) + C(1, const int32_t val1 = inBuf[getIdx(plane, x1, y)]; ) + C(1, ) + C(1, const int x0 = (x > 0) ? (x + dw - 1) : dw; ) + C(1, const int32_t val0 = inBuf[getIdx(plane, x0, y)]; ) + C(1, ) + C(1, return val_1 - ((val0 + val1 + 2) >> 2); ) + C(0, + })}; + +static const char legall_horiz[] = {C( + 0, void idwt_horiz(int plane, int x, int y) { ) + C(1, const int w = plane_sizes[plane].x; ) + C(1, const int dw = w / 2 - 1; ) + C(1, ) + C(1, const int32_t out0 = legall_low_x(plane, x, y); ) + C(1, const int32_t tmp1 = (x == dw) ? out0 : legall_low_x(plane, x + 1, y); ) + C(1, ) + C(1, const int x1 = x + dw + 1; ) + C(1, const int32_t val1 = inBuf[getIdx(plane, x1, y)]; ) + C(1, ) + C(1, const int32_t out1 = legall_high(val1, out0, tmp1); ) + C(1, ) + C(1, outBuf[getIdx(plane, 2 * x, y)] = (out0 + 1) >> 1; ) + C(1, outBuf[getIdx(plane, 2 * x + 1, y)] = (out1 + 1) >> 1; ) + C(0, + })}; + +static int init_wavelet_shd_legall_vert(DiracVulkanDecodeContext *s, + FFVkSPIRVCompiler *spv) { + int err = 0; + uint8_t *spv_data; + size_t spv_len; + void *spv_opaque = NULL; + int wavelet_idx = DWT_DIRAC_LEGALL5_3; + FFVulkanContext *vkctx = &s->vkctx; + FFVulkanDescriptorSetBinding *desc; + FFVkSPIRVShader *shd = &s->vert_wavelet_shd[wavelet_idx]; + FFVulkanPipeline *pl = &s->vert_wavelet_pl[wavelet_idx]; + FFVkExecPool *exec = &s->exec_pool; + + RET(ff_vk_shader_init(pl, shd, "legall_vert", VK_SHADER_STAGE_COMPUTE_BIT, + 0)); + + shd = &s->vert_wavelet_shd[wavelet_idx]; + ff_vk_shader_set_compute_sizes(shd, 8, 8, 3); + + GLSLC(0, #extension GL_EXT_scalar_block_layout : require); + GLSLC(0, #extension GL_EXT_shader_explicit_arithmetic_types : require); + + desc = (FFVulkanDescriptorSetBinding[]){ + { + .name = "in_buf", + .stages = VK_SHADER_STAGE_COMPUTE_BIT, + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, + .buf_content = "int32_t inBuf[];", + .mem_quali = "readonly", + .dimensions = 1, + }, + { + .name = "out_buf", + .stages = VK_SHADER_STAGE_COMPUTE_BIT, + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, + .buf_content = "int32_t outBuf[];", + .mem_quali = "writeonly", + .dimensions = 1, + }, + }; + RET(ff_vk_pipeline_descriptor_set_add(vkctx, pl, shd, desc, 2, 0, 0)); + + ff_vk_add_push_constant(pl, 0, sizeof(WaveletPushConst), + VK_SHADER_STAGE_COMPUTE_BIT); + + GLSLC( + 0, layout(push_constant, std430) uniform pushConstants { ); + GLSLC(1, ivec2 plane_sizes[3];); + GLSLC(1, int plane_offs[3];); + GLSLC(1, int plane_strides[3];); + GLSLC(1, int dw[3];); + GLSLC(1, int wavelet_depth;); + GLSLC(0, + };); + GLSLC(0, ); + + GLSLD(get_idx); + GLSLD(legall_low_y); + GLSLD(legall_high); + GLSLD(legall_vert); + + GLSLC( + 0, void main() { ); + GLSLC(1, int off_y = int(gl_WorkGroupSize.y * gl_NumWorkGroups.y);); + GLSLC(1, int off_x = int(gl_WorkGroupSize.x * gl_NumWorkGroups.x);); + GLSLC(1, int pic_z = int(gl_GlobalInvocationID.z);); + GLSLC(1, ); + GLSLC(1, uint h = int(plane_sizes[pic_z].y);); + GLSLC(2, uint w = int(plane_sizes[pic_z].x);); + GLSLC(1, ); + GLSLC(1, int y = int(gl_GlobalInvocationID.y);); + GLSLC( + 1, for (; 2 * y < h; y += off_y) { ); + GLSLC(2, int x = int(gl_GlobalInvocationID.x);); + GLSLC( + 2, for (; x < w; x += off_x) { ); + GLSLC(3, idwt_vert(pic_z, x, 2 * y);); + GLSLC(2, + }); + GLSLC(1, + }); + GLSLC(0, + }); + + RET(spv->compile_shader(spv, vkctx, shd, &spv_data, &spv_len, "main", + &spv_opaque)); + RET(ff_vk_shader_create(vkctx, shd, spv_data, spv_len, "main")); + RET(ff_vk_init_compute_pipeline(vkctx, pl, shd)); + RET(ff_vk_exec_pipeline_register(vkctx, exec, pl)); + +fail: + if (spv_opaque) + spv->free_shader(spv, &spv_opaque); + + return err; +} + +static int init_wavelet_shd_legall_horiz(DiracVulkanDecodeContext *s, + FFVkSPIRVCompiler *spv) { + int err = 0; + uint8_t *spv_data; + size_t spv_len; + void *spv_opaque = NULL; + int wavelet_idx = DWT_DIRAC_LEGALL5_3; + FFVulkanContext *vkctx = &s->vkctx; + FFVulkanDescriptorSetBinding *desc; + FFVkSPIRVShader *shd = &s->horiz_wavelet_shd[wavelet_idx]; + FFVulkanPipeline *pl = &s->horiz_wavelet_pl[wavelet_idx]; + FFVkExecPool *exec = &s->exec_pool; + + RET(ff_vk_shader_init(pl, shd, "legall_horiz", VK_SHADER_STAGE_COMPUTE_BIT, + 0)); + + shd = &s->horiz_wavelet_shd[wavelet_idx]; + ff_vk_shader_set_compute_sizes(shd, 8, 8, 3); + + GLSLC(0, #extension GL_EXT_debug_printf : enable); + GLSLC(0, #extension GL_EXT_scalar_block_layout : require); + GLSLC(0, #extension GL_EXT_shader_explicit_arithmetic_types : require); + + desc = (FFVulkanDescriptorSetBinding[]){ + { + .name = "in_buf", + .stages = VK_SHADER_STAGE_COMPUTE_BIT, + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, + .buf_content = "int32_t inBuf[];", + .mem_quali = "readonly", + .dimensions = 1, + }, + { + .name = "out_buf", + .stages = VK_SHADER_STAGE_COMPUTE_BIT, + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, + .buf_content = "int32_t outBuf[];", + .mem_quali = "writeonly", + .dimensions = 1, + }, + }; + RET(ff_vk_pipeline_descriptor_set_add(vkctx, pl, shd, desc, 2, 0, 0)); + + ff_vk_add_push_constant(pl, 0, sizeof(WaveletPushConst), + VK_SHADER_STAGE_COMPUTE_BIT); + + GLSLC( + 0, layout(push_constant, std430) uniform pushConstants { ); + GLSLC(1, ivec2 plane_sizes[3];); + GLSLC(1, int plane_offs[3];); + GLSLC(1, int plane_strides[3];); + GLSLC(1, int dw[3];); + GLSLC(1, int wavelet_depth;); + GLSLC(0, + };); + GLSLC(0, ); + + GLSLD(get_idx); + GLSLD(legall_low_x); + GLSLD(legall_high); + GLSLD(legall_horiz); + + GLSLC( + 0, void main() { ); + GLSLC(1, int off_y = int(gl_WorkGroupSize.y * gl_NumWorkGroups.y);); + GLSLC(1, int off_x = int(gl_WorkGroupSize.x * gl_NumWorkGroups.x);); + GLSLC(1, int pic_z = int(gl_GlobalInvocationID.z);); + GLSLC(1, ); + GLSLC(1, uint h = int(plane_sizes[pic_z].y);); + GLSLC(2, uint w = int(plane_sizes[pic_z].x);); + GLSLC(1, ); + GLSLC(1, int y = int(gl_GlobalInvocationID.y);); + GLSLC( + 1, for (; y < h; y += off_y) { ); + GLSLC(2, int x = int(gl_GlobalInvocationID.x);); + GLSLC( + 2, for (; 2 * x < w; x += off_x) { ); + GLSLC(3, idwt_horiz(pic_z, x, y);); + GLSLC(2, + }); + GLSLC(1, + }); + GLSLC(0, + }); + + RET(spv->compile_shader(spv, vkctx, shd, &spv_data, &spv_len, "main", + &spv_opaque)); + RET(ff_vk_shader_create(vkctx, shd, spv_data, spv_len, "main")); + RET(ff_vk_init_compute_pipeline(vkctx, pl, shd)); + RET(ff_vk_exec_pipeline_register(vkctx, exec, pl)); + +fail: + if (spv_opaque) + spv->free_shader(spv, &spv_opaque); + + return err; +} + +static av_always_inline int inline wavelet_legall_pass( + DiracVulkanDecodeContext *dec, DiracContext *ctx, FFVkExecContext *exec, + VkBufferMemoryBarrier2 *buf_bar, int *nb_buf_bar) { + int err; + int barrier_num = *nb_buf_bar; + int wavelet_idx = DWT_DIRAC_LEGALL5_3; + FFVulkanFunctions *vk = &dec->vkctx.vkfn; + + FFVulkanPipeline *pl_hor = &dec->horiz_wavelet_pl[wavelet_idx]; + FFVulkanPipeline *pl_vert = &dec->vert_wavelet_pl[wavelet_idx]; + + err = ff_vk_set_descriptor_buffer(&dec->vkctx, pl_vert, exec, 0, 0, 0, + dec->tmp_buf.address, dec->tmp_buf.size, + VK_FORMAT_UNDEFINED); + if (err < 0) + goto fail; + err = ff_vk_set_descriptor_buffer( + &dec->vkctx, pl_vert, exec, 0, 1, 0, dec->tmp_interleave_buf.address, + dec->tmp_interleave_buf.size, VK_FORMAT_UNDEFINED); + if (err < 0) + goto fail; + + err = ff_vk_set_descriptor_buffer( + &dec->vkctx, pl_hor, exec, 0, 0, 0, dec->tmp_interleave_buf.address, + dec->tmp_interleave_buf.size, VK_FORMAT_UNDEFINED); + if (err < 0) + goto fail; + err = ff_vk_set_descriptor_buffer(&dec->vkctx, pl_hor, exec, 0, 1, 0, + dec->tmp_buf.address, dec->tmp_buf.size, + VK_FORMAT_UNDEFINED); + if (err < 0) + goto fail; + + for (int i = ctx->wavelet_depth - 1; i >= 0; i--) { + dec->pConst.plane_strides[0] = ctx->plane[0].idwt.width << i; + dec->pConst.plane_strides[1] = ctx->plane[1].idwt.width << i; + dec->pConst.plane_strides[2] = ctx->plane[2].idwt.width << i; + + dec->pConst.plane_offs[0] = 0; + dec->pConst.plane_offs[1] = + ctx->plane[0].idwt.width * ctx->plane[0].idwt.height; + dec->pConst.plane_offs[2] = + dec->pConst.plane_offs[1] + + ctx->plane[1].idwt.width * ctx->plane[1].idwt.height; + + dec->pConst.dw[0] = ctx->plane[0].idwt.width >> (i + 1); + dec->pConst.dw[1] = ctx->plane[1].idwt.width >> (i + 1); + dec->pConst.dw[2] = ctx->plane[2].idwt.width >> (i + 1); + + dec->pConst.real_plane_dims[0] = (ctx->plane[0].idwt.width) >> i; + dec->pConst.real_plane_dims[1] = (ctx->plane[0].idwt.height) >> i; + dec->pConst.real_plane_dims[2] = (ctx->plane[1].idwt.width) >> i; + dec->pConst.real_plane_dims[3] = (ctx->plane[1].idwt.height) >> i; + dec->pConst.real_plane_dims[4] = (ctx->plane[2].idwt.width) >> i; + dec->pConst.real_plane_dims[5] = (ctx->plane[2].idwt.height) >> i; + + /* Vertical wavelet pass */ + ff_vk_update_push_exec(&dec->vkctx, exec, pl_vert, + VK_SHADER_STAGE_COMPUTE_BIT, 0, + sizeof(WaveletPushConst), &dec->pConst); + + barrier_num = *nb_buf_bar; + bar_read(buf_bar, nb_buf_bar, &dec->tmp_buf); + bar_write(buf_bar, nb_buf_bar, &dec->tmp_buf); + bar_read(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); + bar_write(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); + + vk->CmdPipelineBarrier2( + exec->buf, + &(VkDependencyInfo){ + .sType = VK_STRUCTURE_TYPE_DEPENDENCY_INFO, + .pBufferMemoryBarriers = buf_bar + barrier_num, + .bufferMemoryBarrierCount = *nb_buf_bar - barrier_num, + }); + + ff_vk_exec_bind_pipeline(&dec->vkctx, exec, pl_vert); + vk->CmdDispatch(exec->buf, dec->pConst.real_plane_dims[0] >> 3, + dec->pConst.real_plane_dims[1] >> 4, 1); + + /* Horizontal wavelet pass */ + ff_vk_update_push_exec(&dec->vkctx, exec, pl_hor, + VK_SHADER_STAGE_COMPUTE_BIT, 0, + sizeof(WaveletPushConst), &dec->pConst); + + barrier_num = *nb_buf_bar; + bar_read(buf_bar, nb_buf_bar, &dec->tmp_buf); + bar_write(buf_bar, nb_buf_bar, &dec->tmp_buf); + bar_read(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); + bar_write(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); + + vk->CmdPipelineBarrier2( + exec->buf, + &(VkDependencyInfo){ + .sType = VK_STRUCTURE_TYPE_DEPENDENCY_INFO, + .pBufferMemoryBarriers = buf_bar + barrier_num, + .bufferMemoryBarrierCount = *nb_buf_bar - barrier_num, + }); + + ff_vk_exec_bind_pipeline(&dec->vkctx, exec, pl_hor); + vk->CmdDispatch(exec->buf, dec->pConst.real_plane_dims[0] >> 4, + dec->pConst.real_plane_dims[1] >> 3, 1); + } + + barrier_num = *nb_buf_bar; + bar_read(buf_bar, nb_buf_bar, &dec->tmp_buf); + bar_write(buf_bar, nb_buf_bar, &dec->tmp_buf); + bar_read(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); + bar_write(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); + + vk->CmdPipelineBarrier2( + exec->buf, &(VkDependencyInfo){ + .sType = VK_STRUCTURE_TYPE_DEPENDENCY_INFO, + .pBufferMemoryBarriers = buf_bar + barrier_num, + .bufferMemoryBarrierCount = *nb_buf_bar - barrier_num, + }); + + return 0; +fail: + ff_vk_exec_discard_deps(&dec->vkctx, exec); + return err; +} + +/* ----- Fidelity init and pipeline pass ----- */ + +static const char fidelity_low[] = {C( + 0, int32_t fidelity_low(int32_t v0, int32_t v1, int32_t v2, int32_t v3, + int32_t v4, int32_t v5, int32_t v6, int32_t v7) {) + C(1, return (-2 * v0 + 10 * v1 - 25 * v2 + 81 * v3 + 81 * v4 - 25 * v5 + 10 * v6 - 2 * v7 + 128) >> 8;) + C(0, + })}; + +static const char fidelity_high[] = {C( + 0, int32_t fidelity_high(int32_t v0, int32_t v1, int32_t v2, int32_t v3, + int32_t v4, int32_t v5, int32_t v6, int32_t v7) {) + C(1, return (-8 * v0 + 21 * v1 - 46 * v2 + 161 * v3 + 161 * v4 - 46 * v5 + 21 * v6 - 8 * v7 + 128) >> 8;) + C(0, + })}; + +static const char fidelity_low_y[] = {C( + 0, int32_t fidelity_low_y(int plane, int x, int y) { ) + C(1, const int h = plane_sizes[plane].y; ) + C(1, ) + C(1, const int32_t v1 = inBuf[getIdx(plane, x, y + 1)]; ) + C(1, ) + C(1, const int y_6 = ((y - 6) > 0) ? (y - 6) : 0; ) + C(1, const int32_t v_6 = inBuf[getIdx(plane, x, y_6)]; ) + C(1, ) + C(1, const int y_4 = ((y - 4) > 0) ? (y - 4) : 0; ) + C(1, const int32_t v_4 = inBuf[getIdx(plane, x, y_4)]; ) + C(1, ) + C(1, const int y_2 = ((y - 2) > 0) ? (y - 2) : 0; ) + C(1, const int32_t v_2 = inBuf[getIdx(plane, x, y_2)]; ) + C(1, ) + C(1, const int32_t v0 = inBuf[getIdx(plane, x, y)]; ) + C(1, ) + C(1, const int y2 = ((y + 2) < h) ? (y + 2) : (h - 2); ) + C(1, const int32_t v2 = inBuf[getIdx(plane, x, y2)]; ) + C(1, ) + C(1, const int y4 = ((y + 4) < h) ? (y + 4) : (h - 2); ) + C(1, const int32_t v4 = inBuf[getIdx(plane, x, y4)]; ) + C(1, ) + C(1, const int y6 = ((y + 6) < h) ? (y + 6) : (h - 2); ) + C(1, const int32_t v6 = inBuf[getIdx(plane, x, y6)]; ) + C(1, ) + C(1, const int y8 = ((y + 8) < h) ? (y + 8) : (h - 2); ) + C(1, const int32_t v8 = inBuf[getIdx(plane, x, y8)]; ) + C(1, ) + C(1, return v1 + fidelity_low(v_6, v_4, v_2, v0, v2, v4, v6, v8); ) + C(0, + })}; + +static const char fidelity_vert[] = {C( + 0, void idwt_vert(int plane, int x, int y) { ) + C(1, const int h = plane_sizes[plane].y; ) + C(1, ) + C(1, const int32_t v0 = inBuf[getIdx(plane, x, y)]; ) + C(1, const int32_t v1 = fidelity_low_y(plane, x, y); ) + C(1, const int32_t v_7 = (y - 8 > 0) ? fidelity_low_y(plane, x, y - 8) : v1; ) + C(1, const int32_t v_5 = (y - 6 > 0) ? fidelity_low_y(plane, x, y - 6) : v1; ) + C(1, const int32_t v_3 = (y - 4 > 0) ? fidelity_low_y(plane, x, y - 4) : v1; ) + C(1, const int32_t v_1 = (y - 2 > 0) ? fidelity_low_y(plane, x, y - 2) : v1; ) + C(1, const int32_t v3 = (y + 2 < h) ? fidelity_low_y(plane, x, y + 2) : ) + C(1, fidelity_low_y(plane, x, h - 2); ) + C(1, const int32_t v5 = (y + 4 < h) ? fidelity_low_y(plane, x, y + 4) : ) + C(1, fidelity_low_y(plane, x, h - 2); ) + C(1, const int32_t v7 = (y + 6 < h) ? fidelity_low_y(plane, x, y + 6) : ) + C(1, fidelity_low_y(plane, x, h - 2); ) + C(1, outBuf[getIdx(plane, x, y)] = v0 - fidelity_high(v_7, v_5, v_3, v_1, v1, v3, v5, v7);) + C(1, outBuf[getIdx(plane, x, y + 1)] = v1; ) + C(0, + })}; + +static const char fidelity_low_x[] = {C( + 0, int32_t fidelity_low_x(int plane, int x, int y) { ) + C(1, const int w = plane_sizes[plane].x; ) + C(1, const int dw = w / 2 - 1; ) + C(1, ) + C(1, const int x_3 = clamp(x - 3, 0, dw); ) + C(1, const int32_t v_3 = inBuf[getIdx(plane, x_3, y)]; ) + C(1, ) + C(1, const int x_2 = clamp(x - 2, 0, dw); ) + C(1, const int32_t v_2 = inBuf[getIdx(plane, x_2, y)]; ) + C(1, ) + C(1, const int x_1 = clamp(x - 1, 0, dw); ) + C(1, const int32_t v_1 = inBuf[getIdx(plane, x_1, y)]; ) + C(1, ) + C(1, const int32_t v0 = inBuf[getIdx(plane, x, y)]; ) + C(1, ) + C(1, const int x_w = x + dw + 1; ) + C(1, const int32_t v_w = inBuf[getIdx(plane, x_w, y)]; ) + C(1, ) + C(1, const int x1 = clamp(x + 1, 0, dw); ) + C(1, const int32_t v1 = inBuf[getIdx(plane, x1, y)]; ) + C(1, ) + C(1, const int x2 = clamp(x + 2, 0, dw); ) + C(1, const int32_t v2 = inBuf[getIdx(plane, x2, y)]; ) + C(1, ) + C(1, const int x3 = clamp(x + 3, 0, dw); ) + C(1, const int32_t v3 = inBuf[getIdx(plane, x3, y)]; ) + C(1, ) + C(1, const int x4 = clamp(x + 4, 0, dw); ) + C(1, const int32_t v4 = inBuf[getIdx(plane, x4, y)]; ) + C(1, ) + C(1, return v_w + fidelity_low(v_3, v_2, v_1, v0, v1, v2, v3, v4); ) + C(0, + })}; + +static const char fidelity_horiz[] = {C( + 0, void idwt_horiz(int plane, int x, int y) { ) + C(1, const int w = plane_sizes[plane].x; ) + C(1, const int dw = w / 2 - 1; ) + C(1, ) + C(1, const int32_t vo0 = inBuf[getIdx(plane, x, y)]; ) + C(1, ) + C(1, const int x_4 = clamp(x - 4, 0, dw); ) + C(1, const int32_t v_4 = fidelity_low_x(plane, x_4, y); ) + C(1, const int x_3 = clamp(x - 3, 0, dw); ) + C(1, const int32_t v_3 = fidelity_low_x(plane, x_3, y); ) + C(1, const int x_2 = clamp(x - 2, 0, dw); ) + C(1, const int32_t v_2 = fidelity_low_x(plane, x_2, y); ) + C(1, const int x_1 = clamp(x - 1, 0, dw); ) + C(1, const int32_t v_1 = fidelity_low_x(plane, x_1, y); ) + C(1, const int x0 = clamp(x, 0, dw); ) + C(1, const int32_t v0 = fidelity_low_x(plane, x0, y); ) + C(1, const int x1 = clamp(x + 1, 0, dw); ) + C(1, const int32_t v1 = fidelity_low_x(plane, x1, y); ) + C(1, const int x2 = clamp(x + 2, 0, dw); ) + C(1, const int32_t v2 = fidelity_low_x(plane, x2, y); ) + C(1, const int x3 = clamp(x + 3, 0, dw); ) + C(1, const int32_t v3 = fidelity_low_x(plane, x3, y); ) + C(1, ) + C(1, outBuf[getIdx(plane, 2 * x, y)] = vo0 - fidelity_high(v_4, v_3, v_2, v_1, v0, v1, v2, v3);) + C(1, outBuf[getIdx(plane, 2 * x + 1, y)] = v0; ) + C(0, + })}; + +static int init_wavelet_shd_fidelity_vert(DiracVulkanDecodeContext *s, + FFVkSPIRVCompiler *spv) { + int err = 0; + uint8_t *spv_data; + size_t spv_len; + void *spv_opaque = NULL; + int wavelet_idx = DWT_DIRAC_FIDELITY; + FFVulkanContext *vkctx = &s->vkctx; + FFVulkanDescriptorSetBinding *desc; + FFVkSPIRVShader *shd = &s->vert_wavelet_shd[wavelet_idx]; + FFVulkanPipeline *pl = &s->vert_wavelet_pl[wavelet_idx]; + FFVkExecPool *exec = &s->exec_pool; + + RET(ff_vk_shader_init(pl, shd, "fidelity_vert", VK_SHADER_STAGE_COMPUTE_BIT, + 0)); + + shd = &s->vert_wavelet_shd[wavelet_idx]; + ff_vk_shader_set_compute_sizes(shd, 8, 8, 3); + + GLSLC(0, #extension GL_EXT_scalar_block_layout : require); + GLSLC(0, #extension GL_EXT_shader_explicit_arithmetic_types : require); + + desc = (FFVulkanDescriptorSetBinding[]){ + { + .name = "in_buf", + .stages = VK_SHADER_STAGE_COMPUTE_BIT, + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, + .buf_content = "int32_t inBuf[];", + .mem_quali = "readonly", + .dimensions = 1, + }, + { + .name = "out_buf", + .stages = VK_SHADER_STAGE_COMPUTE_BIT, + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, + .buf_content = "int32_t outBuf[];", + .mem_quali = "writeonly", + .dimensions = 1, + }, + }; + RET(ff_vk_pipeline_descriptor_set_add(vkctx, pl, shd, desc, 2, 0, 0)); + + ff_vk_add_push_constant(pl, 0, sizeof(WaveletPushConst), + VK_SHADER_STAGE_COMPUTE_BIT); + + GLSLC( + 0, layout(push_constant, std430) uniform pushConstants { ); + GLSLC(1, ivec2 plane_sizes[3];); + GLSLC(1, int plane_offs[3];); + GLSLC(1, int plane_strides[3];); + GLSLC(1, int dw[3];); + GLSLC(1, int wavelet_depth;); + GLSLC(0, + };); + GLSLC(0, ); + + GLSLD(get_idx); + GLSLD(fidelity_low); + GLSLD(fidelity_high); + GLSLD(fidelity_low_y); + GLSLD(fidelity_vert); + + GLSLC( + 0, void main() { ); + GLSLC(1, int off_y = int(gl_WorkGroupSize.y * gl_NumWorkGroups.y);); + GLSLC(1, int off_x = int(gl_WorkGroupSize.x * gl_NumWorkGroups.x);); + GLSLC(1, int pic_z = int(gl_GlobalInvocationID.z);); + GLSLC(1, ); + GLSLC(1, uint h = int(plane_sizes[pic_z].y);); + GLSLC(2, uint w = int(plane_sizes[pic_z].x);); + GLSLC(1, ); + GLSLC(1, int y = int(gl_GlobalInvocationID.y);); + GLSLC( + 1, for (; 2 * y < h; y += off_y) { ); + GLSLC(2, int x = int(gl_GlobalInvocationID.x);); + GLSLC( + 2, for (; x < w; x += off_x) { ); + GLSLC(3, idwt_vert(pic_z, x, 2 * y);); + GLSLC(2, + }); + GLSLC(1, + }); + GLSLC(0, + }); + + RET(spv->compile_shader(spv, vkctx, shd, &spv_data, &spv_len, "main", + &spv_opaque)); + RET(ff_vk_shader_create(vkctx, shd, spv_data, spv_len, "main")); + RET(ff_vk_init_compute_pipeline(vkctx, pl, shd)); + + RET(ff_vk_exec_pipeline_register(vkctx, exec, pl)); + +fail: + if (spv_opaque) + spv->free_shader(spv, &spv_opaque); + + return err; +} + +static int init_wavelet_shd_fidelity_horiz(DiracVulkanDecodeContext *s, + FFVkSPIRVCompiler *spv) { + int err = 0; + uint8_t *spv_data; + size_t spv_len; + void *spv_opaque = NULL; + int wavelet_idx = DWT_DIRAC_FIDELITY; + FFVulkanContext *vkctx = &s->vkctx; + FFVulkanDescriptorSetBinding *desc; + FFVkSPIRVShader *shd = &s->horiz_wavelet_shd[wavelet_idx]; + FFVulkanPipeline *pl = &s->horiz_wavelet_pl[wavelet_idx]; + FFVkExecPool *exec = &s->exec_pool; + + RET(ff_vk_shader_init(pl, shd, "fidelity_horiz", + VK_SHADER_STAGE_COMPUTE_BIT, 0)); + + shd = &s->horiz_wavelet_shd[wavelet_idx]; + ff_vk_shader_set_compute_sizes(shd, 8, 8, 3); + + GLSLC(0, #extension GL_EXT_debug_printf : enable); + GLSLC(0, #extension GL_EXT_scalar_block_layout : require); + GLSLC(0, #extension GL_EXT_shader_explicit_arithmetic_types : require); + + desc = (FFVulkanDescriptorSetBinding[]){ + { + .name = "in_buf", + .stages = VK_SHADER_STAGE_COMPUTE_BIT, + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, + .buf_content = "int32_t inBuf[];", + .mem_quali = "readonly", + .dimensions = 1, + }, + { + .name = "out_buf", + .stages = VK_SHADER_STAGE_COMPUTE_BIT, + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, + .buf_content = "int32_t outBuf[];", + .mem_quali = "writeonly", + .dimensions = 1, + }, + }; + RET(ff_vk_pipeline_descriptor_set_add(vkctx, pl, shd, desc, 2, 0, 0)); + + ff_vk_add_push_constant(pl, 0, sizeof(WaveletPushConst), + VK_SHADER_STAGE_COMPUTE_BIT); + + GLSLC( + 0, layout(push_constant, std430) uniform pushConstants { ); + GLSLC(1, ivec2 plane_sizes[3];); + GLSLC(1, int plane_offs[3];); + GLSLC(1, int plane_strides[3];); + GLSLC(1, int dw[3];); + GLSLC(1, int wavelet_depth;); + GLSLC(0, + };); + GLSLC(0, ); + + GLSLD(get_idx); + GLSLD(fidelity_low); + GLSLD(fidelity_high); + GLSLD(fidelity_low_x); + GLSLD(fidelity_horiz); + + GLSLC( + 0, void main() { ); + GLSLC(1, int off_y = int(gl_WorkGroupSize.y * gl_NumWorkGroups.y);); + GLSLC(1, int off_x = int(gl_WorkGroupSize.x * gl_NumWorkGroups.x);); + GLSLC(1, int pic_z = int(gl_GlobalInvocationID.z);); + GLSLC(1, ); + GLSLC(1, uint h = int(plane_sizes[pic_z].y);); + GLSLC(2, uint w = int(plane_sizes[pic_z].x);); + GLSLC(1, ); + GLSLC(1, int y = int(gl_GlobalInvocationID.y);); + GLSLC( + 1, for (; y < h; y += off_y) { ); + GLSLC(2, int x = int(gl_GlobalInvocationID.x);); + GLSLC( + 2, for (; 2 * x < w; x += off_x) { ); + GLSLC(3, idwt_horiz(pic_z, x, y);); + GLSLC(2, + }); + GLSLC(1, + }); + GLSLC(0, + }); + + RET(spv->compile_shader(spv, vkctx, shd, &spv_data, &spv_len, "main", + &spv_opaque)); + RET(ff_vk_shader_create(vkctx, shd, spv_data, spv_len, "main")); + RET(ff_vk_init_compute_pipeline(vkctx, pl, shd)); + RET(ff_vk_exec_pipeline_register(vkctx, exec, pl)); + +fail: + if (spv_opaque) + spv->free_shader(spv, &spv_opaque); + + return err; +} + +static av_always_inline int inline wavelet_fidelity_pass( + DiracVulkanDecodeContext *dec, DiracContext *ctx, FFVkExecContext *exec, + VkBufferMemoryBarrier2 *buf_bar, int *nb_buf_bar) { + int err; + int barrier_num = *nb_buf_bar; + int wavelet_idx = DWT_DIRAC_FIDELITY; + FFVulkanFunctions *vk = &dec->vkctx.vkfn; + + FFVulkanPipeline *pl_hor = &dec->horiz_wavelet_pl[wavelet_idx]; + FFVulkanPipeline *pl_vert = &dec->vert_wavelet_pl[wavelet_idx]; + + err = ff_vk_set_descriptor_buffer(&dec->vkctx, pl_vert, exec, 0, 0, 0, + dec->tmp_buf.address, dec->tmp_buf.size, + VK_FORMAT_UNDEFINED); + if (err < 0) + goto fail; + err = ff_vk_set_descriptor_buffer( + &dec->vkctx, pl_vert, exec, 0, 1, 0, dec->tmp_interleave_buf.address, + dec->tmp_interleave_buf.size, VK_FORMAT_UNDEFINED); + if (err < 0) + goto fail; + + err = ff_vk_set_descriptor_buffer( + &dec->vkctx, pl_hor, exec, 0, 0, 0, dec->tmp_interleave_buf.address, + dec->tmp_interleave_buf.size, VK_FORMAT_UNDEFINED); + if (err < 0) + goto fail; + err = ff_vk_set_descriptor_buffer(&dec->vkctx, pl_hor, exec, 0, 1, 0, + dec->tmp_buf.address, dec->tmp_buf.size, + VK_FORMAT_UNDEFINED); + if (err < 0) + goto fail; + + for (int i = ctx->wavelet_depth - 1; i >= 0; i--) { + dec->pConst.plane_strides[0] = ctx->plane[0].idwt.width << i; + dec->pConst.plane_strides[1] = ctx->plane[1].idwt.width << i; + dec->pConst.plane_strides[2] = ctx->plane[2].idwt.width << i; + + dec->pConst.plane_offs[0] = 0; + dec->pConst.plane_offs[1] = + ctx->plane[0].idwt.width * ctx->plane[0].idwt.height; + dec->pConst.plane_offs[2] = + dec->pConst.plane_offs[1] + + ctx->plane[1].idwt.width * ctx->plane[1].idwt.height; + + dec->pConst.dw[0] = ctx->plane[0].idwt.width >> (i + 1); + dec->pConst.dw[1] = ctx->plane[1].idwt.width >> (i + 1); + dec->pConst.dw[2] = ctx->plane[2].idwt.width >> (i + 1); + + dec->pConst.real_plane_dims[0] = (ctx->plane[0].idwt.width) >> i; + dec->pConst.real_plane_dims[1] = (ctx->plane[0].idwt.height) >> i; + dec->pConst.real_plane_dims[2] = (ctx->plane[1].idwt.width) >> i; + dec->pConst.real_plane_dims[3] = (ctx->plane[1].idwt.height) >> i; + dec->pConst.real_plane_dims[4] = (ctx->plane[2].idwt.width) >> i; + dec->pConst.real_plane_dims[5] = (ctx->plane[2].idwt.height) >> i; + + /* Vertical wavelet pass */ + ff_vk_update_push_exec(&dec->vkctx, exec, pl_vert, + VK_SHADER_STAGE_COMPUTE_BIT, 0, + sizeof(WaveletPushConst), &dec->pConst); + + barrier_num = *nb_buf_bar; + bar_read(buf_bar, nb_buf_bar, &dec->tmp_buf); + bar_write(buf_bar, nb_buf_bar, &dec->tmp_buf); + bar_read(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); + bar_write(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); + + vk->CmdPipelineBarrier2( + exec->buf, + &(VkDependencyInfo){ + .sType = VK_STRUCTURE_TYPE_DEPENDENCY_INFO, + .pBufferMemoryBarriers = buf_bar + barrier_num, + .bufferMemoryBarrierCount = *nb_buf_bar - barrier_num, + }); + + ff_vk_exec_bind_pipeline(&dec->vkctx, exec, pl_vert); + vk->CmdDispatch(exec->buf, dec->pConst.real_plane_dims[0] >> 3, + dec->pConst.real_plane_dims[1] >> 4, 1); + + /* Horizontal wavelet pass */ + ff_vk_update_push_exec(&dec->vkctx, exec, pl_hor, + VK_SHADER_STAGE_COMPUTE_BIT, 0, + sizeof(WaveletPushConst), &dec->pConst); + + barrier_num = *nb_buf_bar; + bar_read(buf_bar, nb_buf_bar, &dec->tmp_buf); + bar_write(buf_bar, nb_buf_bar, &dec->tmp_buf); + bar_read(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); + bar_write(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); + + vk->CmdPipelineBarrier2( + exec->buf, + &(VkDependencyInfo){ + .sType = VK_STRUCTURE_TYPE_DEPENDENCY_INFO, + .pBufferMemoryBarriers = buf_bar + barrier_num, + .bufferMemoryBarrierCount = *nb_buf_bar - barrier_num, + }); + + ff_vk_exec_bind_pipeline(&dec->vkctx, exec, pl_hor); + vk->CmdDispatch(exec->buf, dec->pConst.real_plane_dims[0] >> 4, + dec->pConst.real_plane_dims[1] >> 3, 1); + } + + barrier_num = *nb_buf_bar; + bar_read(buf_bar, nb_buf_bar, &dec->tmp_buf); + bar_write(buf_bar, nb_buf_bar, &dec->tmp_buf); + bar_read(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); + bar_write(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); + + vk->CmdPipelineBarrier2( + exec->buf, &(VkDependencyInfo){ + .sType = VK_STRUCTURE_TYPE_DEPENDENCY_INFO, + .pBufferMemoryBarriers = buf_bar + barrier_num, + .bufferMemoryBarrierCount = *nb_buf_bar - barrier_num, + }); + + return 0; +fail: + ff_vk_exec_discard_deps(&dec->vkctx, exec); + return err; +} + +/* ----- Daubechies(9, 7) init and pipeline pass ----- */ + +static const char daub97_low1[] = {C( + 0, int32_t daub97_low1(int32_t v1, int32_t v2, int32_t v3) { ) + C(1, return v2 - ((1817 * (v1 + v2) + 2048) >> 12); ) + C(0, + })}; + +static const char daub97_high1[] = {C( + 0, int32_t daub97_high1(int32_t v1, int32_t v2, int32_t v3) { ) + C(1, return v2 - ((113 * (v1 + v2) + 64) >> 7); ) + C(0, + })}; + +static const char daub97_low0[] = {C( + 0, int32_t daub97_low0(int32_t v1, int32_t v2, int32_t v3) { ) + C(1, return v2 - ((217 * (v1 + v2) + 2048) >> 12); ) + C(0, + })}; + +static const char daub97_high0[] = {C( + 0, int32_t daub97_high0(int32_t v1, int32_t v2, int32_t v3) { ) + C(1, return v2 - ((6947 * (v1 + v2) + 2048) >> 12); ) + C(0, + })}; + +static const char daub97_low_x0[] = {C( + 0, int32_t daub97_low_x0(int plane, int x, int y) { ) + C(1, int w = plane_sizes[plane].x; ) + C(1, int dw = plane_sizes[plane].x / 2; ) + C(1, ) + C(1, int x0 = (x == 0) ? dw : x + dw; ) + C(1, int32_t v0 = inBuf[getIdx(plane, x0, y)]; ) + C(1, ) + C(1, int32_t v1 = inBuf[getIdx(plane, x, y)]; ) + C(1, ) + C(1, int x2 = x + dw; ) + C(1, int32_t v2 = inBuf[getIdx(plane, x0, y)]; ) + C(1, ) + C(1, return daub97_low1(v0, v1, v2); ) + C(0, + })}; + +static const char daub97_high_x0[] = {C( + 0, int32_t daub97_high_x0(int plane, int x, int y) { ) + C(1, int w = plane_sizes[plane].x; ) + C(1, int dw = plane_sizes[plane].x / 2; ) + C(1, ) + C(1, int x0 = (x == dw - 1) ? (dw - 1) : (x - 1); ) + C(1, int32_t v0 = daub97_low_x0(plane, x0, y); ) + C(1, ) + C(1, int32_t v1 = inBuf[getIdx(plane, x + dw - 1, y)]; ) + C(1, ) + C(1, int32_t v2 = daub97_low_x0(plane, x, y); ) + C(1, ) + C(1, return daub97_high1(v0, v1, v2); ) + C(0, + })}; + +static const char daub97_low_x1[] = {C( + 0, int32_t daub97_low_x1(int plane, int x, int y) { ) + C(1, int w = plane_sizes[plane].x; ) + C(1, int dw = plane_sizes[plane].x / 2; ) + C(1, ) + C(1, int32_t v0 = daub97_high_x0(plane, x, y); ) + C(1, ) + C(1, int32_t v1 = daub97_low_x0(plane, x, y); ) + C(1, ) + C(1, int32_t v2 = daub97_high_x0(plane, x + 1, y); ) + C(1, ) + C(1, return daub97_low0(v0, v1, v2); ) + C(0, + })}; + +static const char daub97_high_x1[] = {C( + 0, int32_t daub97_high_x1(int plane, int x, int y) { ) + C(1, int w = plane_sizes[plane].x; ) + C(1, int dw = plane_sizes[plane].x / 2; ) + C(1, ) + C(1, int x0 = clamp(x - 1, 0, dw); ) + C(1, int32_t v0 = daub97_low_x1(plane, x0, y); ) + C(1, ) + C(1, int32_t v1 = daub97_high_x0(plane, x + 1, y); ) + C(1, ) + C(1, int x2 = clamp(x, 0, dw); ) + C(1, int32_t v2 = daub97_low_x1(plane, x2, y); ) + C(1, ) + C(1, return daub97_high0(v0, v1, v2); ) + C(0, + })}; + +static const char daub97_horiz[] = {C( + 0, void idwt_horiz(int plane, int x, int y) { ) + C(1, int w = plane_sizes[plane].x; ) + C(1, int dw = plane_sizes[plane].x / 2; ) + C(1, ) + C(1, int32_t v0 = daub97_low_x1(plane, x, y); ) + C(1, int32_t v1 = daub97_high_x1(plane, x, y); ) + C(1, ) + C(1, outBuf[getIdx(plane, 2 * x, y)] = ~((~v0) >> 1); ) + C(1, outBuf[getIdx(plane, 2 * x + 1, y)] = ~((~v1) >> 1); ) + C(0, + })}; + +static int init_wavelet_shd_daub97_vert(DiracVulkanDecodeContext *s, + FFVkSPIRVCompiler *spv) { + int err = 0; + uint8_t *spv_data; + size_t spv_len; + void *spv_opaque = NULL; + int wavelet_idx = DWT_DIRAC_DAUB9_7; + FFVulkanContext *vkctx = &s->vkctx; + FFVulkanDescriptorSetBinding *desc; + FFVkSPIRVShader *shd = &s->vert_wavelet_shd[wavelet_idx]; + FFVulkanPipeline *pl = &s->vert_wavelet_pl[wavelet_idx]; + FFVkExecPool *exec = &s->exec_pool; + + RET(ff_vk_shader_init(pl, shd, "daub97_vert", VK_SHADER_STAGE_COMPUTE_BIT, + 0)); + + shd = &s->vert_wavelet_shd[wavelet_idx]; + ff_vk_shader_set_compute_sizes(shd, 8, 1, 3); + + GLSLC(0, #extension GL_EXT_debug_printf : enable); + GLSLC(0, #extension GL_EXT_scalar_block_layout : require); + GLSLC(0, #extension GL_EXT_shader_explicit_arithmetic_types : require); + + desc = (FFVulkanDescriptorSetBinding[]){ + { + .name = "in_buf", + .stages = VK_SHADER_STAGE_COMPUTE_BIT, + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, + .buf_content = "int32_t inBuf[];", + /*.mem_quali = "readonly",*/ + .dimensions = 1, + }, + { + .name = "out_buf", + .stages = VK_SHADER_STAGE_COMPUTE_BIT, + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, + .buf_content = "int32_t outBuf[];", + .mem_quali = "writeonly", + .dimensions = 1, + }, + }; + RET(ff_vk_pipeline_descriptor_set_add(vkctx, pl, shd, desc, 2, 0, 0)); + + ff_vk_add_push_constant(pl, 0, sizeof(WaveletPushConst), + VK_SHADER_STAGE_COMPUTE_BIT); + + GLSLC( + 0, layout(push_constant, std430) uniform pushConstants { ); + GLSLC(1, ivec2 plane_sizes[3];); + GLSLC(1, int plane_offs[3];); + GLSLC(1, int plane_strides[3];); + GLSLC(1, int dw[3];); + GLSLC(1, int wavelet_depth;); + GLSLC(0, + };); + GLSLC(0, ); + + GLSLD(get_idx); + + GLSLC( + 0, void main() { ); + GLSLC(1, int off_x = int(gl_WorkGroupSize.x * gl_NumWorkGroups.x);); + GLSLC(1, int pic_z = int(gl_GlobalInvocationID.z);); + GLSLC(1, ); + GLSLC(1, uint h = int(plane_sizes[pic_z].y);); + GLSLC(1, uint w = int(plane_sizes[pic_z].x);); + GLSLC(1, ); + GLSLC(2, int x = int(gl_GlobalInvocationID.x);); + GLSLC( + 1, for (; x < w; x += off_x) { ); + GLSLC( + 2, for (int y = 0; y < h; y += 2) { ); + GLSLC(3, int32_t v0 = inBuf[getIdx( + pic_z, x, int(clamp(y - 1, 0, h)))];); + GLSLC(3, + int32_t v1 = inBuf[getIdx(pic_z, x, y + 1)];); + GLSLC(3, inBuf[getIdx(pic_z, x, y)] -= + (1817 * (v0 + v1 + 2048)) >> 12;); + GLSLC(2, + }); + GLSLC( + 2, for (int y = 0; y < h; y += 2) { ); + GLSLC(3, int32_t v0 = inBuf[getIdx(pic_z, x, y)];); + GLSLC(3, + int32_t v1 = inBuf[getIdx( + pic_z, x, int(clamp(y + 2, 0, h - 2)))];); + GLSLC(3, inBuf[getIdx(pic_z, x, y + 1)] -= + (3616 * (v0 + v1 + 2048)) >> 12;); + GLSLC(2, + }); + GLSLC( + 2, for (int y = 0; y < h; y += 2) { ); + GLSLC(3, int32_t v0 = inBuf[getIdx( + pic_z, x, int(clamp(y - 1, 0, h)))];); + GLSLC(3, + int32_t v1 = inBuf[getIdx(pic_z, x, y + 1)];); + GLSLC(3, int32_t v2 = inBuf[getIdx(pic_z, x, y)];); + GLSLC(3, outBuf[getIdx(pic_z, x, y)] = + v2 + (217 * (v0 + v1 + 2048)) >> 12;); + GLSLC(2, + }); + GLSLC( + 2, for (int y = 0; y < h; y += 2) { ); + GLSLC(3, int32_t v0 = inBuf[getIdx(pic_z, x, y)];); + GLSLC(3, + int32_t v1 = inBuf[getIdx( + pic_z, x, int(clamp(y + 2, 0, h - 2)))];); + GLSLC(3, + int32_t v2 = inBuf[getIdx(pic_z, x, y + 1)];); + GLSLC(3, outBuf[getIdx(pic_z, x, y + 1)] = + v2 + (6497 * (v0 + v1 + 2048)) >> 12;); + GLSLC(2, + }); + GLSLC(1, + }); + GLSLC(0, + }); + + RET(spv->compile_shader(spv, vkctx, shd, &spv_data, &spv_len, "main", + &spv_opaque)); + RET(ff_vk_shader_create(vkctx, shd, spv_data, spv_len, "main")); + RET(ff_vk_init_compute_pipeline(vkctx, pl, shd)); + RET(ff_vk_exec_pipeline_register(vkctx, exec, pl)); + +fail: + if (spv_opaque) + spv->free_shader(spv, &spv_opaque); + + return err; +} + +static int init_wavelet_shd_daub97_horiz(DiracVulkanDecodeContext *s, + FFVkSPIRVCompiler *spv) { + int err = 0; + uint8_t *spv_data; + size_t spv_len; + void *spv_opaque = NULL; + int wavelet_idx = DWT_DIRAC_DAUB9_7; + FFVulkanContext *vkctx = &s->vkctx; + FFVulkanDescriptorSetBinding *desc; + FFVkSPIRVShader *shd = &s->horiz_wavelet_shd[wavelet_idx]; + FFVulkanPipeline *pl = &s->horiz_wavelet_pl[wavelet_idx]; + FFVkExecPool *exec = &s->exec_pool; + + RET(ff_vk_shader_init(pl, shd, "daub97_horiz", VK_SHADER_STAGE_COMPUTE_BIT, + 0)); + + shd = &s->horiz_wavelet_shd[wavelet_idx]; + ff_vk_shader_set_compute_sizes(shd, 8, 8, 3); + + GLSLC(0, #extension GL_EXT_debug_printf : enable); + GLSLC(0, #extension GL_EXT_scalar_block_layout : require); + GLSLC(0, #extension GL_EXT_shader_explicit_arithmetic_types : require); + + desc = (FFVulkanDescriptorSetBinding[]){ + { + .name = "in_buf", + .stages = VK_SHADER_STAGE_COMPUTE_BIT, + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, + .buf_content = "int32_t inBuf[];", + .mem_quali = "readonly", + .dimensions = 1, + }, + { + .name = "out_buf", + .stages = VK_SHADER_STAGE_COMPUTE_BIT, + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, + .buf_content = "int32_t outBuf[];", + .mem_quali = "writeonly", + .dimensions = 1, + }, + }; + RET(ff_vk_pipeline_descriptor_set_add(vkctx, pl, shd, desc, 2, 0, 0)); + + ff_vk_add_push_constant(pl, 0, sizeof(WaveletPushConst), + VK_SHADER_STAGE_COMPUTE_BIT); + + GLSLC( + 0, layout(push_constant, std430) uniform pushConstants { ); + GLSLC(1, ivec2 plane_sizes[3];); + GLSLC(1, int plane_offs[3];); + GLSLC(1, int plane_strides[3];); + GLSLC(1, int dw[3];); + GLSLC(1, int wavelet_depth;); + GLSLC(0, + };); + GLSLC(0, ); + + GLSLD(get_idx); + GLSLD(daub97_low1); + GLSLD(daub97_low0); + GLSLD(daub97_high1); + GLSLD(daub97_high0); + GLSLD(daub97_low_x0); + GLSLD(daub97_high_x0); + GLSLD(daub97_low_x1); + GLSLD(daub97_high_x1); + GLSLD(daub97_horiz); + + GLSLC( + 0, void main() { ); + GLSLC(1, int off_y = int(gl_WorkGroupSize.y * gl_NumWorkGroups.y);); + GLSLC(1, int off_x = int(gl_WorkGroupSize.x * gl_NumWorkGroups.x);); + GLSLC(1, int pic_z = int(gl_GlobalInvocationID.z);); + GLSLC(1, ); + GLSLC(1, uint h = int(plane_sizes[pic_z].y);); + GLSLC(2, uint w = int(plane_sizes[pic_z].x);); + GLSLC(1, ); + GLSLC(1, int y = int(gl_GlobalInvocationID.y);); + GLSLC( + 1, for (; y < h; y += off_y) { ); + GLSLC(2, int x = int(gl_GlobalInvocationID.x);); + GLSLC( + 2, for (; 2 * x < w; x += off_x) { ); + GLSLC(3, idwt_horiz(pic_z, x, y);); + GLSLC(2, + }); + GLSLC(1, + }); + GLSLC(0, + }); + + RET(spv->compile_shader(spv, vkctx, shd, &spv_data, &spv_len, "main", + &spv_opaque)); + RET(ff_vk_shader_create(vkctx, shd, spv_data, spv_len, "main")); + RET(ff_vk_init_compute_pipeline(vkctx, pl, shd)); + RET(ff_vk_exec_pipeline_register(vkctx, exec, pl)); + +fail: + if (spv_opaque) + spv->free_shader(spv, &spv_opaque); + + return err; +} + +static av_always_inline int inline wavelet_daub97_pass( + DiracVulkanDecodeContext *dec, DiracContext *ctx, FFVkExecContext *exec, + VkBufferMemoryBarrier2 *buf_bar, int *nb_buf_bar) { + int err; + int barrier_num = *nb_buf_bar; + int wavelet_idx = DWT_DIRAC_DAUB9_7; + FFVulkanFunctions *vk = &dec->vkctx.vkfn; + + FFVulkanPipeline *pl_hor = &dec->horiz_wavelet_pl[wavelet_idx]; + FFVulkanPipeline *pl_vert = &dec->vert_wavelet_pl[wavelet_idx]; + + err = ff_vk_set_descriptor_buffer(&dec->vkctx, pl_vert, exec, 0, 0, 0, + dec->tmp_buf.address, dec->tmp_buf.size, + VK_FORMAT_UNDEFINED); + if (err < 0) + goto fail; + err = ff_vk_set_descriptor_buffer( + &dec->vkctx, pl_vert, exec, 0, 1, 0, dec->tmp_interleave_buf.address, + dec->tmp_interleave_buf.size, VK_FORMAT_UNDEFINED); + if (err < 0) + goto fail; + + err = ff_vk_set_descriptor_buffer( + &dec->vkctx, pl_hor, exec, 0, 0, 0, dec->tmp_interleave_buf.address, + dec->tmp_interleave_buf.size, VK_FORMAT_UNDEFINED); + if (err < 0) + goto fail; + err = ff_vk_set_descriptor_buffer(&dec->vkctx, pl_hor, exec, 0, 1, 0, + dec->tmp_buf.address, dec->tmp_buf.size, + VK_FORMAT_UNDEFINED); + if (err < 0) + goto fail; + + for (int i = ctx->wavelet_depth - 1; i >= 0; i--) { + dec->pConst.plane_strides[0] = ctx->plane[0].idwt.width << i; + dec->pConst.plane_strides[1] = ctx->plane[1].idwt.width << i; + dec->pConst.plane_strides[2] = ctx->plane[2].idwt.width << i; + + dec->pConst.plane_offs[0] = 0; + dec->pConst.plane_offs[1] = + ctx->plane[0].idwt.width * ctx->plane[0].idwt.height; + dec->pConst.plane_offs[2] = + dec->pConst.plane_offs[1] + + ctx->plane[1].idwt.width * ctx->plane[1].idwt.height; + + dec->pConst.dw[0] = ctx->plane[0].idwt.width >> (i + 1); + dec->pConst.dw[1] = ctx->plane[1].idwt.width >> (i + 1); + dec->pConst.dw[2] = ctx->plane[2].idwt.width >> (i + 1); + + dec->pConst.real_plane_dims[0] = (ctx->plane[0].idwt.width) >> i; + dec->pConst.real_plane_dims[1] = (ctx->plane[0].idwt.height) >> i; + dec->pConst.real_plane_dims[2] = (ctx->plane[1].idwt.width) >> i; + dec->pConst.real_plane_dims[3] = (ctx->plane[1].idwt.height) >> i; + dec->pConst.real_plane_dims[4] = (ctx->plane[2].idwt.width) >> i; + dec->pConst.real_plane_dims[5] = (ctx->plane[2].idwt.height) >> i; + + /* Vertical wavelet pass */ + ff_vk_update_push_exec(&dec->vkctx, exec, pl_vert, + VK_SHADER_STAGE_COMPUTE_BIT, 0, + sizeof(WaveletPushConst), &dec->pConst); + + barrier_num = *nb_buf_bar; + bar_read(buf_bar, nb_buf_bar, &dec->tmp_buf); + bar_write(buf_bar, nb_buf_bar, &dec->tmp_buf); + bar_read(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); + bar_write(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); + + vk->CmdPipelineBarrier2( + exec->buf, + &(VkDependencyInfo){ + .sType = VK_STRUCTURE_TYPE_DEPENDENCY_INFO, + .pBufferMemoryBarriers = buf_bar + barrier_num, + .bufferMemoryBarrierCount = *nb_buf_bar - barrier_num, + }); + + ff_vk_exec_bind_pipeline(&dec->vkctx, exec, pl_vert); + vk->CmdDispatch(exec->buf, dec->pConst.real_plane_dims[0], 1, 1); + + /* Horizontal wavelet pass */ + ff_vk_update_push_exec(&dec->vkctx, exec, pl_hor, + VK_SHADER_STAGE_COMPUTE_BIT, 0, + sizeof(WaveletPushConst), &dec->pConst); + + barrier_num = *nb_buf_bar; + bar_read(buf_bar, nb_buf_bar, &dec->tmp_buf); + bar_write(buf_bar, nb_buf_bar, &dec->tmp_buf); + bar_read(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); + bar_write(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); + + vk->CmdPipelineBarrier2( + exec->buf, + &(VkDependencyInfo){ + .sType = VK_STRUCTURE_TYPE_DEPENDENCY_INFO, + .pBufferMemoryBarriers = buf_bar + barrier_num, + .bufferMemoryBarrierCount = *nb_buf_bar - barrier_num, + }); + + ff_vk_exec_bind_pipeline(&dec->vkctx, exec, pl_hor); + vk->CmdDispatch(exec->buf, dec->pConst.real_plane_dims[0] >> 4, + dec->pConst.real_plane_dims[1] >> 3, 1); + + barrier_num = *nb_buf_bar; + bar_read(buf_bar, nb_buf_bar, &dec->tmp_buf); + bar_write(buf_bar, nb_buf_bar, &dec->tmp_buf); + bar_read(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); + bar_write(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); + + vk->CmdPipelineBarrier2( + exec->buf, + &(VkDependencyInfo){ + .sType = VK_STRUCTURE_TYPE_DEPENDENCY_INFO, + .pBufferMemoryBarriers = buf_bar + barrier_num, + .bufferMemoryBarrierCount = *nb_buf_bar - barrier_num, + }); + } + + return 0; +fail: + ff_vk_exec_discard_deps(&dec->vkctx, exec); + return err; +} + +/* ----- Deslauriers-Dubuc(9, 7) init and pipeline pass ----- */ + +static const char dd97_high[] = {C( + 0, int32_t dd97_high(int32_t v1, int32_t v2, int32_t v3, int32_t v4, + int32_t v5) { ) + C(1, return v3 + ((9 * v4 + 9 * v2 - v5 - v1 + 8) >> 4); ) + C(0, + })}; + +static const char dd97_vert[] = {C( + 0, void idwt_vert(int plane, int x, int y) { ) + C(1, const int h = plane_sizes[plane].y; ) + C(1, ) + C(1, const int32_t out0 = legall_low_y(plane, x, y); ) + C(1, const int32_t out_2 = (y - 2 > 0) ? legall_low_y(plane, x, y - 2) : ) + C(1, legall_low_y(plane, x, 0); ) + C(1, const int32_t out2 = (y + 2 < h) ? legall_low_y(plane, x, y + 2) : ) + C(1, legall_low_y(plane, x, h - 2); ) + C(1, const int32_t out4 = (y + 4 < h) ? legall_low_y(plane, x, y + 4) : ) + C(1, legall_low_y(plane, x, h - 2); ) + C(1, const int32_t val1 = inBuf[getIdx(plane, x, y + 1)]; ) + C(1, ) + C(1, outBuf[getIdx(plane, x, y)] = out0; ) + C(1, outBuf[getIdx(plane, x, y + 1)] = dd97_high(out_2, out0, val1, out2, out4); ) + C(1, + })}; + +static const char dd97_horiz[] = {C( + 0, void idwt_horiz(int plane, int x, int y) { ) + C(1, const int w = plane_sizes[plane].x; ) + C(1, const int dw = w / 2 - 1; ) + C(1, ) + C(1, const int32_t out0 = legall_low_x(plane, x, y); ) + C(1, ) + C(1, const int32_t out_1 = ((x - 1) > 0) ? legall_low_x(plane, x - 1, y) : out0; ) + C(1, const int32_t val3 = inBuf[getIdx(plane, x + dw + 1, y)]; ) + C(1, const int32_t out1 = ((x + 1) <= dw) ? legall_low_x(plane, x + 1, y) : ) + C(1, legall_low_x(plane, dw, y); ) + C(1, const int32_t out2 = ((x + 2) <= dw) ? legall_low_x(plane, x + 2, y) : ) + C(1, legall_low_x(plane, dw, y); ) + C(1, const int32_t res = dd97_high(out_1, out0, val3, out1, out2); ) + C(1, ) + C(1, outBuf[getIdx(plane, 2 * x, y)] = (out0 + 1) >> 1; ) + C(1, outBuf[getIdx(plane, 2 * x + 1, y)] = (res + 1) >> 1; ) + C(0, + })}; + +static int init_wavelet_shd_dd97_vert(DiracVulkanDecodeContext *s, + FFVkSPIRVCompiler *spv) { + int err = 0; + uint8_t *spv_data; + size_t spv_len; + void *spv_opaque = NULL; + int wavelet_idx = DWT_DIRAC_DD9_7; + FFVulkanContext *vkctx = &s->vkctx; + FFVulkanDescriptorSetBinding *desc; + FFVkSPIRVShader *shd = &s->vert_wavelet_shd[wavelet_idx]; + FFVulkanPipeline *pl = &s->vert_wavelet_pl[wavelet_idx]; + FFVkExecPool *exec = &s->exec_pool; + + RET(ff_vk_shader_init(pl, shd, "dd97_vert", VK_SHADER_STAGE_COMPUTE_BIT, + 0)); + + shd = &s->vert_wavelet_shd[wavelet_idx]; + ff_vk_shader_set_compute_sizes(shd, 8, 8, 3); + + GLSLC(0, #extension GL_EXT_scalar_block_layout : require); + GLSLC(0, #extension GL_EXT_shader_explicit_arithmetic_types : require); + + desc = (FFVulkanDescriptorSetBinding[]){ + { + .name = "in_buf", + .stages = VK_SHADER_STAGE_COMPUTE_BIT, + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, + .buf_content = "int32_t inBuf[];", + .mem_quali = "readonly", + .dimensions = 1, + }, + { + .name = "out_buf", + .stages = VK_SHADER_STAGE_COMPUTE_BIT, + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, + .buf_content = "int32_t outBuf[];", + .mem_quali = "writeonly", + .dimensions = 1, + }, + }; + RET(ff_vk_pipeline_descriptor_set_add(vkctx, pl, shd, desc, 2, 0, 0)); + + ff_vk_add_push_constant(pl, 0, sizeof(WaveletPushConst), + VK_SHADER_STAGE_COMPUTE_BIT); + + GLSLC( + 0, layout(push_constant, std430) uniform pushConstants { ); + GLSLC(1, ivec2 plane_sizes[3];); + GLSLC(1, int plane_offs[3];); + GLSLC(1, int plane_strides[3];); + GLSLC(1, int dw[3];); + GLSLC(1, int wavelet_depth;); + GLSLC(0, + };); + GLSLC(0, ); + + GLSLD(get_idx); + GLSLD(legall_low_y); + GLSLD(dd97_high); + GLSLD(dd97_vert); + + GLSLC( + 0, void main() { ); + GLSLC(1, int off_y = int(gl_WorkGroupSize.y * gl_NumWorkGroups.y);); + GLSLC(1, int off_x = int(gl_WorkGroupSize.x * gl_NumWorkGroups.x);); + GLSLC(1, int pic_z = int(gl_GlobalInvocationID.z);); + GLSLC(1, ); + GLSLC(1, uint h = int(plane_sizes[pic_z].y);); + GLSLC(2, uint w = int(plane_sizes[pic_z].x);); + GLSLC(1, ); + GLSLC(1, int y = int(gl_GlobalInvocationID.y);); + GLSLC( + 1, for (; 2 * y < h; y += off_y) { ); + GLSLC(2, int x = int(gl_GlobalInvocationID.x);); + GLSLC( + 2, for (; x < w; x += off_x) { ); + GLSLC(3, idwt_vert(pic_z, x, 2 * y);); + GLSLC(2, + }); + GLSLC(1, + }); + GLSLC(0, + }); + + RET(spv->compile_shader(spv, vkctx, shd, &spv_data, &spv_len, "main", + &spv_opaque)); + RET(ff_vk_shader_create(vkctx, shd, spv_data, spv_len, "main")); + RET(ff_vk_init_compute_pipeline(vkctx, pl, shd)); + RET(ff_vk_exec_pipeline_register(vkctx, exec, pl)); + +fail: + if (spv_opaque) + spv->free_shader(spv, &spv_opaque); + + return err; +} + +static int init_wavelet_shd_dd97_horiz(DiracVulkanDecodeContext *s, + FFVkSPIRVCompiler *spv) { + int err = 0; + uint8_t *spv_data; + size_t spv_len; + void *spv_opaque = NULL; + int wavelet_idx = DWT_DIRAC_DD9_7; + FFVulkanContext *vkctx = &s->vkctx; + FFVulkanDescriptorSetBinding *desc; + FFVkSPIRVShader *shd = &s->horiz_wavelet_shd[wavelet_idx]; + FFVulkanPipeline *pl = &s->horiz_wavelet_pl[wavelet_idx]; + FFVkExecPool *exec = &s->exec_pool; + + RET(ff_vk_shader_init(pl, shd, "dd97_horiz", VK_SHADER_STAGE_COMPUTE_BIT, + 0)); + + shd = &s->horiz_wavelet_shd[wavelet_idx]; + ff_vk_shader_set_compute_sizes(shd, 8, 8, 3); + + GLSLC(0, #extension GL_EXT_debug_printf : enable); + GLSLC(0, #extension GL_EXT_scalar_block_layout : require); + GLSLC(0, #extension GL_EXT_shader_explicit_arithmetic_types : require); + + desc = (FFVulkanDescriptorSetBinding[]){ + { + .name = "in_buf", + .stages = VK_SHADER_STAGE_COMPUTE_BIT, + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, + .buf_content = "int32_t inBuf[];", + .mem_quali = "readonly", + .dimensions = 1, + }, + { + .name = "out_buf", + .stages = VK_SHADER_STAGE_COMPUTE_BIT, + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, + .buf_content = "int32_t outBuf[];", + .mem_quali = "writeonly", + .dimensions = 1, + }, + }; + RET(ff_vk_pipeline_descriptor_set_add(vkctx, pl, shd, desc, 2, 0, 0)); + + ff_vk_add_push_constant(pl, 0, sizeof(WaveletPushConst), + VK_SHADER_STAGE_COMPUTE_BIT); + + GLSLC( + 0, layout(push_constant, std430) uniform pushConstants { ); + GLSLC(1, ivec2 plane_sizes[3];); + GLSLC(1, int plane_offs[3];); + GLSLC(1, int plane_strides[3];); + GLSLC(1, int dw[3];); + GLSLC(1, int wavelet_depth;); + GLSLC(0, + };); + GLSLC(0, ); + + GLSLD(get_idx); + GLSLD(legall_low_x); + GLSLD(dd97_high); + GLSLD(dd97_horiz); + + GLSLC( + 0, void main() { ); + GLSLC(1, int off_y = int(gl_WorkGroupSize.y * gl_NumWorkGroups.y);); + GLSLC(1, int off_x = int(gl_WorkGroupSize.x * gl_NumWorkGroups.x);); + GLSLC(1, int pic_z = int(gl_GlobalInvocationID.z);); + GLSLC(1, ); + GLSLC(1, uint h = int(plane_sizes[pic_z].y);); + GLSLC(2, uint w = int(plane_sizes[pic_z].x);); + GLSLC(1, ); + GLSLC(1, int y = int(gl_GlobalInvocationID.y);); + GLSLC( + 1, for (; y < h; y += off_y) { ); + GLSLC(2, int x = int(gl_GlobalInvocationID.x);); + GLSLC( + 2, for (; 2 * x < w; x += off_x) { ); + GLSLC(3, idwt_horiz(pic_z, x, y);); + GLSLC(2, + }); + GLSLC(1, + }); + GLSLC(0, + }); + + RET(spv->compile_shader(spv, vkctx, shd, &spv_data, &spv_len, "main", + &spv_opaque)); + RET(ff_vk_shader_create(vkctx, shd, spv_data, spv_len, "main")); + RET(ff_vk_init_compute_pipeline(vkctx, pl, shd)); + RET(ff_vk_exec_pipeline_register(vkctx, exec, pl)); + +fail: + if (spv_opaque) + spv->free_shader(spv, &spv_opaque); + + return err; +} + +static av_always_inline int inline wavelet_dd97_pass( + DiracVulkanDecodeContext *dec, DiracContext *ctx, FFVkExecContext *exec, + VkBufferMemoryBarrier2 *buf_bar, int *nb_buf_bar) { + int err; + int barrier_num = *nb_buf_bar; + int wavelet_idx = DWT_DIRAC_DD9_7; + FFVulkanFunctions *vk = &dec->vkctx.vkfn; + + FFVulkanPipeline *pl_hor = &dec->horiz_wavelet_pl[wavelet_idx]; + FFVulkanPipeline *pl_vert = &dec->vert_wavelet_pl[wavelet_idx]; + + err = ff_vk_set_descriptor_buffer(&dec->vkctx, pl_vert, exec, 0, 0, 0, + dec->tmp_buf.address, dec->tmp_buf.size, + VK_FORMAT_UNDEFINED); + if (err < 0) + goto fail; + err = ff_vk_set_descriptor_buffer( + &dec->vkctx, pl_vert, exec, 0, 1, 0, dec->tmp_interleave_buf.address, + dec->tmp_interleave_buf.size, VK_FORMAT_UNDEFINED); + if (err < 0) + goto fail; + + err = ff_vk_set_descriptor_buffer( + &dec->vkctx, pl_hor, exec, 0, 0, 0, dec->tmp_interleave_buf.address, + dec->tmp_interleave_buf.size, VK_FORMAT_UNDEFINED); + if (err < 0) + goto fail; + err = ff_vk_set_descriptor_buffer(&dec->vkctx, pl_hor, exec, 0, 1, 0, + dec->tmp_buf.address, dec->tmp_buf.size, + VK_FORMAT_UNDEFINED); + if (err < 0) + goto fail; + + for (int i = ctx->wavelet_depth - 1; i >= 0; i--) { + dec->pConst.plane_strides[0] = ctx->plane[0].idwt.width << i; + dec->pConst.plane_strides[1] = ctx->plane[1].idwt.width << i; + dec->pConst.plane_strides[2] = ctx->plane[2].idwt.width << i; + + dec->pConst.plane_offs[0] = 0; + dec->pConst.plane_offs[1] = + ctx->plane[0].idwt.width * ctx->plane[0].idwt.height; + dec->pConst.plane_offs[2] = + dec->pConst.plane_offs[1] + + ctx->plane[1].idwt.width * ctx->plane[1].idwt.height; + + dec->pConst.dw[0] = ctx->plane[0].idwt.width >> (i + 1); + dec->pConst.dw[1] = ctx->plane[1].idwt.width >> (i + 1); + dec->pConst.dw[2] = ctx->plane[2].idwt.width >> (i + 1); + + dec->pConst.real_plane_dims[0] = (ctx->plane[0].idwt.width) >> i; + dec->pConst.real_plane_dims[1] = (ctx->plane[0].idwt.height) >> i; + dec->pConst.real_plane_dims[2] = (ctx->plane[1].idwt.width) >> i; + dec->pConst.real_plane_dims[3] = (ctx->plane[1].idwt.height) >> i; + dec->pConst.real_plane_dims[4] = (ctx->plane[2].idwt.width) >> i; + dec->pConst.real_plane_dims[5] = (ctx->plane[2].idwt.height) >> i; + + /* Vertical wavelet pass */ + ff_vk_update_push_exec(&dec->vkctx, exec, pl_vert, + VK_SHADER_STAGE_COMPUTE_BIT, 0, + sizeof(WaveletPushConst), &dec->pConst); + + barrier_num = *nb_buf_bar; + bar_read(buf_bar, nb_buf_bar, &dec->tmp_buf); + bar_write(buf_bar, nb_buf_bar, &dec->tmp_buf); + bar_read(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); + bar_write(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); + + vk->CmdPipelineBarrier2( + exec->buf, + &(VkDependencyInfo){ + .sType = VK_STRUCTURE_TYPE_DEPENDENCY_INFO, + .pBufferMemoryBarriers = buf_bar + barrier_num, + .bufferMemoryBarrierCount = *nb_buf_bar - barrier_num, + }); + + ff_vk_exec_bind_pipeline(&dec->vkctx, exec, pl_vert); + vk->CmdDispatch(exec->buf, dec->pConst.real_plane_dims[0] >> 3, + dec->pConst.real_plane_dims[1] >> 4, 1); + + /* Horizontal wavelet pass */ + ff_vk_update_push_exec(&dec->vkctx, exec, pl_hor, + VK_SHADER_STAGE_COMPUTE_BIT, 0, + sizeof(WaveletPushConst), &dec->pConst); + + barrier_num = *nb_buf_bar; + bar_read(buf_bar, nb_buf_bar, &dec->tmp_buf); + bar_write(buf_bar, nb_buf_bar, &dec->tmp_buf); + bar_read(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); + bar_write(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); + + vk->CmdPipelineBarrier2( + exec->buf, + &(VkDependencyInfo){ + .sType = VK_STRUCTURE_TYPE_DEPENDENCY_INFO, + .pBufferMemoryBarriers = buf_bar + barrier_num, + .bufferMemoryBarrierCount = *nb_buf_bar - barrier_num, + }); + + ff_vk_exec_bind_pipeline(&dec->vkctx, exec, pl_hor); + vk->CmdDispatch(exec->buf, dec->pConst.real_plane_dims[0] >> 4, + dec->pConst.real_plane_dims[1] >> 3, 1); + + barrier_num = *nb_buf_bar; + bar_read(buf_bar, nb_buf_bar, &dec->tmp_buf); + bar_write(buf_bar, nb_buf_bar, &dec->tmp_buf); + bar_read(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); + bar_write(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); + + vk->CmdPipelineBarrier2( + exec->buf, + &(VkDependencyInfo){ + .sType = VK_STRUCTURE_TYPE_DEPENDENCY_INFO, + .pBufferMemoryBarriers = buf_bar + barrier_num, + .bufferMemoryBarrierCount = *nb_buf_bar - barrier_num, + }); + } + + return 0; +fail: + ff_vk_exec_discard_deps(&dec->vkctx, exec); + return err; +} + +/* ----- Deslauriers-Dubuc(13, 7) init and pipeline pass ----- */ +static const char dd137_low[] = {C( + 0, int32_t dd137_low(int32_t v0, int32_t v1, int32_t v2, int32_t v3, + int32_t v4) { ) + C(0, return v2 - ((9 * v1 + 9 * v3 - v4 - v0 + 16) >> 5); ) + C(0, + })}; + +static const char dd137_low_y[] = {C( + 0, int32_t dd137_low_y(int plane, int x, int y) { ) + C(1, const int h = plane_sizes[plane].y; ) + C(1, ) + C(1, const int y0 = (x > 3) ? (y - 3) : 1; ) + C(1, const int32_t v0 = inBuf[getIdx(plane, x, y0)]; ) + C(1, ) + C(1, const int y1 = (y > 1) ? (y - 1) : 1; ) + C(1, const int32_t v1 = inBuf[getIdx(plane, x, y1)]; ) + C(1, ) + C(1, const int y2 = y; ) + C(1, const int32_t v2 = inBuf[getIdx(plane, x, y2)]; ) + C(1, ) + C(1, const int y3 = y + 1; ) + C(1, const int32_t v3 = inBuf[getIdx(plane, x, y3)]; ) + C(1, ) + C(1, const int y4 = (y + 3 < h) ? (y + 3) : (h - 1); ) + C(1, const int32_t v4 = inBuf[getIdx(plane, x, y4)]; ) + C(1, ) + C(1, return dd137_low(v0, v1, v2, v3, v4); ) + C(0, + })}; + +static const char dd137_vert[] = {C( + 0, void idwt_vert(int plane, int x, int y) { ) + C(1, const int h = plane_sizes[plane].y; ) + C(1, ) + C(1, const int32_t out0 = dd137_low_y(plane, x, y); ) + C(1, const int32_t out_2 = (y - 2 > 0) ? dd137_low_y(plane, x, y - 2) : ) + C(1, dd137_low_y(plane, x, 0); ) + C(1, const int32_t out2 = (y + 2 < h) ? dd137_low_y(plane, x, y + 2) : ) + C(1, dd137_low_y(plane, x, h - 2); ) + C(1, const int32_t out4 = (y + 4 < h) ? dd137_low_y(plane, x, y + 4) : ) + C(1, dd137_low_y(plane, x, h - 2); ) + C(1, const int32_t val1 = inBuf[getIdx(plane, x, y + 1)]; ) + C(1, ) + C(1, outBuf[getIdx(plane, x, y)] = out0; ) + C(1, outBuf[getIdx(plane, x, y + 1)] = dd97_high(out_2, out0, val1, out2, out4); ) + C(1, + })}; + +static const char dd137_low_x[] = {C( + 0, int32_t dd137_low_x(int plane, int x, int y) { ) + C(1, const int w = plane_sizes[plane].x; ) + C(1, const int dw = w / 2; ) + C(1, ) + C(1, const int x0 = (x > 1) ? x : dw; ) + C(1, const int32_t v0 = inBuf[getIdx(plane, x0, y)]; ) + C(1, ) + C(1, const int x1 = (x > 1) ? (x + dw - 2) : dw; ) + C(1, const int32_t v1 = inBuf[getIdx(plane, x1, y)]; ) + C(1, ) + C(1, const int x2 = x; ) + C(1, const int32_t v2 = inBuf[getIdx(plane, x2, y)]; ) + C(1, ) + C(1, const int x3 = x + dw; ) + C(1, const int32_t v3 = inBuf[getIdx(plane, x3, y)]; ) + C(1, ) + C(1, const int x4 = (x != dw - 1) ? (x + dw + 1) : (dw - 1); ) + C(1, const int32_t v4 = inBuf[getIdx(plane, x4, y)]; ) + C(1, ) + C(1, return dd137_low(v0, v1, v2, v3, v4); ) + C(0, + })}; + +static const char dd137_horiz[] = {C( + 0, void idwt_horiz(int plane, int x, int y) { ) + C(1, const int w = plane_sizes[plane].x; ) + C(1, const int dw = w / 2 - 1; ) + C(1, ) + C(1, const int32_t out0 = dd137_low_x(plane, x, y); ) + C(1, ) + C(1, const int32_t out_1 = ((x - 1) > 0) ? dd137_low_x(plane, x - 1, y) : out0; ) + C(1, const int32_t val3 = inBuf[getIdx(plane, x + dw + 1, y)]; ) + C(1, const int32_t out1 = ((x + 1) <= dw) ? dd137_low_x(plane, x + 1, y) : ) + C(1, dd137_low_x(plane, dw, y); ) + C(1, const int32_t out2 = ((x + 2) <= dw) ? dd137_low_x(plane, x + 2, y) : ) + C(1, dd137_low_x(plane, dw, y); ) + C(1, const int32_t res = dd97_high(out_1, out0, val3, out1, out2); ) + C(1, ) + C(1, outBuf[getIdx(plane, 2 * x, y)] = (out0 + 1) >> 1; ) + C(1, outBuf[getIdx(plane, 2 * x + 1, y)] = (res + 1) >> 1; ) + C(0, + })}; + +static int init_wavelet_shd_dd137_vert(DiracVulkanDecodeContext *s, + FFVkSPIRVCompiler *spv) { + int err = 0; + uint8_t *spv_data; + size_t spv_len; + void *spv_opaque = NULL; + int wavelet_idx = DWT_DIRAC_DD13_7; + FFVulkanContext *vkctx = &s->vkctx; + FFVulkanDescriptorSetBinding *desc; + FFVkSPIRVShader *shd = &s->vert_wavelet_shd[wavelet_idx]; + FFVulkanPipeline *pl = &s->vert_wavelet_pl[wavelet_idx]; + FFVkExecPool *exec = &s->exec_pool; + + RET(ff_vk_shader_init(pl, shd, "dd137_vert", VK_SHADER_STAGE_COMPUTE_BIT, + 0)); + + shd = &s->vert_wavelet_shd[wavelet_idx]; + ff_vk_shader_set_compute_sizes(shd, 8, 8, 3); + + GLSLC(0, #extension GL_EXT_scalar_block_layout : require); + GLSLC(0, #extension GL_EXT_shader_explicit_arithmetic_types : require); + + desc = (FFVulkanDescriptorSetBinding[]){ + { + .name = "in_buf", + .stages = VK_SHADER_STAGE_COMPUTE_BIT, + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, + .buf_content = "int32_t inBuf[];", + .mem_quali = "readonly", + .dimensions = 1, + }, + { + .name = "out_buf", + .stages = VK_SHADER_STAGE_COMPUTE_BIT, + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, + .buf_content = "int32_t outBuf[];", + .mem_quali = "writeonly", + .dimensions = 1, + }, + }; + RET(ff_vk_pipeline_descriptor_set_add(vkctx, pl, shd, desc, 2, 0, 0)); + + ff_vk_add_push_constant(pl, 0, sizeof(WaveletPushConst), + VK_SHADER_STAGE_COMPUTE_BIT); + + GLSLC( + 0, layout(push_constant, std430) uniform pushConstants { ); + GLSLC(1, ivec2 plane_sizes[3];); + GLSLC(1, int plane_offs[3];); + GLSLC(1, int plane_strides[3];); + GLSLC(1, int dw[3];); + GLSLC(1, int wavelet_depth;); + GLSLC(0, + };); + GLSLC(0, ); + + GLSLD(get_idx); + GLSLD(dd97_high); + GLSLD(dd137_low); + GLSLD(dd137_low_y); + GLSLD(dd137_vert); + + GLSLC( + 0, void main() { ); + GLSLC(1, int off_y = int(gl_WorkGroupSize.y * gl_NumWorkGroups.y);); + GLSLC(1, int off_x = int(gl_WorkGroupSize.x * gl_NumWorkGroups.x);); + GLSLC(1, int pic_z = int(gl_GlobalInvocationID.z);); + GLSLC(1, ); + GLSLC(1, uint h = int(plane_sizes[pic_z].y);); + GLSLC(2, uint w = int(plane_sizes[pic_z].x);); + GLSLC(1, ); + GLSLC(1, int y = int(gl_GlobalInvocationID.y);); + GLSLC( + 1, for (; 2 * y < h; y += off_y) { ); + GLSLC(2, int x = int(gl_GlobalInvocationID.x);); + GLSLC( + 2, for (; x < w; x += off_x) { ); + GLSLC(3, idwt_vert(pic_z, x, 2 * y);); + GLSLC(2, + }); + GLSLC(1, + }); + GLSLC(0, + }); + + RET(spv->compile_shader(spv, vkctx, shd, &spv_data, &spv_len, "main", + &spv_opaque)); + RET(ff_vk_shader_create(vkctx, shd, spv_data, spv_len, "main")); + RET(ff_vk_init_compute_pipeline(vkctx, pl, shd)); + RET(ff_vk_exec_pipeline_register(vkctx, exec, pl)); + +fail: + if (spv_opaque) + spv->free_shader(spv, &spv_opaque); + + return err; +} + +static int init_wavelet_shd_dd137_horiz(DiracVulkanDecodeContext *s, + FFVkSPIRVCompiler *spv) { + int err = 0; + uint8_t *spv_data; + size_t spv_len; + void *spv_opaque = NULL; + int wavelet_idx = DWT_DIRAC_DD13_7; + FFVulkanContext *vkctx = &s->vkctx; + FFVulkanDescriptorSetBinding *desc; + FFVkSPIRVShader *shd = &s->horiz_wavelet_shd[wavelet_idx]; + FFVulkanPipeline *pl = &s->horiz_wavelet_pl[wavelet_idx]; + FFVkExecPool *exec = &s->exec_pool; + + RET(ff_vk_shader_init(pl, shd, "dd137_horiz", VK_SHADER_STAGE_COMPUTE_BIT, + 0)); + + shd = &s->horiz_wavelet_shd[wavelet_idx]; + ff_vk_shader_set_compute_sizes(shd, 8, 8, 3); + + GLSLC(0, #extension GL_EXT_debug_printf : enable); + GLSLC(0, #extension GL_EXT_scalar_block_layout : require); + GLSLC(0, #extension GL_EXT_shader_explicit_arithmetic_types : require); + + desc = (FFVulkanDescriptorSetBinding[]){ + { + .name = "in_buf", + .stages = VK_SHADER_STAGE_COMPUTE_BIT, + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, + .buf_content = "int32_t inBuf[];", + .mem_quali = "readonly", + .dimensions = 1, + }, + { + .name = "out_buf", + .stages = VK_SHADER_STAGE_COMPUTE_BIT, + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, + .buf_content = "int32_t outBuf[];", + .mem_quali = "writeonly", + .dimensions = 1, + }, + }; + RET(ff_vk_pipeline_descriptor_set_add(vkctx, pl, shd, desc, 2, 0, 0)); + + ff_vk_add_push_constant(pl, 0, sizeof(WaveletPushConst), + VK_SHADER_STAGE_COMPUTE_BIT); + + GLSLC( + 0, layout(push_constant, std430) uniform pushConstants { ); + GLSLC(1, ivec2 plane_sizes[3];); + GLSLC(1, int plane_offs[3];); + GLSLC(1, int plane_strides[3];); + GLSLC(1, int dw[3];); + GLSLC(1, int wavelet_depth;); + GLSLC(0, + };); + GLSLC(0, ); + + GLSLD(get_idx); + GLSLD(dd97_high); + GLSLD(dd137_low); + GLSLD(dd137_low_x); + GLSLD(dd137_horiz); + + GLSLC( + 0, void main() { ); + GLSLC(1, int off_y = int(gl_WorkGroupSize.y * gl_NumWorkGroups.y);); + GLSLC(1, int off_x = int(gl_WorkGroupSize.x * gl_NumWorkGroups.x);); + GLSLC(1, int pic_z = int(gl_GlobalInvocationID.z);); + GLSLC(1, ); + GLSLC(1, uint h = int(plane_sizes[pic_z].y);); + GLSLC(2, uint w = int(plane_sizes[pic_z].x);); + GLSLC(1, ); + GLSLC(1, int y = int(gl_GlobalInvocationID.y);); + GLSLC( + 1, for (; y < h; y += off_y) { ); + GLSLC(2, int x = int(gl_GlobalInvocationID.x);); + GLSLC( + 2, for (; 2 * x < w; x += off_x) { ); + GLSLC(3, idwt_horiz(pic_z, x, y);); + GLSLC(2, + }); + GLSLC(1, + }); + GLSLC(0, + }); + + RET(spv->compile_shader(spv, vkctx, shd, &spv_data, &spv_len, "main", + &spv_opaque)); + RET(ff_vk_shader_create(vkctx, shd, spv_data, spv_len, "main")); + RET(ff_vk_init_compute_pipeline(vkctx, pl, shd)); + RET(ff_vk_exec_pipeline_register(vkctx, exec, pl)); + +fail: + if (spv_opaque) + spv->free_shader(spv, &spv_opaque); + + return err; +} + +static av_always_inline int inline wavelet_dd137_pass( + DiracVulkanDecodeContext *dec, DiracContext *ctx, FFVkExecContext *exec, + VkBufferMemoryBarrier2 *buf_bar, int *nb_buf_bar) { + int err; + int barrier_num = *nb_buf_bar; + int wavelet_idx = DWT_DIRAC_DD13_7; + FFVulkanFunctions *vk = &dec->vkctx.vkfn; + + FFVulkanPipeline *pl_hor = &dec->horiz_wavelet_pl[wavelet_idx]; + FFVulkanPipeline *pl_vert = &dec->vert_wavelet_pl[wavelet_idx]; + + err = ff_vk_set_descriptor_buffer(&dec->vkctx, pl_vert, exec, 0, 0, 0, + dec->tmp_buf.address, dec->tmp_buf.size, + VK_FORMAT_UNDEFINED); + if (err < 0) + goto fail; + err = ff_vk_set_descriptor_buffer( + &dec->vkctx, pl_vert, exec, 0, 1, 0, dec->tmp_interleave_buf.address, + dec->tmp_interleave_buf.size, VK_FORMAT_UNDEFINED); + if (err < 0) + goto fail; + + err = ff_vk_set_descriptor_buffer( + &dec->vkctx, pl_hor, exec, 0, 0, 0, dec->tmp_interleave_buf.address, + dec->tmp_interleave_buf.size, VK_FORMAT_UNDEFINED); + if (err < 0) + goto fail; + err = ff_vk_set_descriptor_buffer(&dec->vkctx, pl_hor, exec, 0, 1, 0, + dec->tmp_buf.address, dec->tmp_buf.size, + VK_FORMAT_UNDEFINED); + if (err < 0) + goto fail; + + for (int i = ctx->wavelet_depth - 1; i >= 0; i--) { + dec->pConst.plane_strides[0] = ctx->plane[0].idwt.width << i; + dec->pConst.plane_strides[1] = ctx->plane[1].idwt.width << i; + dec->pConst.plane_strides[2] = ctx->plane[2].idwt.width << i; + + dec->pConst.plane_offs[0] = 0; + dec->pConst.plane_offs[1] = + ctx->plane[0].idwt.width * ctx->plane[0].idwt.height; + dec->pConst.plane_offs[2] = + dec->pConst.plane_offs[1] + + ctx->plane[1].idwt.width * ctx->plane[1].idwt.height; + + dec->pConst.dw[0] = ctx->plane[0].idwt.width >> (i + 1); + dec->pConst.dw[1] = ctx->plane[1].idwt.width >> (i + 1); + dec->pConst.dw[2] = ctx->plane[2].idwt.width >> (i + 1); + + dec->pConst.real_plane_dims[0] = (ctx->plane[0].idwt.width) >> i; + dec->pConst.real_plane_dims[1] = (ctx->plane[0].idwt.height) >> i; + dec->pConst.real_plane_dims[2] = (ctx->plane[1].idwt.width) >> i; + dec->pConst.real_plane_dims[3] = (ctx->plane[1].idwt.height) >> i; + dec->pConst.real_plane_dims[4] = (ctx->plane[2].idwt.width) >> i; + dec->pConst.real_plane_dims[5] = (ctx->plane[2].idwt.height) >> i; + + /* Vertical wavelet pass */ + ff_vk_update_push_exec(&dec->vkctx, exec, pl_vert, + VK_SHADER_STAGE_COMPUTE_BIT, 0, + sizeof(WaveletPushConst), &dec->pConst); + + barrier_num = *nb_buf_bar; + bar_read(buf_bar, nb_buf_bar, &dec->tmp_buf); + bar_write(buf_bar, nb_buf_bar, &dec->tmp_buf); + bar_read(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); + bar_write(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); + + vk->CmdPipelineBarrier2( + exec->buf, + &(VkDependencyInfo){ + .sType = VK_STRUCTURE_TYPE_DEPENDENCY_INFO, + .pBufferMemoryBarriers = buf_bar + barrier_num, + .bufferMemoryBarrierCount = *nb_buf_bar - barrier_num, + }); + + ff_vk_exec_bind_pipeline(&dec->vkctx, exec, pl_vert); + vk->CmdDispatch(exec->buf, dec->pConst.real_plane_dims[0] >> 3, + dec->pConst.real_plane_dims[1] >> 4, 1); + + /* Horizontal wavelet pass */ + ff_vk_update_push_exec(&dec->vkctx, exec, pl_hor, + VK_SHADER_STAGE_COMPUTE_BIT, 0, + sizeof(WaveletPushConst), &dec->pConst); + + barrier_num = *nb_buf_bar; + bar_read(buf_bar, nb_buf_bar, &dec->tmp_buf); + bar_write(buf_bar, nb_buf_bar, &dec->tmp_buf); + bar_read(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); + bar_write(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); + + vk->CmdPipelineBarrier2( + exec->buf, + &(VkDependencyInfo){ + .sType = VK_STRUCTURE_TYPE_DEPENDENCY_INFO, + .pBufferMemoryBarriers = buf_bar + barrier_num, + .bufferMemoryBarrierCount = *nb_buf_bar - barrier_num, + }); + + ff_vk_exec_bind_pipeline(&dec->vkctx, exec, pl_hor); + vk->CmdDispatch(exec->buf, dec->pConst.real_plane_dims[0] >> 4, + dec->pConst.real_plane_dims[1] >> 3, 1); + + barrier_num = *nb_buf_bar; + bar_read(buf_bar, nb_buf_bar, &dec->tmp_buf); + bar_write(buf_bar, nb_buf_bar, &dec->tmp_buf); + bar_read(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); + bar_write(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); + + vk->CmdPipelineBarrier2( + exec->buf, + &(VkDependencyInfo){ + .sType = VK_STRUCTURE_TYPE_DEPENDENCY_INFO, + .pBufferMemoryBarriers = buf_bar + barrier_num, + .bufferMemoryBarrierCount = *nb_buf_bar - barrier_num, + }); + } + + return 0; +fail: + ff_vk_exec_discard_deps(&dec->vkctx, exec); + return err; +} + +/* ----- Haar Wavelet init and pipeline pass ----- */ + +static const char haari_horiz[] = {C( + 0, void idwt_horiz(int plane, int x, int y) { ) + C(1, int offs0 = plane_offs[plane] + plane_strides[plane] * y + x; ) + C(1, int offs1 = offs0 + plane_sizes[plane].x / 2; ) + C(1, int outIdx = plane_offs[plane] + plane_strides[plane] * y + x * 2; ) + C(1, int32_t val_orig0 = inBuf[offs0]; ) + C(1, int32_t val_orig1 = inBuf[offs1]; ) + C(1, int32_t val_new0 = val_orig0 - ((val_orig1 + 1) >> 1); ) + C(1, int32_t val_new1 = val_orig1 + val_new0; ) + C(1, outBuf[outIdx] = val_new0; ) + C(1, outBuf[outIdx + 1] = val_new1; ) + C(0, + })}; + +static const char haari_shift_horiz[] = {C( + 0, void idwt_horiz(int plane, int x, int y) { ) + C(1, int offs0 = plane_offs[plane] + plane_strides[plane] * y + x; ) + C(1, int offs1 = offs0 + plane_sizes[plane].x / 2; ) + C(1, int outIdx = plane_offs[plane] + plane_strides[plane] * y + x * 2; ) + C(1, int32_t val_orig0 = inBuf[offs0]; ) + C(1, int32_t val_orig1 = inBuf[offs1]; ) + C(1, int32_t val_new0 = val_orig0 - ((val_orig1 + 1) >> 1); ) + C(1, int32_t val_new1 = val_orig1 + val_new0; ) + C(1, outBuf[outIdx] = (val_new0 + 1) >> 1; ) + C(1, outBuf[outIdx + 1] = (val_new1 + 1) >> 1; ) + C(0, + })}; + +static const char haari_vert[] = {C( + 0, void idwt_vert(int plane, int x, int y) { ) + C(1, int offs0 = plane_offs[plane] + plane_strides[plane] * y + x; ) + C(1, int offs1 = plane_offs[plane] + plane_strides[plane] * (y + 1) + x; ) + C(2, int32_t val_orig0 = inBuf[offs0]; ) + C(1, int32_t val_orig1 = inBuf[offs1]; ) + C(1, int32_t val_new0 = val_orig0 - ((val_orig1 + 1) >> 1); ) + C(1, int32_t val_new1 = val_orig1 + val_new0; ) + C(1, outBuf[offs0] = val_new0; ) + C(1, outBuf[offs1] = val_new1; ) + C(0, + })}; + +static int init_wavelet_shd_haari_vert(DiracVulkanDecodeContext *s, + FFVkSPIRVCompiler *spv, int shift) { + int err = 0; + uint8_t *spv_data; + size_t spv_len; + void *spv_opaque = NULL; + int wavelet_idx = DWT_DIRAC_HAAR0 + shift; + FFVulkanContext *vkctx = &s->vkctx; + FFVulkanDescriptorSetBinding *desc; + FFVkSPIRVShader *shd = &s->vert_wavelet_shd[wavelet_idx]; + FFVulkanPipeline *pl = &s->vert_wavelet_pl[wavelet_idx]; + FFVkExecPool *exec = &s->exec_pool; + + RET(ff_vk_shader_init(pl, shd, "haari_vert", VK_SHADER_STAGE_COMPUTE_BIT, + 0)); + + shd = &s->vert_wavelet_shd[wavelet_idx]; + ff_vk_shader_set_compute_sizes(shd, 8, 8, 3); + + GLSLC(0, #extension GL_EXT_scalar_block_layout : require); + GLSLC(0, #extension GL_EXT_shader_explicit_arithmetic_types : require); + + desc = (FFVulkanDescriptorSetBinding[]){ + { + .name = "in_buf", + .stages = VK_SHADER_STAGE_COMPUTE_BIT, + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, + .buf_content = "int32_t inBuf[];", + .mem_quali = "readonly", + .dimensions = 1, + }, + { + .name = "out_buf", + .stages = VK_SHADER_STAGE_COMPUTE_BIT, + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, + .buf_content = "int32_t outBuf[];", + .mem_quali = "writeonly", + .dimensions = 1, + }, + }; + RET(ff_vk_pipeline_descriptor_set_add(vkctx, pl, shd, desc, 2, 0, 0)); + + ff_vk_add_push_constant(pl, 0, sizeof(WaveletPushConst), + VK_SHADER_STAGE_COMPUTE_BIT); + + GLSLC( + 0, layout(push_constant, std430) uniform pushConstants { ); + GLSLC(1, ivec2 plane_sizes[3];); + GLSLC(1, int plane_offs[3];); + GLSLC(1, int plane_strides[3];); + GLSLC(1, int dw[3];); + GLSLC(1, int wavelet_depth;); + GLSLC(0, + };); + GLSLC(0, ); + + GLSLD(haari_vert); + + GLSLC( + 0, void main() { ); + GLSLC(1, int off_y = int(gl_WorkGroupSize.y * gl_NumWorkGroups.y);); + GLSLC(1, int off_x = int(gl_WorkGroupSize.x * gl_NumWorkGroups.x);); + GLSLC(1, int pic_z = int(gl_GlobalInvocationID.z);); + GLSLC(1, ); + GLSLC(1, uint h = int(plane_sizes[pic_z].y);); + GLSLC(2, uint w = int(plane_sizes[pic_z].x);); + GLSLC(1, ); + GLSLC(1, int y = int(gl_GlobalInvocationID.y);); + GLSLC( + 1, for (; 2 * y < h; y += off_y) { ); + GLSLC(2, int x = int(gl_GlobalInvocationID.x);); + GLSLC( + 2, for (; x < w; x += off_x) { ); + GLSLC(3, idwt_vert(pic_z, x, 2 * y);); + GLSLC(2, + }); + GLSLC(1, + }); + GLSLC(0, + }); + + RET(spv->compile_shader(spv, vkctx, shd, &spv_data, &spv_len, "main", + &spv_opaque)); + RET(ff_vk_shader_create(vkctx, shd, spv_data, spv_len, "main")); + RET(ff_vk_init_compute_pipeline(vkctx, pl, shd)); + RET(ff_vk_exec_pipeline_register(vkctx, exec, pl)); + +fail: + if (spv_opaque) + spv->free_shader(spv, &spv_opaque); + + return err; +} + +static int init_wavelet_shd_haari_horiz(DiracVulkanDecodeContext *s, + FFVkSPIRVCompiler *spv, int shift) { + int err = 0; + uint8_t *spv_data; + size_t spv_len; + void *spv_opaque = NULL; + int wavelet_idx = DWT_DIRAC_HAAR0 + shift; + FFVulkanContext *vkctx = &s->vkctx; + FFVulkanDescriptorSetBinding *desc; + FFVkSPIRVShader *shd = &s->horiz_wavelet_shd[wavelet_idx]; + FFVulkanPipeline *pl = &s->horiz_wavelet_pl[wavelet_idx]; + FFVkExecPool *exec = &s->exec_pool; + + RET(ff_vk_shader_init(pl, shd, "haari_horiz", VK_SHADER_STAGE_COMPUTE_BIT, + 0)); + + shd = &s->horiz_wavelet_shd[wavelet_idx]; + ff_vk_shader_set_compute_sizes(shd, 8, 8, 3); + + GLSLC(0, #extension GL_EXT_scalar_block_layout : require); + GLSLC(0, #extension GL_EXT_shader_explicit_arithmetic_types : require); + + desc = (FFVulkanDescriptorSetBinding[]){ + { + .name = "in_buf", + .stages = VK_SHADER_STAGE_COMPUTE_BIT, + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, + .buf_content = "int32_t inBuf[];", + .mem_quali = "readonly", + .dimensions = 1, + }, + { + .name = "out_buf", + .stages = VK_SHADER_STAGE_COMPUTE_BIT, + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, + .buf_content = "int32_t outBuf[];", + .mem_quali = "writeonly", + .dimensions = 1, + }, + }; + RET(ff_vk_pipeline_descriptor_set_add(vkctx, pl, shd, desc, 2, 0, 0)); + + ff_vk_add_push_constant(pl, 0, sizeof(WaveletPushConst), + VK_SHADER_STAGE_COMPUTE_BIT); + + GLSLC( + 0, layout(push_constant, std430) uniform pushConstants { ); + GLSLC(1, ivec2 plane_sizes[3];); + GLSLC(1, int plane_offs[3];); + GLSLC(1, int plane_strides[3];); + GLSLC(1, int dw[3];); + GLSLC(1, int wavelet_depth;); + GLSLC(0, + };); + GLSLC(0, ); + + GLSLD(shift ? haari_shift_horiz : haari_horiz); + + GLSLC( + 0, void main() { ); + GLSLC(1, int off_y = int(gl_WorkGroupSize.y * gl_NumWorkGroups.y);); + GLSLC(1, int off_x = int(gl_WorkGroupSize.x * gl_NumWorkGroups.x);); + GLSLC(1, int pic_z = int(gl_GlobalInvocationID.z);); + GLSLC(1, ); + GLSLC(1, uint w = int(plane_sizes[pic_z].x);); + GLSLC(1, uint h = int(plane_sizes[pic_z].y);); + GLSLC(1, ); + GLSLC(1, int y = int(gl_GlobalInvocationID.y);); + GLSLC( + 1, for (; y < h; y += off_y) { ); + GLSLC(2, int x = int(gl_GlobalInvocationID.x);); + GLSLC( + 2, for (; 2 * x < w; x += off_x) { ); + GLSLC(3, idwt_horiz(pic_z, x, y);); + GLSLC(2, + }); + GLSLC(1, + }); + GLSLC(0, + }); + + RET(spv->compile_shader(spv, vkctx, shd, &spv_data, &spv_len, "main", + &spv_opaque)); + RET(ff_vk_shader_create(vkctx, shd, spv_data, spv_len, "main")); + RET(ff_vk_init_compute_pipeline(vkctx, pl, shd)); + RET(ff_vk_exec_pipeline_register(vkctx, exec, pl)); + +fail: + if (spv_opaque) + spv->free_shader(spv, &spv_opaque); + + return err; +} + +static av_always_inline int inline wavelet_haari_pass( + DiracVulkanDecodeContext *dec, DiracContext *ctx, FFVkExecContext *exec, + VkBufferMemoryBarrier2 *buf_bar, int *nb_buf_bar, int shift) { + int err; + int barrier_num = *nb_buf_bar; + + const int wavelet_idx = DWT_DIRAC_HAAR0 + shift; + FFVulkanFunctions *vk = &dec->vkctx.vkfn; + + FFVulkanPipeline *pl_hor = &dec->horiz_wavelet_pl[wavelet_idx]; + FFVulkanPipeline *pl_vert = &dec->vert_wavelet_pl[wavelet_idx]; + + err = ff_vk_set_descriptor_buffer(&dec->vkctx, pl_vert, exec, 0, 0, 0, + dec->tmp_buf.address, dec->tmp_buf.size, + VK_FORMAT_UNDEFINED); + if (err < 0) + goto fail; + err = ff_vk_set_descriptor_buffer( + &dec->vkctx, pl_vert, exec, 0, 1, 0, dec->tmp_interleave_buf.address, + dec->tmp_interleave_buf.size, VK_FORMAT_UNDEFINED); + if (err < 0) + goto fail; + + err = ff_vk_set_descriptor_buffer( + &dec->vkctx, pl_hor, exec, 0, 0, 0, dec->tmp_interleave_buf.address, + dec->tmp_interleave_buf.size, VK_FORMAT_UNDEFINED); + if (err < 0) + goto fail; + err = ff_vk_set_descriptor_buffer(&dec->vkctx, pl_hor, exec, 0, 1, 0, + dec->tmp_buf.address, dec->tmp_buf.size, + VK_FORMAT_UNDEFINED); + if (err < 0) + goto fail; + + for (int i = ctx->wavelet_depth - 1; i >= 0; i--) { + dec->pConst.plane_strides[0] = ctx->plane[0].idwt.width << i; + dec->pConst.plane_strides[1] = ctx->plane[1].idwt.width << i; + dec->pConst.plane_strides[2] = ctx->plane[2].idwt.width << i; + + dec->pConst.plane_offs[0] = 0; + dec->pConst.plane_offs[1] = + ctx->plane[0].idwt.width * ctx->plane[0].idwt.height; + dec->pConst.plane_offs[2] = + dec->pConst.plane_offs[1] + + ctx->plane[1].idwt.width * ctx->plane[1].idwt.height; + + dec->pConst.dw[0] = ctx->plane[0].idwt.width >> (i + 1); + dec->pConst.dw[1] = ctx->plane[1].idwt.width >> (i + 1); + dec->pConst.dw[2] = ctx->plane[2].idwt.width >> (i + 1); + + dec->pConst.real_plane_dims[0] = ctx->plane[0].idwt.width >> i; + dec->pConst.real_plane_dims[1] = ctx->plane[0].idwt.height >> i; + dec->pConst.real_plane_dims[2] = ctx->plane[1].idwt.width >> i; + dec->pConst.real_plane_dims[3] = ctx->plane[1].idwt.height >> i; + dec->pConst.real_plane_dims[4] = ctx->plane[2].idwt.width >> i; + dec->pConst.real_plane_dims[5] = ctx->plane[2].idwt.height >> i; + + dec->pConst.wavelet_depth = ctx->wavelet_depth; + + /* Vertical wavelet pass */ + ff_vk_update_push_exec(&dec->vkctx, exec, pl_vert, + VK_SHADER_STAGE_COMPUTE_BIT, 0, + sizeof(WaveletPushConst), &dec->pConst); + + barrier_num = *nb_buf_bar; + bar_read(buf_bar, nb_buf_bar, &dec->tmp_buf); + bar_write(buf_bar, nb_buf_bar, &dec->tmp_buf); + bar_read(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); + bar_write(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); + + vk->CmdPipelineBarrier2( + exec->buf, + &(VkDependencyInfo){ + .sType = VK_STRUCTURE_TYPE_DEPENDENCY_INFO, + .pBufferMemoryBarriers = buf_bar + barrier_num, + .bufferMemoryBarrierCount = *nb_buf_bar - barrier_num, + }); + + ff_vk_exec_bind_pipeline(&dec->vkctx, exec, pl_vert); + vk->CmdDispatch(exec->buf, dec->pConst.real_plane_dims[0] >> 3, + dec->pConst.real_plane_dims[1] >> 4, 1); + + /* Horizontal wavelet pass */ + ff_vk_update_push_exec(&dec->vkctx, exec, pl_hor, + VK_SHADER_STAGE_COMPUTE_BIT, 0, + sizeof(WaveletPushConst), &dec->pConst); + + barrier_num = *nb_buf_bar; + bar_read(buf_bar, nb_buf_bar, &dec->tmp_buf); + bar_write(buf_bar, nb_buf_bar, &dec->tmp_buf); + bar_read(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); + bar_write(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); + + vk->CmdPipelineBarrier2( + exec->buf, + &(VkDependencyInfo){ + .sType = VK_STRUCTURE_TYPE_DEPENDENCY_INFO, + .pBufferMemoryBarriers = buf_bar + barrier_num, + .bufferMemoryBarrierCount = *nb_buf_bar - barrier_num, + }); + + ff_vk_exec_bind_pipeline(&dec->vkctx, exec, pl_hor); + vk->CmdDispatch(exec->buf, dec->pConst.real_plane_dims[0] >> 4, + dec->pConst.real_plane_dims[1] >> 3, 1); + } + + barrier_num = *nb_buf_bar; + bar_read(buf_bar, nb_buf_bar, &dec->tmp_buf); + bar_write(buf_bar, nb_buf_bar, &dec->tmp_buf); + bar_read(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); + bar_write(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); + vk->CmdPipelineBarrier2( + exec->buf, &(VkDependencyInfo){ + .sType = VK_STRUCTURE_TYPE_DEPENDENCY_INFO, + .pBufferMemoryBarriers = buf_bar + barrier_num, + .bufferMemoryBarrierCount = *nb_buf_bar - barrier_num, + }); + + return 0; +fail: + ff_vk_exec_discard_deps(&dec->vkctx, exec); + return err; +} + +/* ----- Dequant Shader init and pipeline pass ----- */ + +static const char dequant[] = {C( + 0, void dequant(int outIdx, int idx, int qf, int qs) { ) + C(1, int32_t val = inBuffer[idx]; ) + C(1, val = sign(val) * ((abs(val) * qf + qs) >> 2); ) + C(1, outBuf0[outIdx] = outBuf1[outIdx] = val; ) + C(0, + })}; + +static const char proc_slice[] = {C( + 0, void proc_slice(int slice_idx) { ) + C(1, const int plane = int(gl_GlobalInvocationID.y); ) + C(1, const int level = int(gl_GlobalInvocationID.z); ) + C(1, if (level >= wavelet_depth) return; ) + C(1, const int base_idx = slice_idx * DWT_LEVELS * 8; ) + C(1, const int base_slice_idx = slice_idx * DWT_LEVELS * 3 + plane * DWT_LEVELS; ) + C(1, ) + C(1, const Slice s = slices[base_slice_idx + level]; ) + C(1, int offs = s.offs; ) + C(1, ) + C(1, for(int orient = int(bool(level)); orient < 4; orient++) { ) + C(2, int32_t qf = quantMatrix[base_idx + level * 8 + orient]; ) + C(2, int32_t qs = quantMatrix[base_idx + level * 8 + orient + 4]; ) + C(2, ) + C(2, const int subband_idx = plane * DWT_LEVELS * 4 ) + C(2, + 4 * level + orient; ) + C(2, ) + C(2, const SubbandOffset sub_off = subband_offs[subband_idx]; ) + C(2, int img_idx = plane_offs[plane] + sub_off.base_off ) + C(2, + s.top * sub_off.stride + s.left; ) + C(2, ) + C(2, for(int y = 0; y < s.tot_v; y++) { ) + C(3, int img_x = img_idx; ) + C(3, for(int x = 0; x < s.tot_h; x++) { ) + C(4, dequant(img_x, offs, qf, qs); ) + C(4, img_x++; ) + C(4, offs++; ) + C(3, } ) + C(3, img_idx += sub_off.stride; ) + C(2, } ) + C(1, } ) + C(0, + })}; + +static int init_quant_shd(DiracVulkanDecodeContext *s, FFVkSPIRVCompiler *spv) { + int err = 0; + uint8_t *spv_data; + size_t spv_len; + void *spv_opaque = NULL; + // const int planes = av_pix_fmt_count_planes(s->vkctx.output_format); + FFVulkanContext *vkctx = &s->vkctx; + FFVulkanDescriptorSetBinding *desc; + FFVkSPIRVShader *shd = &s->quant_shd; + FFVulkanPipeline *pl = &s->quant_pl; + FFVkExecPool *exec = &s->exec_pool; + + RET(ff_vk_shader_init(pl, shd, "dequant", VK_SHADER_STAGE_COMPUTE_BIT, 0)); + + shd = &s->quant_shd; + ff_vk_shader_set_compute_sizes(shd, 1, 1, 1); + + GLSLC(0, #extension GL_EXT_debug_printf : enable); + GLSLC(0, #extension GL_EXT_scalar_block_layout : require); + GLSLC(0, #extension GL_EXT_shader_explicit_arithmetic_types : require); + + GLSLC( + 0, struct Slice { ); + GLSLC(1, int left;); + GLSLC(1, int top;); + GLSLC(1, int tot_h;); + GLSLC(1, int tot_v;); + GLSLC(1, int tot;); + GLSLC(1, int offs;); + GLSLC(1, int pad0;); + GLSLC(1, int pad1;); + GLSLC(0, + };); + + GLSLC( + 0, struct SubbandOffset { ); + GLSLC(1, int base_off;); + GLSLC(1, int stride;); + GLSLC(1, int pad0;); + GLSLC(1, int pad1;); + GLSLC(0, + };); + + desc = (FFVulkanDescriptorSetBinding[]){ + { + .name = "out_buf_0", + .stages = VK_SHADER_STAGE_COMPUTE_BIT, + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, + .buf_content = "int32_t outBuf0[];", + .mem_quali = "writeonly", + .dimensions = 1, + }, + { + .name = "out_buf_1", + .stages = VK_SHADER_STAGE_COMPUTE_BIT, + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, + .buf_content = "int32_t outBuf1[];", + .mem_quali = "writeonly", + .dimensions = 1, + }, + { + .name = "quant_in_buf", + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, + .stages = VK_SHADER_STAGE_COMPUTE_BIT, + .buf_content = "int32_t inBuffer[];", + .mem_quali = "readonly", + }, + { + .name = "quant_vals_buf", + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, + .stages = VK_SHADER_STAGE_COMPUTE_BIT, + .buf_content = "int32_t quantMatrix[];", + .mem_quali = "readonly", + }, + { + .name = "slices_buf", + .type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER, + .stages = VK_SHADER_STAGE_COMPUTE_BIT, + .buf_content = "Slice slices[];", + .mem_quali = "readonly", + .mem_layout = "std430", + }, + { + .name = "subband_buf", + .type = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER, + .stages = VK_SHADER_STAGE_COMPUTE_BIT, + .buf_content = "SubbandOffset subband_offs[60];", + .mem_quali = "readonly", + .mem_layout = "std430", + }, + }; + RET(ff_vk_pipeline_descriptor_set_add(vkctx, pl, shd, desc, 6, 0, 0)); + + ff_vk_add_push_constant(pl, 0, sizeof(WaveletPushConst), + VK_SHADER_STAGE_COMPUTE_BIT); + + GLSLC( + 0, layout(push_constant, std430) uniform pushConstants { ); + GLSLC(1, ivec2 plane_sizes[3];); + GLSLC(1, int plane_offs[3];); + GLSLC(1, int plane_strides[3];); + GLSLC(1, int dw[3];); + GLSLC(1, int wavelet_depth;); + GLSLC(0, + };); + GLSLC(0, ); + + GLSLF(0, #define DWT_LEVELS % i, MAX_DWT_LEVELS); + + GLSLD(dequant); + GLSLD(proc_slice); + GLSLC(0, void main()); + GLSLC(0, { ); + GLSLC(1, int idx = int(gl_GlobalInvocationID.x);); + GLSLC(1, proc_slice(idx);); + GLSLC(0, + }); + + RET(spv->compile_shader(spv, vkctx, shd, &spv_data, &spv_len, "main", + &spv_opaque)); + RET(ff_vk_shader_create(vkctx, shd, spv_data, spv_len, "main")); + RET(ff_vk_init_compute_pipeline(vkctx, pl, shd)); + RET(ff_vk_exec_pipeline_register(vkctx, exec, pl)); + +fail: + if (spv_opaque) + spv->free_shader(spv, &spv_opaque); + + return err; +} + +static av_always_inline int inline quant_pl_pass( + DiracVulkanDecodeContext *dec, DiracContext *ctx, FFVkExecContext *exec, + VkBufferMemoryBarrier2 *buf_bar, int *nb_buf_bar) { + int err, nb_bar; + FFVulkanFunctions *vk = &dec->vkctx.vkfn; + + ff_vk_exec_bind_pipeline(&dec->vkctx, exec, &dec->quant_pl); + + err = ff_vk_set_descriptor_buffer(&dec->vkctx, &dec->quant_pl, exec, 0, 0, + 0, dec->tmp_buf.address, + dec->tmp_buf.size, VK_FORMAT_UNDEFINED); + if (err < 0) + return err; + + err = ff_vk_set_descriptor_buffer(&dec->vkctx, &dec->quant_pl, exec, 0, 1, + 0, dec->tmp_interleave_buf.address, + dec->tmp_interleave_buf.size, + VK_FORMAT_UNDEFINED); + if (err < 0) + return err; + + err = ff_vk_set_descriptor_buffer( + &dec->vkctx, &dec->quant_pl, exec, 0, 2, 0, dec->quant_val_buf->address, + dec->quant_val_buf->size, VK_FORMAT_UNDEFINED); + if (err < 0) + return err; + + err = ff_vk_set_descriptor_buffer( + &dec->vkctx, &dec->quant_pl, exec, 0, 3, 0, dec->quant_buf->address, + dec->quant_buf->size, VK_FORMAT_UNDEFINED); + if (err < 0) + return err; + + err = ff_vk_set_descriptor_buffer( + &dec->vkctx, &dec->quant_pl, exec, 0, 4, 0, dec->slice_buf->address, + dec->slice_buf->size, VK_FORMAT_UNDEFINED); + if (err < 0) + return err; + + err = ff_vk_set_descriptor_buffer( + &dec->vkctx, &dec->quant_pl, exec, 0, 5, 0, dec->subband_info.address, + dec->subband_info.size, VK_FORMAT_UNDEFINED); + if (err < 0) + return err; + + dec->pConst.real_plane_dims[0] = ctx->plane[0].idwt.width; + dec->pConst.real_plane_dims[1] = ctx->plane[0].idwt.height; + dec->pConst.real_plane_dims[2] = ctx->plane[1].idwt.width; + dec->pConst.real_plane_dims[3] = ctx->plane[1].idwt.height; + dec->pConst.real_plane_dims[4] = ctx->plane[2].idwt.width; + dec->pConst.real_plane_dims[5] = ctx->plane[2].idwt.height; + + dec->pConst.plane_strides[0] = ctx->plane[0].idwt.width; + dec->pConst.plane_strides[1] = ctx->plane[1].idwt.width; + dec->pConst.plane_strides[2] = ctx->plane[2].idwt.width; + + dec->pConst.plane_offs[0] = 0; + dec->pConst.plane_offs[1] = + ctx->plane[0].idwt.width * ctx->plane[0].idwt.height; + dec->pConst.plane_offs[2] = + dec->pConst.plane_offs[1] + + ctx->plane[1].idwt.width * ctx->plane[1].idwt.height; + + dec->pConst.wavelet_depth = ctx->wavelet_depth; + + ff_vk_update_push_exec(&dec->vkctx, exec, &dec->quant_pl, + VK_SHADER_STAGE_COMPUTE_BIT, 0, + sizeof(WaveletPushConst), &dec->pConst); + + bar_read(buf_bar, nb_buf_bar, dec->quant_val_buf); + bar_read(buf_bar, nb_buf_bar, dec->slice_buf); + bar_read(buf_bar, nb_buf_bar, dec->quant_buf); + bar_read(buf_bar, nb_buf_bar, &dec->subband_info); + + bar_write(buf_bar, nb_buf_bar, &dec->tmp_buf); + bar_write(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); + bar_read(buf_bar, nb_buf_bar, &dec->tmp_buf); + bar_read(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); + + vk->CmdPipelineBarrier2(exec->buf, + &(VkDependencyInfo){ + .sType = VK_STRUCTURE_TYPE_DEPENDENCY_INFO, + .pBufferMemoryBarriers = buf_bar, + .bufferMemoryBarrierCount = *nb_buf_bar, + }); + + vk->CmdDispatch(exec->buf, ctx->num_x * ctx->num_y, 3, ctx->wavelet_depth); + + nb_bar = *nb_buf_bar; + bar_write(buf_bar, nb_buf_bar, &dec->tmp_buf); + bar_write(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); + bar_read(buf_bar, nb_buf_bar, &dec->tmp_buf); + bar_read(buf_bar, nb_buf_bar, &dec->tmp_interleave_buf); + + vk->CmdPipelineBarrier2( + exec->buf, &(VkDependencyInfo){ + .sType = VK_STRUCTURE_TYPE_DEPENDENCY_INFO, + .pBufferMemoryBarriers = buf_bar + nb_bar, + .bufferMemoryBarrierCount = *nb_buf_bar - nb_bar, + }); + + return 0; +} + +static int vulkan_dirac_uninit(AVCodecContext *avctx) { + // DiracContext *d = avctx->priv_data; + // if (d->hwaccel_picture_private) { + // av_freep(d->hwaccel_picture_private); + // } + + free_common(avctx); + + return 0; +} + +static inline int wavelet_init(DiracVulkanDecodeContext *dec, + FFVkSPIRVCompiler *spv) { + int err; + + err = init_wavelet_shd_daub97_horiz(dec, spv); + if (err < 0) { + return err; + } + + err = init_wavelet_shd_daub97_vert(dec, spv); + if (err < 0) { + return err; + } + + err = init_wavelet_shd_haari_vert(dec, spv, 0); + if (err < 0) { + return err; + } + + err = init_wavelet_shd_haari_horiz(dec, spv, 0); + if (err < 0) { + return err; + } + + err = init_wavelet_shd_haari_vert(dec, spv, 1); + if (err < 0) { + return err; + } + + err = init_wavelet_shd_haari_horiz(dec, spv, 1); + if (err < 0) { + return err; + } + + err = init_wavelet_shd_legall_vert(dec, spv); + if (err < 0) { + return err; + } + + err = init_wavelet_shd_legall_horiz(dec, spv); + if (err < 0) { + return err; + } + + err = init_wavelet_shd_dd97_vert(dec, spv); + if (err < 0) { + return err; + } + + err = init_wavelet_shd_dd97_horiz(dec, spv); + if (err < 0) { + return err; + } + + err = init_wavelet_shd_fidelity_vert(dec, spv); + if (err < 0) { + return err; + } + + err = init_wavelet_shd_fidelity_horiz(dec, spv); + if (err < 0) { + return err; + } + + err = init_wavelet_shd_dd137_vert(dec, spv); + if (err < 0) { + return err; + } + + err = init_wavelet_shd_dd137_horiz(dec, spv); + if (err < 0) { + return err; + } + + return 0; +} + +static int vulkan_dirac_init(AVCodecContext *avctx) { + int err = 0; + DiracVulkanDecodeContext *dec = avctx->internal->hwaccel_priv_data; + FFVulkanContext *s; + FFVkSPIRVCompiler *spv; + + spv = ff_vk_spirv_init(); + if (!spv) { + av_log(avctx, AV_LOG_ERROR, "Unable to initialize SPIR-V compiler!\n"); + return AVERROR_EXTERNAL; + } + + err = ff_decode_get_hw_frames_ctx(avctx, AV_HWDEVICE_TYPE_VULKAN); + if (err < 0) + goto fail; + + /* Initialize contexts */ + s = &dec->vkctx; + + err = ff_vk_init(s, avctx, NULL, avctx->hw_frames_ctx); + if (err < 0) + return err; + + /* Create queue context */ + ff_vk_qf_init(s, &dec->qf, VK_QUEUE_COMPUTE_BIT); + + err = ff_vk_exec_pool_init(s, &dec->qf, &dec->exec_pool, 4, 0, 0, 0, NULL); + + err = ff_vk_init_sampler(&dec->vkctx, &dec->sampler, 1, VK_FILTER_NEAREST); + if (err < 0) { + goto fail; + } + + av_log(avctx, AV_LOG_VERBOSE, "Vulkan decoder initialization sucessful\n"); + + err = init_quant_shd(dec, spv); + if (err < 0) { + goto fail; + } + + err = init_cpy_shd(dec, spv, 0); + if (err < 0) { + goto fail; + } + + err = init_cpy_shd(dec, spv, 1); + if (err < 0) { + goto fail; + } + + err = init_cpy_shd(dec, spv, 2); + if (err < 0) { + goto fail; + } + + err = wavelet_init(dec, spv); + if (err < 0) { + goto fail; + } + + dec->quant_val_buf_vk_ptr = NULL; + dec->slice_buf_vk_ptr = NULL; + dec->quant_buf_vk_ptr = NULL; + + dec->av_quant_val_buf = NULL; + dec->av_quant_buf = NULL; + dec->av_slice_buf = NULL; + + dec->thread_buf_size = 0; + dec->n_slice_bufs = 0; + + err = ff_vk_create_buf(&dec->vkctx, &dec->subband_info, + sizeof(SubbandOffset) * MAX_DWT_LEVELS * 12, NULL, + NULL, + VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT | + VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT, + VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | + VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT); + if (err < 0) + return err; + + err = ff_vk_map_buffer(&dec->vkctx, &dec->subband_info, + (uint8_t **)&dec->subband_info_ptr, 0); + if (err < 0) + return err; + + return 0; + +fail: + if (spv) { + spv->uninit(&spv); + } + vulkan_dirac_uninit(avctx); + + return err; +} + +static int vulkan_decode_bootstrap(AVCodecContext *avctx, + AVBufferRef *frames_ref) { + DiracVulkanDecodeContext *dec = avctx->internal->hwaccel_priv_data; + AVHWFramesContext *frames = (AVHWFramesContext *)frames_ref->data; + AVHWDeviceContext *device = (AVHWDeviceContext *)frames->device_ref->data; + AVVulkanDeviceContext *hwctx = device->hwctx; + + dec->vkctx.extensions = ff_vk_extensions_to_mask( + hwctx->enabled_dev_extensions, hwctx->nb_enabled_dev_extensions); + + return 0; +} + +static int vulkan_dirac_frame_params(AVCodecContext *avctx, + AVBufferRef *hw_frames_ctx) { + int err; + AVHWFramesContext *frames_ctx = (AVHWFramesContext *)hw_frames_ctx->data; + AVVulkanFramesContext *hwfc = frames_ctx->hwctx; + DiracContext *s = avctx->priv_data; + + frames_ctx->sw_format = s->sof_pix_fmt; + + err = vulkan_decode_bootstrap(avctx, hw_frames_ctx); + if (err < 0) + return err; + + frames_ctx->width = avctx->coded_width; + frames_ctx->height = avctx->coded_height; + frames_ctx->format = AV_PIX_FMT_VULKAN; + + for (int i = 0; i < AV_NUM_DATA_POINTERS; i++) { + hwfc->format[i] = av_vkfmt_from_pixfmt(frames_ctx->sw_format)[i]; + } + + hwfc->tiling = VK_IMAGE_TILING_OPTIMAL; + hwfc->usage = VK_IMAGE_USAGE_TRANSFER_SRC_BIT | VK_IMAGE_USAGE_STORAGE_BIT; + + return err; +} + +static void vulkan_dirac_free_frame_priv(FFRefStructOpaque _hwctx, void *data) { + // AVHWDeviceContext *hwctx = _hwctx.nc; + DiracVulkanDecodePicture *dp = data; + + /* Free frame resources */ + av_free(dp); +} + +static void setup_subbands(DiracContext *ctx, DiracVulkanDecodeContext *dec) { + SubbandOffset *offs = dec->subband_info_ptr; + memset(offs, 0, dec->subband_info.size); + + for (int plane = 0; plane < 3; plane++) { + Plane *p = &ctx->plane[plane]; + int w = p->idwt.width; + int s = FFALIGN(p->idwt.width, 8); + + for (int level = ctx->wavelet_depth - 1; level >= 0; level--) { + w >>= 1; + s <<= 1; + for (int orient = 0; orient < 4; orient++) { + const int idx = plane * MAX_DWT_LEVELS * 4 + level * 4 + orient; + SubbandOffset *off = &offs[idx]; + off->stride = s; + off->base_off = 0; + + if (orient & 1) + off->base_off += w; + if (orient > 1) + off->base_off += (s >> 1); + + /*SubBand *b = &p->band[level][orient];*/ + /*int w = (b->ibuf - p->idwt.buf) >> (1 + b->pshift);*/ + /*off->stride = b->stride >> (1 + b->pshift);*/ + /*off->base_off = w;*/ + } + } + } +} + +static int vulkan_dirac_start_frame(AVCodecContext *avctx, + av_unused const uint8_t *buffer, + av_unused uint32_t size) { + int err; + DiracVulkanDecodeContext *s = avctx->internal->hwaccel_priv_data; + DiracContext *c = avctx->priv_data; + DiracVulkanDecodePicture *pic = c->hwaccel_picture_private; + WaveletPushConst *pConst = &s->pConst; + + pic->frame = c->current_picture; + setup_subbands(c, s); + + pConst->real_plane_dims[0] = c->plane[0].idwt.width; + pConst->real_plane_dims[1] = c->plane[0].idwt.height; + pConst->real_plane_dims[2] = c->plane[1].idwt.width; + pConst->real_plane_dims[3] = c->plane[1].idwt.height; + pConst->real_plane_dims[4] = c->plane[2].idwt.width; + pConst->real_plane_dims[5] = c->plane[2].idwt.height; + + pConst->plane_strides[0] = c->plane[0].idwt.width; + pConst->plane_strides[1] = c->plane[1].idwt.width; + pConst->plane_strides[0] = c->plane[0].idwt.width; + + pConst->plane_offs[0] = 0; + pConst->plane_offs[1] = c->plane[0].idwt.width * c->plane[0].idwt.height; + pConst->plane_offs[2] = pConst->plane_offs[1] + + c->plane[1].idwt.width * c->plane[1].idwt.height; + + pConst->wavelet_depth = c->wavelet_depth; + + if (s->quant_buf_vk_ptr == NULL || s->slice_buf_vk_ptr == NULL || + s->quant_val_buf_vk_ptr == NULL || + c->num_x * c->num_y != s->n_slice_bufs) { + err = alloc_quant_buf(c, s); + if (err < 0) + return err; + err = alloc_dequant_buf(c, s); + if (err < 0) + return err; + err = alloc_slices_buf(c, s); + if (err < 0) + return err; + err = alloc_tmp_bufs(c, s); + if (err < 0) + return err; + } + + return 0; +} + +static int vulkan_dirac_end_frame(AVCodecContext *avctx) { + int err, nb_img_bar = 0, nb_buf_bar = 0; + DiracVulkanDecodeContext *dec = avctx->internal->hwaccel_priv_data; + DiracContext *ctx = avctx->priv_data; + VkImageView views[AV_NUM_DATA_POINTERS]; + VkBufferMemoryBarrier2 buf_bar[80]; + VkImageMemoryBarrier2 img_bar[80]; + DiracVulkanDecodePicture *pic = ctx->hwaccel_picture_private; + FFVkExecContext *exec = ff_vk_exec_get(&dec->exec_pool); + enum dwt_type wavelet_idx = ctx->wavelet_idx + 2; + + ff_vk_exec_start(&dec->vkctx, exec); + + err = + ff_vk_exec_add_dep_buf(&dec->vkctx, exec, &dec->av_quant_val_buf, 1, 1); + if (err < 0) + goto fail; + + err = ff_vk_exec_add_dep_buf(&dec->vkctx, exec, &dec->av_quant_buf, 1, 1); + if (err < 0) + goto fail; + + err = ff_vk_exec_add_dep_buf(&dec->vkctx, exec, &dec->av_slice_buf, 1, 1); + if (err < 0) + goto fail; + + err = quant_pl_pass(dec, ctx, exec, buf_bar, &nb_buf_bar); + if (err < 0) + goto fail; + + err = ff_vk_exec_add_dep_frame(&dec->vkctx, exec, pic->frame->avframe, + VK_PIPELINE_STAGE_2_ALL_COMMANDS_BIT, + VK_PIPELINE_STAGE_2_COMPUTE_SHADER_BIT); + if (err < 0) + goto fail; + + err = + ff_vk_create_imageviews(&dec->vkctx, exec, views, pic->frame->avframe); + if (err < 0) + goto fail; + + switch (wavelet_idx) { + case DWT_DIRAC_DAUB9_7: + err = wavelet_daub97_pass(dec, ctx, exec, buf_bar, &nb_buf_bar); + break; + + case DWT_DIRAC_FIDELITY: + err = wavelet_fidelity_pass(dec, ctx, exec, buf_bar, &nb_buf_bar); + break; + + case DWT_DIRAC_DD9_7: + err = wavelet_dd97_pass(dec, ctx, exec, buf_bar, &nb_buf_bar); + break; + + case DWT_DIRAC_DD13_7: + err = wavelet_dd137_pass(dec, ctx, exec, buf_bar, &nb_buf_bar); + break; + + case DWT_DIRAC_LEGALL5_3: + err = wavelet_legall_pass(dec, ctx, exec, buf_bar, &nb_buf_bar); + break; + + case DWT_DIRAC_HAAR0: + err = wavelet_haari_pass(dec, ctx, exec, buf_bar, &nb_buf_bar, 0); + break; + + case DWT_DIRAC_HAAR1: + err = wavelet_haari_pass(dec, ctx, exec, buf_bar, &nb_buf_bar, 1); + break; + + default: + err = AVERROR_PATCHWELCOME; + break; + } + + err = cpy_to_image_pass(dec, ctx, exec, views, buf_bar, &nb_buf_bar, + img_bar, &nb_img_bar, (ctx->bit_depth - 8) >> 1); + if (err < 0) + goto fail; + + err = ff_vk_exec_submit(&dec->vkctx, exec); + if (err < 0) + goto fail; + + ff_vk_exec_wait(&dec->vkctx, exec); + + return 0; + +fail: + ff_vk_exec_discard_deps(&dec->vkctx, exec); + return err; +} + +static int vulkan_dirac_update_thread_context(AVCodecContext *dst, + const AVCodecContext *src) { + int err; + DiracVulkanDecodeContext *src_ctx = src->internal->hwaccel_priv_data; + DiracVulkanDecodeContext *dst_ctx = dst->internal->hwaccel_priv_data; + FFVkSPIRVCompiler *spv; + + spv = ff_vk_spirv_init(); + if (!spv) { + av_log(dst, AV_LOG_ERROR, "Unable to initialize SPIR-V compiler!\n"); + return AVERROR_EXTERNAL; + } + + memset(dst_ctx, 0, sizeof(DiracVulkanDecodeContext)); + + dst_ctx->vkctx = src_ctx->vkctx; + dst_ctx->sampler = src_ctx->sampler; + dst_ctx->qf = src_ctx->qf; + dst_ctx->exec_pool = src_ctx->exec_pool; + dst_ctx->quant_pl = src_ctx->quant_pl; + + err = init_quant_shd(dst_ctx, spv); + if (err < 0) { + goto fail; + } + + err = init_cpy_shd(dst_ctx, spv, 0); + if (err < 0) { + goto fail; + } + + err = init_cpy_shd(dst_ctx, spv, 1); + if (err < 0) { + goto fail; + } + + err = init_cpy_shd(dst_ctx, spv, 2); + if (err < 0) { + goto fail; + } + + err = wavelet_init(dst_ctx, spv); + if (err < 0) { + goto fail; + } + + dst_ctx->quant_val_buf_vk_ptr = NULL; + dst_ctx->slice_buf_vk_ptr = NULL; + dst_ctx->quant_buf_vk_ptr = NULL; + + dst_ctx->av_quant_val_buf = NULL; + dst_ctx->av_quant_buf = NULL; + dst_ctx->av_slice_buf = NULL; + + dst_ctx->thread_buf_size = 0; + dst_ctx->n_slice_bufs = 0; + + err = ff_vk_create_buf(&dst_ctx->vkctx, &dst_ctx->subband_info, + sizeof(SubbandOffset) * MAX_DWT_LEVELS * 12, NULL, + NULL, + VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT | + VK_BUFFER_USAGE_UNIFORM_BUFFER_BIT, + VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | + VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT); + if (err < 0) + return err; + + err = ff_vk_map_buffer(&dst_ctx->vkctx, &dst_ctx->subband_info, + (uint8_t **)&dst_ctx->subband_info_ptr, 0); + if (err < 0) + return err; + + return 0; + +fail: + if (spv) { + spv->uninit(&spv); + } + vulkan_dirac_uninit(dst); + + return err; +} + +static inline int decode_hq_slice(const DiracContext *s, int jobnr) { + int i, level, orientation, quant_idx; + DiracVulkanDecodeContext *dec = s->avctx->internal->hwaccel_priv_data; + int32_t *qfactor = &dec->quant_buf_vk_ptr[jobnr * 8 * MAX_DWT_LEVELS]; + int32_t *qoffset = &dec->quant_buf_vk_ptr[jobnr * 8 * MAX_DWT_LEVELS + 4]; + int32_t *quant_val_base = dec->quant_val_buf_vk_ptr; + DiracSlice *slice = &s->slice_params_buf[jobnr]; + SliceCoeffVk *slice_vk = &dec->slice_buf_vk_ptr[jobnr * 3 * MAX_DWT_LEVELS]; + GetBitContext *gb = &slice->gb; + + skip_bits_long(gb, 8 * s->highquality.prefix_bytes); + quant_idx = get_bits(gb, 8); + + if (quant_idx > DIRAC_MAX_QUANT_INDEX - 1) { + av_log(s->avctx, AV_LOG_ERROR, "Invalid quantization index - %i\n", + quant_idx); + return AVERROR_INVALIDDATA; + } + + /* Slice quantization (slice_quantizers() in the specs) */ + for (level = 0; level < s->wavelet_depth; level++) { + for (orientation = !!level; orientation < 4; orientation++) { + const int quant = + FFMAX(quant_idx - s->lowdelay.quant[level][orientation], 0); + qfactor[level * 8 + orientation] = ff_dirac_qscale_tab[quant]; + qoffset[level * 8 + orientation] = + ff_dirac_qoffset_intra_tab[quant] + 2; + } + } + + /* Luma + 2 Chroma planes */ + for (i = 0; i < 3; i++) { + int coef_num, coef_par; + int64_t length = s->highquality.size_scaler * get_bits(gb, 8); + int64_t bits_end = get_bits_count(gb) + 8 * length; + const uint8_t *addr = align_get_bits(gb); + int offs = dec->slice_vals_size * (3 * jobnr + i); + uint8_t *tmp_buf = (uint8_t *)&quant_val_base[offs]; + + if (length * 8 > get_bits_left(gb)) { + av_log(s->avctx, AV_LOG_ERROR, "end too far away\n"); + return AVERROR_INVALIDDATA; + } + + coef_num = subband_coeffs(s, slice->slice_x, slice->slice_y, i, offs, + &slice_vk[MAX_DWT_LEVELS * i]); + + coef_par = ff_dirac_golomb_read_32bit(addr, length, tmp_buf, coef_num); + + if (coef_num > coef_par) { + const int start_b = coef_par * sizeof(int32_t); + const int end_b = coef_num * sizeof(int32_t); + memset(&tmp_buf[start_b], 0, end_b - start_b); + } + + skip_bits_long(gb, bits_end - get_bits_count(gb)); + } + + return 0; +} + +static int decode_hq_slice_row(AVCodecContext *avctx, void *arg, int jobnr, + int threadnr) { + const DiracContext *s = avctx->priv_data; + int i, jobn = s->num_x * jobnr; + + for (i = 0; i < s->num_x; i++) { + decode_hq_slice(s, jobn); + jobn++; + } + + return 0; +} + +static int vulkan_dirac_decode_slice(AVCodecContext *avctx, const uint8_t *data, + uint32_t size) { + DiracContext *s = avctx->priv_data; + + /*avctx->execute2(avctx, decode_hq_slice_row, NULL, NULL, s->num_y);*/ + for (int i = 0; i < s->num_y; i++) { + decode_hq_slice_row(avctx, NULL, i, 0); + } + + return 0; +} + +const FFHWAccel ff_dirac_vulkan_hwaccel = { + .p.name = "dirac_vulkan", + .p.type = AVMEDIA_TYPE_VIDEO, + .p.id = AV_CODEC_ID_DIRAC, + .p.pix_fmt = AV_PIX_FMT_VULKAN, + .start_frame = &vulkan_dirac_start_frame, + .end_frame = &vulkan_dirac_end_frame, + .decode_slice = &vulkan_dirac_decode_slice, + .free_frame_priv = &vulkan_dirac_free_frame_priv, + .uninit = &vulkan_dirac_uninit, + .init = &vulkan_dirac_init, + .frame_params = &vulkan_dirac_frame_params, + .frame_priv_data_size = sizeof(DiracVulkanDecodePicture), + .decode_params = &ff_vk_params_invalidate, + .flush = &ff_vk_decode_flush, + .update_thread_context = &vulkan_dirac_update_thread_context, + .priv_data_size = sizeof(DiracVulkanDecodeContext), + // .caps_internal = HWACCEL_CAP_ASYNC_SAFE | HWACCEL_CAP_THREAD_SAFE, + .caps_internal = FF_CODEC_CAP_NOT_INIT_THREADSAFE, +};
This patch contains the code for the VC2 vulkan hwaccel, as well as changes to configure and makefiles needed to compile them. Signed-off-by: Petro Mozil <mozil.petryk@gmail.com> --- configure | 2 + libavcodec/Makefile | 1 + libavcodec/diracdec.c | 336 +--- libavcodec/diracdec.h | 263 +++ libavcodec/hwaccels.h | 1 + libavcodec/vulkan_dirac.c | 3817 +++++++++++++++++++++++++++++++++++++ 6 files changed, 4172 insertions(+), 248 deletions(-) create mode 100644 libavcodec/diracdec.h create mode 100644 libavcodec/vulkan_dirac.c