diff mbox series

[FFmpeg-devel,03/11] lavu/tx: add a 7-point FFT and (i)MDCT

Message ID MYfn4l3--3-2@lynne.ee
State Accepted
Commit bd9ea917a3eadc97377cffd0bc66c90dcd22a748
Headers show
Series lavu/tx: FFT improvements, additions and assembly
Related show

Checks

Context Check Description
andriy/x86_make success Make finished
andriy/x86_make_fate success Make fate finished
andriy/PPC64_make success Make finished
andriy/PPC64_make_fate success Make fate finished

Commit Message

Lynne April 19, 2021, 8:22 p.m. UTC
Patch attached.
Subject: [PATCH 03/11] lavu/tx: add a 7-point FFT and (i)MDCT

---
 libavutil/tx_template.c | 126 ++++++++++++++++++++++++++++++++++++----
 1 file changed, 116 insertions(+), 10 deletions(-)
diff mbox series

Patch

diff --git a/libavutil/tx_template.c b/libavutil/tx_template.c
index f78e7abfb1..2946c039be 100644
--- a/libavutil/tx_template.c
+++ b/libavutil/tx_template.c
@@ -40,6 +40,7 @@  COSTABLE(32768);
 COSTABLE(65536);
 COSTABLE(131072);
 DECLARE_ALIGNED(32, FFTComplex, TX_NAME(ff_cos_53))[4];
+DECLARE_ALIGNED(32, FFTComplex, TX_NAME(ff_cos_7))[3];
 
 static FFTSample * const cos_tabs[18] = {
     NULL,
@@ -103,9 +104,16 @@  static av_cold void ff_init_53_tabs(void)
     TX_NAME(ff_cos_53)[3] = (FFTComplex){ RESCALE(cos(2 * M_PI / 10)), RESCALE(sin(2 * M_PI / 10)) };
 }
 
+static av_cold void ff_init_7_tabs(void)
+{
+    TX_NAME(ff_cos_7)[0] = (FFTComplex){ RESCALE(cos(2 * M_PI /  7)), RESCALE(sin(2 * M_PI /  7)) };
+    TX_NAME(ff_cos_7)[1] = (FFTComplex){ RESCALE(sin(2 * M_PI / 28)), RESCALE(cos(2 * M_PI / 28)) };
+    TX_NAME(ff_cos_7)[2] = (FFTComplex){ RESCALE(cos(2 * M_PI / 14)), RESCALE(sin(2 * M_PI / 14)) };
+}
+
 static CosTabsInitOnce cos_tabs_init_once[] = {
     { ff_init_53_tabs, AV_ONCE_INIT },
-    { NULL },
+    { ff_init_7_tabs, AV_ONCE_INIT },
     { NULL },
     { NULL },
     { init_cos_tabs_16, AV_ONCE_INIT },
@@ -204,6 +212,93 @@  DECL_FFT5(fft5_m1,  0,  6, 12,  3,  9)
 DECL_FFT5(fft5_m2, 10,  1,  7, 13,  4)
 DECL_FFT5(fft5_m3,  5, 11,  2,  8, 14)
 
+static av_always_inline void fft7(FFTComplex *out, FFTComplex *in,
+                                  ptrdiff_t stride)
+{
+    FFTComplex t[6], z[3];
+    const FFTComplex *tab = TX_NAME(ff_cos_7);
+#ifdef TX_INT32
+    int64_t mtmp[12];
+#endif
+
+    BF(t[1].re, t[0].re, in[1].re, in[6].re);
+    BF(t[1].im, t[0].im, in[1].im, in[6].im);
+    BF(t[3].re, t[2].re, in[2].re, in[5].re);
+    BF(t[3].im, t[2].im, in[2].im, in[5].im);
+    BF(t[5].re, t[4].re, in[3].re, in[4].re);
+    BF(t[5].im, t[4].im, in[3].im, in[4].im);
+
+    out[0*stride].re = in[0].re + t[0].re + t[2].re + t[4].re;
+    out[0*stride].im = in[0].im + t[0].im + t[2].im + t[4].im;
+
+#ifdef TX_INT32 /* NOTE: it's possible to do this with 16 mults but 72 adds */
+    mtmp[ 0] = ((int64_t)tab[0].re)*t[0].re - ((int64_t)tab[2].re)*t[4].re;
+    mtmp[ 1] = ((int64_t)tab[0].re)*t[4].re - ((int64_t)tab[1].re)*t[0].re;
+    mtmp[ 2] = ((int64_t)tab[0].re)*t[2].re - ((int64_t)tab[2].re)*t[0].re;
+    mtmp[ 3] = ((int64_t)tab[0].re)*t[0].im - ((int64_t)tab[1].re)*t[2].im;
+    mtmp[ 4] = ((int64_t)tab[0].re)*t[4].im - ((int64_t)tab[1].re)*t[0].im;
+    mtmp[ 5] = ((int64_t)tab[0].re)*t[2].im - ((int64_t)tab[2].re)*t[0].im;
+
+    mtmp[ 6] = ((int64_t)tab[2].im)*t[1].im + ((int64_t)tab[1].im)*t[5].im;
+    mtmp[ 7] = ((int64_t)tab[0].im)*t[5].im + ((int64_t)tab[2].im)*t[3].im;
+    mtmp[ 8] = ((int64_t)tab[2].im)*t[5].im + ((int64_t)tab[1].im)*t[3].im;
+    mtmp[ 9] = ((int64_t)tab[0].im)*t[1].re + ((int64_t)tab[1].im)*t[3].re;
+    mtmp[10] = ((int64_t)tab[2].im)*t[3].re + ((int64_t)tab[0].im)*t[5].re;
+    mtmp[11] = ((int64_t)tab[2].im)*t[1].re + ((int64_t)tab[1].im)*t[5].re;
+
+    z[0].re = (int32_t)(mtmp[ 0] - ((int64_t)tab[1].re)*t[2].re + 0x40000000 >> 31);
+    z[1].re = (int32_t)(mtmp[ 1] - ((int64_t)tab[2].re)*t[2].re + 0x40000000 >> 31);
+    z[2].re = (int32_t)(mtmp[ 2] - ((int64_t)tab[1].re)*t[4].re + 0x40000000 >> 31);
+    z[0].im = (int32_t)(mtmp[ 3] - ((int64_t)tab[2].re)*t[4].im + 0x40000000 >> 31);
+    z[1].im = (int32_t)(mtmp[ 4] - ((int64_t)tab[2].re)*t[2].im + 0x40000000 >> 31);
+    z[2].im = (int32_t)(mtmp[ 5] - ((int64_t)tab[1].re)*t[4].im + 0x40000000 >> 31);
+
+    t[0].re = (int32_t)(mtmp[ 6] - ((int64_t)tab[0].im)*t[3].im + 0x40000000 >> 31);
+    t[2].re = (int32_t)(mtmp[ 7] - ((int64_t)tab[1].im)*t[1].im + 0x40000000 >> 31);
+    t[4].re = (int32_t)(mtmp[ 8] + ((int64_t)tab[0].im)*t[1].im + 0x40000000 >> 31);
+    t[0].im = (int32_t)(mtmp[ 9] + ((int64_t)tab[2].im)*t[5].re + 0x40000000 >> 31);
+    t[2].im = (int32_t)(mtmp[10] - ((int64_t)tab[1].im)*t[1].re + 0x40000000 >> 31);
+    t[4].im = (int32_t)(mtmp[11] - ((int64_t)tab[0].im)*t[3].re + 0x40000000 >> 31);
+#else
+    z[0].re = tab[0].re*t[0].re - tab[2].re*t[4].re - tab[1].re*t[2].re;
+    z[1].re = tab[0].re*t[4].re - tab[1].re*t[0].re - tab[2].re*t[2].re;
+    z[2].re = tab[0].re*t[2].re - tab[2].re*t[0].re - tab[1].re*t[4].re;
+    z[0].im = tab[0].re*t[0].im - tab[1].re*t[2].im - tab[2].re*t[4].im;
+    z[1].im = tab[0].re*t[4].im - tab[1].re*t[0].im - tab[2].re*t[2].im;
+    z[2].im = tab[0].re*t[2].im - tab[2].re*t[0].im - tab[1].re*t[4].im;
+
+    /* It's possible to do t[4].re and t[0].im with 2 multiplies only by
+     * multiplying the sum of all with the average of the twiddles */
+
+    t[0].re = tab[2].im*t[1].im + tab[1].im*t[5].im - tab[0].im*t[3].im;
+    t[2].re = tab[0].im*t[5].im + tab[2].im*t[3].im - tab[1].im*t[1].im;
+    t[4].re = tab[2].im*t[5].im + tab[1].im*t[3].im + tab[0].im*t[1].im;
+    t[0].im = tab[0].im*t[1].re + tab[1].im*t[3].re + tab[2].im*t[5].re;
+    t[2].im = tab[2].im*t[3].re + tab[0].im*t[5].re - tab[1].im*t[1].re;
+    t[4].im = tab[2].im*t[1].re + tab[1].im*t[5].re - tab[0].im*t[3].re;
+#endif
+
+    BF(t[1].re, z[0].re, z[0].re, t[4].re);
+    BF(t[3].re, z[1].re, z[1].re, t[2].re);
+    BF(t[5].re, z[2].re, z[2].re, t[0].re);
+    BF(t[1].im, z[0].im, z[0].im, t[0].im);
+    BF(t[3].im, z[1].im, z[1].im, t[2].im);
+    BF(t[5].im, z[2].im, z[2].im, t[4].im);
+
+    out[1*stride].re = in[0].re + z[0].re;
+    out[1*stride].im = in[0].im + t[1].im;
+    out[2*stride].re = in[0].re + t[3].re;
+    out[2*stride].im = in[0].im + z[1].im;
+    out[3*stride].re = in[0].re + z[2].re;
+    out[3*stride].im = in[0].im + t[5].im;
+    out[4*stride].re = in[0].re + t[5].re;
+    out[4*stride].im = in[0].im + z[2].im;
+    out[5*stride].re = in[0].re + z[1].re;
+    out[5*stride].im = in[0].im + t[3].im;
+    out[6*stride].re = in[0].re + t[1].re;
+    out[6*stride].im = in[0].im + z[0].im;
+}
+
 static av_always_inline void fft15(FFTComplex *out, FFTComplex *in,
                                    ptrdiff_t stride)
 {
@@ -376,6 +471,7 @@  static void compound_fft_##N##xM(AVTXContext *s, void *_out,                   \
 
 DECL_COMP_FFT(3)
 DECL_COMP_FFT(5)
+DECL_COMP_FFT(7)
 DECL_COMP_FFT(15)
 
 static void split_radix_fft(AVTXContext *s, void *_out, void *_in,
@@ -473,6 +569,7 @@  static void compound_imdct_##N##xM(AVTXContext *s, void *_dst, void *_src,     \
 
 DECL_COMP_IMDCT(3)
 DECL_COMP_IMDCT(5)
+DECL_COMP_IMDCT(7)
 DECL_COMP_IMDCT(15)
 
 #define DECL_COMP_MDCT(N)                                                      \
@@ -521,6 +618,7 @@  static void compound_mdct_##N##xM(AVTXContext *s, void *_dst, void *_src,      \
 
 DECL_COMP_MDCT(3)
 DECL_COMP_MDCT(5)
+DECL_COMP_MDCT(7)
 DECL_COMP_MDCT(15)
 
 static void monolithic_imdct(AVTXContext *s, void *_dst, void *_src,
@@ -675,6 +773,7 @@  int TX_NAME(ff_tx_init_mdct_fft)(AVTXContext *s, av_tx_fn *tx,
         SRC /= FACTOR;                                                         \
     }
     CHECK_FACTOR(n, 15, len)
+    CHECK_FACTOR(n,  7, len)
     CHECK_FACTOR(n,  5, len)
     CHECK_FACTOR(n,  3, len)
 #undef CHECK_FACTOR
@@ -714,22 +813,29 @@  int TX_NAME(ff_tx_init_mdct_fft)(AVTXContext *s, av_tx_fn *tx,
             return err;
         if (!(s->tmp = av_malloc(n*m*sizeof(*s->tmp))))
             return AVERROR(ENOMEM);
-        *tx = n == 3 ? compound_fft_3xM :
-              n == 5 ? compound_fft_5xM :
-                       compound_fft_15xM;
-        if (is_mdct)
-            *tx = n == 3 ? inv ? compound_imdct_3xM  : compound_mdct_3xM :
-                  n == 5 ? inv ? compound_imdct_5xM  : compound_mdct_5xM :
-                           inv ? compound_imdct_15xM : compound_mdct_15xM;
+        if (!(m & (m - 1))) {
+            *tx = n == 3 ? compound_fft_3xM :
+                  n == 5 ? compound_fft_5xM :
+                  n == 7 ? compound_fft_7xM :
+                           compound_fft_15xM;
+            if (is_mdct)
+                *tx = n == 3 ? inv ? compound_imdct_3xM  : compound_mdct_3xM :
+                      n == 5 ? inv ? compound_imdct_5xM  : compound_mdct_5xM :
+                      n == 7 ? inv ? compound_imdct_7xM  : compound_mdct_7xM :
+                               inv ? compound_imdct_15xM : compound_mdct_15xM;
+        }
     } else { /* Direct transform case */
         *tx = split_radix_fft;
         if (is_mdct)
             *tx = inv ? monolithic_imdct : monolithic_mdct;
     }
 
-    if (n != 1)
+    if (n == 3 || n == 5 || n == 15)
         init_cos_tabs(0);
-    if (m != 1) {
+    else if (n == 7)
+        init_cos_tabs(1);
+
+    if (m != 1 && !(m & (m - 1))) {
         if ((err = ff_tx_gen_ptwo_revtab(s, n == 1 && !is_mdct && !(flags & AV_TX_INPLACE))))
             return err;
         if (flags & AV_TX_INPLACE) {