diff mbox series

[FFmpeg-devel] x86/tx_float: AVX2 SIMD for R2C and C2R RDFTs

Message ID NoSuodr--3-9@lynne.ee
State New
Headers show
Series [FFmpeg-devel] x86/tx_float: AVX2 SIMD for R2C and C2R RDFTs | expand

Checks

Context Check Description
yinshiyou/make_loongarch64 success Make finished
yinshiyou/make_fate_loongarch64 success Make fate finished
andriy/make_x86 success Make finished
andriy/make_fate_x86 success Make fate finished

Commit Message

Lynne Jan. 18, 2024, 6:52 p.m. UTC
Adds full assembly for R2C and C2R transforms

R2C Before:
145370 decicycles in           av_tx (r2c),  131072 runs,      0 skips
R2C After:
56897 decicycles in           av_tx (r2c),  131072 runs,      0 skips

C2R Before:
140958 decicycles in           av_tx (c2r),  131071 runs,      1 skips
C2R After:
50427 decicycles in           av_tx (c2r),  131061 runs,     11 skips

C2R does an in-place scatter for the FFT.
R2C could be made a little faster by adding an assembly-only
version of the regular lookup-enabled FFT. In theory, may only
help for really large transforms.
diff mbox series

Patch

From f5281404f5789498b854d33f808c133820540281 Mon Sep 17 00:00:00 2001
From: Lynne <dev@lynne.ee>
Date: Mon, 11 Dec 2023 13:29:21 +0100
Subject: [PATCH] x86/tx_float: AVX2 SIMD for R2C and C2R RDFTs

R2C Before:
145370 decicycles in           av_tx (r2c),  131072 runs,      0 skips
R2C After:
56897 decicycles in           av_tx (r2c),  131072 runs,      0 skips

C2R Before:
140958 decicycles in           av_tx (c2r),  131071 runs,      1 skips
C2R After:
50427 decicycles in           av_tx (c2r),  131061 runs,     11 skips

C2R does an in-place scatter for the FFT.
R2C could be made a little faster by adding an assembly-only
version of the regular lookup-enabled FFT.

This also adds a small optimization to the C version
and makes it more in-line with what the assembly does.
---
 libavutil/tx_template.c       |  20 ++-
 libavutil/x86/tx_float.asm    | 226 ++++++++++++++++++++++++++++++++++
 libavutil/x86/tx_float_init.c |  65 ++++++++++
 tests/checkasm/av_tx.c        |   7 ++
 4 files changed, 306 insertions(+), 12 deletions(-)

diff --git a/libavutil/tx_template.c b/libavutil/tx_template.c
index a2c27465cb..08a0243a2e 100644
--- a/libavutil/tx_template.c
+++ b/libavutil/tx_template.c
@@ -1647,13 +1647,10 @@  static av_cold int TX_NAME(ff_tx_rdft_init)(AVTXContext *s,
     *tab++ = RESCALE( (0.5 - inv) * m);
     *tab++ = RESCALE(-(0.5 - inv) * m);
 
-    for (int i = 0; i < len4; i++)
+    for (int i = 0; i < len4; i++) {
         *tab++ = RESCALE(cos(i*f));
-
-    tab = ((TXSample *)s->exp) + len4 + 8;
-
-    for (int i = 0; i < len4; i++)
         *tab++ = RESCALE(cos(((len - i*4)/4.0)*f)) * (inv ? 1 : -1);
+    }
 
     return 0;
 }
@@ -1665,8 +1662,7 @@  static void TX_NAME(ff_tx_rdft_ ##n)(AVTXContext *s, void *_dst,               \
     const int len2 = s->len >> 1;                                              \
     const int len4 = s->len >> 2;                                              \
     const TXSample *fact = (void *)s->exp;                                     \
-    const TXSample *tcos = fact + 8;                                           \
-    const TXSample *tsin = tcos + len4;                                        \
+    const TXSample *exp = fact + 8;                                            \
     TXComplex *data = inv ? _src : _dst;                                       \
     TXComplex t[3];                                                            \
                                                                                \
@@ -1688,18 +1684,18 @@  static void TX_NAME(ff_tx_rdft_ ##n)(AVTXContext *s, void *_dst,               \
                                                                                \
     for (int i = 1; i < len4; i++) {                                           \
         /* Separate even and odd FFTs */                                       \
-        t[0].re = MULT(fact[4], (data[i].re + data[len2 - i].re));             \
         t[0].im = MULT(fact[5], (data[i].im - data[len2 - i].im));             \
-        t[1].re = MULT(fact[6], (data[i].im + data[len2 - i].im));             \
+        t[0].re = MULT(fact[4], (data[i].re + data[len2 - i].re));             \
         t[1].im = MULT(fact[7], (data[i].re - data[len2 - i].re));             \
+        t[1].re = MULT(fact[6], (data[i].im + data[len2 - i].im));             \
                                                                                \
         /* Apply twiddle factors to the odd FFT and add to the even FFT */     \
-        CMUL(t[2].re, t[2].im, t[1].re, t[1].im, tcos[i], tsin[i]);            \
+        CMUL(t[2].re, t[2].im, t[1].re, t[1].im, exp[i*2 + 0], exp[i*2 + 1]);  \
                                                                                \
-        data[       i].re = t[0].re + t[2].re;                                 \
+        data[       i].re = t[2].re + t[0].re;                                 \
         data[       i].im = t[2].im - t[0].im;                                 \
         data[len2 - i].re = t[0].re - t[2].re;                                 \
-        data[len2 - i].im = t[2].im + t[0].im;                                 \
+        data[len2 - i].im = t[0].im + t[2].im;                                 \
     }                                                                          \
                                                                                \
     if (inv) {                                                                 \
diff --git a/libavutil/x86/tx_float.asm b/libavutil/x86/tx_float.asm
index e1533a8595..11d5e946db 100644
--- a/libavutil/x86/tx_float.asm
+++ b/libavutil/x86/tx_float.asm
@@ -91,6 +91,11 @@  s16_perm:      dd 0, 1, 2, 3, 1, 0, 3, 2
 
 s15_perm:      dd 0, 6, 5, 3, 2, 4, 7, 1
 
+rdft_perm_pos: dd 3, 2, 2, 3, 1, 0, 0, 1
+rdft_perm_neg: dd 1, 0, 0, 1, 3, 2, 2, 3
+rdft_perm_exp: dd 1, 1, 0, 0, 3, 3, 2, 2
+rdft_m11:      times 2 dd 0x0, 0x0, 0x3f800000, 0x3f800000 ; 0, 0, 1, 1
+
 mask_mmppmmmm: dd NEG, NEG, POS, POS, NEG, NEG, NEG, NEG
 mask_mmmmpppm: dd NEG, NEG, NEG, NEG, POS, POS, POS, NEG
 mask_ppmpmmpm: dd POS, POS, NEG, POS, NEG, NEG, POS, NEG
@@ -1934,3 +1939,224 @@  cglobal fft_pfa_15xM_ns_float, 4, 14, 16, 320, ctx, out, in, stride, len, lut, b
 PFA_15_FN avx2, 0
 PFA_15_FN avx2, 1
 %endif
+
+%macro RDFT_CONV_LOAD 0
+    movaps m10, [rdft_perm_neg]
+    movaps m11, [rdft_perm_pos]
+    vbroadcastf128 m12, [expq + 4*4]                ; fact[5476]
+    movaps m13, [mask_pmmppmmp]                     ; +--+
+
+    movaps m14, [rdft_m11]                          ; 0.0, 0.0, 1.0, 1.0
+    movaps m15, [rdft_perm_exp]
+%endmacro
+
+; %1 - source, front
+; %2 - source, rear
+; Results are left in m0 (front) and m2 (rear)
+%macro RDFT_CONV_ITER 2
+    movups m8, [expq + (8 + 2)*4]
+
+    vperm2f128 m4, m8, m8, 0x00                     ; cos,sin,cos,sin x2
+    vperm2f128 m6, m8, m8, 0x11                     ; cos,sin,cos,sin x2
+
+    vpermilps m5, m4, m15                           ; cos1,cos1,cos1,cos1,sin2,sin2,cos2,cos2
+    vpermilps m7, m6, m15                           ; cos1,cos1,cos1,cos1,sin2,sin2,cos2,cos2
+
+    shufpd m4, m14, m5, 1111b                       ; 1,1,cos1,cos1,1,1,cos2,cos2
+    shufpd m5, m14, m5, 0000b                       ; 0,0,sin1,sin1,0,0,sin2,sin2
+    shufpd m6, m14, m7, 1111b                       ; 1,1,cos1,cos1,1,1,cos2,cos2
+    shufpd m7, m14, m7, 0000b                       ; 0,0,sin1,sin1,0,0,sin2,sin2
+
+    movups m2, [%1q]
+    movups m3, [%2q]
+
+    vperm2f128 m0, m2, m2, 0x00
+    vperm2f128 m1, m3, m3, 0x11
+    vperm2f128 m2, m2, m2, 0x11
+    vperm2f128 m3, m3, m3, 0x00
+
+    vpermilps m0, m0, m10                           ; data[0].imrereim, data[1].imrereim
+    vpermilps m1, m1, m11                           ; data[len - 01].imrereim
+    vpermilps m2, m2, m10                           ; data[0].imrereim, data[1].imrereim
+    vpermilps m3, m3, m11                           ; data[len - 01].imrereim
+
+    addsubps m0, m1                                 ; data[0] - data[len - 0] x2
+    addsubps m2, m3                                 ; data[0] - data[len - 0] x2
+
+    mulps m0, m12                                   ; t[01].imre
+    mulps m2, m12                                   ; t[01].imre
+
+    shufps m1, m0, m0, q2301                        ; t[01].reim
+    shufps m3, m2, m2, q2301                        ; t[01].reim
+
+    mulps m1, m4                                    ; 1, 1, tcos, tcos x2
+    mulps m0, m5                                    ; 0, 0, tsin, tsin x2
+    mulps m3, m6                                    ; 1, 1, tcos, tcos x2
+    mulps m2, m7                                    ; 0, 0, tsin, tsin x2
+
+    addsubps m1, m0                                 ; t[02].reim
+    addsubps m3, m2                                 ; t[02].reim
+
+    shufpd m0, m1, m1, 0101b                        ; t[20].reim
+    shufpd m2, m3, m3, 0101b                        ; t[20].reim
+
+    xorps m1, m13                                   ; +--+t[02].reim
+    xorps m3, m13                                   ; +--+t[02].reim
+
+    addps m0, m1                                    ; data[0].reim, data[len2 - 0].reim x2
+    addps m2, m3                                    ; data[0].reim, data[len2 - 0].reim x2
+
+    shufpd m1, m0, m2, 0000b                        ; high
+    shufpd m3, m0, m2, 1111b                        ; low
+
+    vpermpd m0, m1, q3120
+    vpermpd m2, m3, q0213
+%endmacro
+
+%macro RDFT_R2C 1
+INIT_YMM %1
+cglobal rdft_r2c_float, 4, 14, 16, 320, ctx, out, in, stride, len, lut, exp, t1, t2, t3, \
+                                        t4, t5, btmp
+    ; FFT setup
+    mov btmpq, ctxq                                 ; backup original context
+    mov t3q, [ctxq + AVTXContext.fn]                ; subtransform's jump point
+
+    mov ctxq, [ctxq + AVTXContext.sub]              ; load subtransform's context
+    mov lutq, [ctxq + AVTXContext.map]              ; load subtransform's map
+    movsxd lenq, dword [ctxq + AVTXContext.len]     ; load subtransform's length
+
+    mov expq, outq
+.preshuf:
+    LOAD64_LUT m0, inq, lutq, 0, t4q, m1, m2
+    movaps [outq], m0
+    add outq, mmsize
+    add lutq, (mmsize/2)
+    sub lenq, (mmsize/8)
+    jg .preshuf
+
+    mov outq, expq
+    mov inq, expq
+    movsxd lenq, dword [ctxq + AVTXContext.len]     ; load subtransform's length
+
+    call t3q                                        ; call the FFT
+
+    mov ctxq, btmpq                                 ; restore original context
+
+    movsxd lenq, dword [ctxq + AVTXContext.len]
+    mov expq, [ctxq + AVTXContext.exp]
+
+    movsd  xm0, [outq]                              ; data[0].reim
+    movhps xm0, [outq + lenq*2]                     ; data[len4].reim
+
+    shufps xm1, xm0, xm0, q2301                     ; data[0].imre, junk
+    addsubps xm2, xm1, xm0                          ; t[0].imre, junk
+    shufps xm1, xm2, xm1, q2301                     ; t[0].reim, data[len4].reim
+    mulps xm9, xm1, [expq]                          ; data[0,len4].reim
+
+    mov inq, outq
+    lea t1q, [outq + lenq*4 - mmsize]
+    mov t2q, lenq
+    add outq, 8
+
+    ; Perform in-place RDFT conversion
+    RDFT_CONV_LOAD
+.loop:
+    RDFT_CONV_ITER out, t1
+    movups [outq], m0
+    movups [t1q],  m2
+
+    add expq, mmsize
+    add outq, mmsize
+    sub t1q,  mmsize
+
+    sub t2q, mmsize/2
+    jg .loop
+
+    ; Write DC, middle and tail
+    movhps [inq + lenq*2], xm9
+    xorps xm0, xm0
+    shufps xm9, xm9, xm0, q3210
+    shufps xm8, xm9, xm9, q2120
+    movsd [inq], xm8
+    movhps [inq + lenq*4], xm8
+
+    RET
+%endmacro
+
+%macro RDFT_C2R 1
+INIT_YMM %1
+cglobal rdft_c2r_float, 4, 14, 16, 320, ctx, out, in, stride, len, lut, exp, t1, t2, t3, \
+                                        t4, t5, btmp
+    movsxd lenq, dword [ctxq + AVTXContext.len]
+    mov expq, [ctxq + AVTXContext.exp]
+    mov btmpq, [ctxq + AVTXContext.fn]             ; subtransform's jump point
+
+    mov ctxq, [ctxq + AVTXContext.sub]             ; load subtransform's context
+    mov lutq, [ctxq + AVTXContext.map]             ; load subtransform's map
+
+    movss xm0, [inq]                               ; data[0].re
+    insertps xm0, [inq + lenq*4], 0b00_01_0000     ; src, dst, zero flags
+    movhps xm0, [inq + lenq*2]                     ; data[0,len4]
+
+    shufps xm1, xm0, xm0, q2301                    ; data[0].imre, junk
+    addsubps xm2, xm1, xm0                         ; t[0].imre, junk
+    shufps xm1, xm2, xm1, q2301                    ; t[0].reim, data[len4].reim
+    mulps xm9, xm1, [expq]                         ; data[0,len4].reim
+
+    mov t1q, lenq
+    lea t2q, [inq + lenq*4 - mmsize]
+    lea inq, [inq + 8]
+
+    lea t4q, [lutq + 4]
+    lea t5q, [lutq + lenq*2 - 4*4]
+
+    RDFT_CONV_LOAD                                 ; Perform in-place RDFT conversion
+.loop:
+    RDFT_CONV_ITER in, t2
+
+    vextractf128 xm1, m0, 1
+    vextractf128 xm3, m2, 1
+
+    movsxd t3q, dword [t4q + 4*0]
+    movlpd [outq + 8*t3q], xm0
+    movsxd t3q, dword [t4q + 4*1]
+    movhpd [outq + 8*t3q], xm0
+    movsxd t3q, dword [t4q + 4*2]
+    movlpd [outq + 8*t3q], xm1
+    movsxd t3q, dword [t4q + 4*3]
+    movhpd [outq + 8*t3q], xm1
+
+    movsxd t3q, dword [t5q + 4*0]
+    movlpd [outq + 8*t3q], xm2
+    movsxd t3q, dword [t5q + 4*1]
+    movhpd [outq + 8*t3q], xm2
+    movsxd t3q, dword [t5q + 4*2]
+    movlpd [outq + 8*t3q], xm3
+    movsxd t3q, dword [t5q + 4*3]
+    movhpd [outq + 8*t3q], xm3
+
+    add expq, mmsize
+    add inq,  mmsize
+    sub t2q,  mmsize
+    add t4q, 4*4
+    sub t5q, 4*4
+
+    sub t1q, mmsize/2
+    jg .loop
+
+    movsxd t3q, dword [lutq + 0]
+    movsd [outq + 8*t3q], xm9
+    movsxd t3q, dword [lutq + lenq]
+    movhps [outq + 8*t3q], xm9
+
+    mov inq, outq
+    movsxd lenq, dword [ctxq + AVTXContext.len]     ; load subtransform's length
+    call btmpq                                      ; call the FFT
+
+    RET
+%endmacro
+
+%if ARCH_X86_64 && HAVE_AVX2_EXTERNAL
+RDFT_R2C avx2
+RDFT_C2R avx2
+%endif
diff --git a/libavutil/x86/tx_float_init.c b/libavutil/x86/tx_float_init.c
index d3c0beb50f..2f3d7899a9 100644
--- a/libavutil/x86/tx_float_init.c
+++ b/libavutil/x86/tx_float_init.c
@@ -52,6 +52,9 @@  TX_DECL_FN(fft_pfa_15xM_ns, avx2)
 
 TX_DECL_FN(mdct_inv, avx2)
 
+TX_DECL_FN(rdft_r2c, avx2)
+TX_DECL_FN(rdft_c2r, avx2)
+
 TX_DECL_FN(fft2_asm, sse3)
 TX_DECL_FN(fft4_fwd_asm, sse2)
 TX_DECL_FN(fft4_inv_asm, sse2)
@@ -167,6 +170,63 @@  static av_cold int m_inv_init(AVTXContext *s, const FFTXCodelet *cd,
     return 0;
 }
 
+static av_cold int rdft_init(AVTXContext *s, const FFTXCodelet *cd,
+                             uint64_t flags, FFTXCodeletOptions *opts,
+                             int len, int inv, const void *scale)
+{
+    int ret;
+    double f, m;
+    TXSample *tab;
+    uint64_t r2r = flags & AV_TX_REAL_TO_REAL;
+    int len4 = FFALIGN(len, 4) / 4;
+    FFTXCodeletOptions sub_opts = { .map_dir = inv ? FF_TX_MAP_SCATTER : FF_TX_MAP_GATHER };
+
+    s->scale_d = *((SCALE_TYPE *)scale);
+    s->scale_f = s->scale_d;
+
+    flags &= ~(AV_TX_REAL_TO_REAL | AV_TX_REAL_TO_IMAGINARY);
+    flags |=  FF_TX_PRESHUFFLE;   /* This function handles the permute step */
+    flags |=  AV_TX_INPLACE;      /* in-place */
+    flags |=  FF_TX_ASM_CALL;     /* We want an assembly function, not C */
+
+    if ((ret = ff_tx_init_subtx(s, TX_TYPE(FFT), flags, &sub_opts,
+                                len >> 1, inv, scale)))
+        return ret;
+
+    if (!(s->exp = av_mallocz((8 + 2*len4)*sizeof(*s->exp))))
+        return AVERROR(ENOMEM);
+
+    if (!(s->tmp = av_malloc(len*sizeof(*s->tmp))))
+        return AVERROR(ENOMEM);
+
+    tab = (TXSample *)s->exp;
+
+    f = 2*M_PI/len;
+
+    m = (inv ? 2*s->scale_d : s->scale_d);
+
+    *tab++ =  RESCALE((inv ? 0.5 : 1.0) * m);
+    *tab++ = -RESCALE(inv ? 0.5*m : 1.0*m);
+    *tab++ =  RESCALE( m);
+    *tab++ =  RESCALE(-m);
+
+    if (r2r)
+        *tab++ = 1 / s->scale_f;
+    else
+        *tab++ = RESCALE( (0.0 - 0.5) * m);
+    *tab++ = RESCALE( (0.5 - 0.0) * m);
+
+    *tab++ = RESCALE(-(0.5 - inv) * m);
+    *tab++ = RESCALE( (0.5 - inv) * m);
+
+    for (int i = 0; i < len4; i++) {
+        *tab++ = RESCALE(cos(i*f));
+        *tab++ = RESCALE(cos(((len - i*4)/4.0)*f)) * (inv ? 1 : -1);
+    }
+
+    return 0;
+}
+
 static av_cold int fft_pfa_init(AVTXContext *s,
                                 const FFTXCodelet *cd,
                                 uint64_t flags,
@@ -303,6 +363,11 @@  const FFTXCodelet * const ff_tx_codelet_list_float_x86[] = {
 
     TX_DEF(mdct_inv, MDCT, 16, TX_LEN_UNLIMITED, 2, TX_FACTOR_ANY, 384, m_inv_init, avx2, AVX2,
            FF_TX_INVERSE_ONLY, AV_CPU_FLAG_AVXSLOW | AV_CPU_FLAG_SLOW_GATHER),
+
+    TX_DEF(rdft_r2c, RDFT, 16, TX_LEN_UNLIMITED, 2, TX_FACTOR_ANY, 384, rdft_init, avx2, AVX2,
+           FF_TX_FORWARD_ONLY, AV_CPU_FLAG_AVXSLOW | AV_CPU_FLAG_SLOW_GATHER),
+    TX_DEF(rdft_c2r, RDFT, 16, TX_LEN_UNLIMITED, 2, TX_FACTOR_ANY, 384, rdft_init, avx2, AVX2,
+           FF_TX_INVERSE_ONLY, AV_CPU_FLAG_AVXSLOW | AV_CPU_FLAG_SLOW_GATHER),
 #endif
 #endif
 
diff --git a/tests/checkasm/av_tx.c b/tests/checkasm/av_tx.c
index aa8fc6b4e9..676c39ed86 100644
--- a/tests/checkasm/av_tx.c
+++ b/tests/checkasm/av_tx.c
@@ -43,6 +43,10 @@  static const int check_lens[] = {
     2, 4, 8, 16, 32, 64, 120, 960, 1024, 1920, 16384,
 };
 
+static const int rdft_check_lens[] = {
+    32, 1024,
+};
+
 static AVTXContext *tx_refs[AV_TX_NB][2 /* Direction */][FF_ARRAY_ELEMS(check_lens)] = { 0 };
 static int init = 0;
 
@@ -113,6 +117,9 @@  void checkasm_check_av_tx(void)
     CHECK_TEMPLATE("float_imdct", AV_TX_FLOAT_MDCT, 1, float, float, check_lens,
                    !float_near_abs_eps_array(out_ref, out_new, EPS, len));
 
+    CHECK_TEMPLATE("float_r2c", AV_TX_FLOAT_RDFT, 0, float, float, rdft_check_lens,
+                   !float_near_abs_eps_array(out_ref, out_new, EPS, len));
+
     randomize_complex(in, 16384, AVComplexDouble, SCALE_NOOP);
     CHECK_TEMPLATE("double_fft", AV_TX_DOUBLE_FFT, 0, AVComplexDouble, double, check_lens,
                    !double_near_abs_eps_array(out_ref, out_new, EPS, len*2));
-- 
2.43.0