diff mbox

[FFmpeg-devel,v2] mdct15: add inverse transform postrotation SIMD

Message ID 20170730004816.15799-1-atomnuker@gmail.com
State New
Headers show

Commit Message

Rostislav Pehlivanov July 30, 2017, 12:48 a.m. UTC
Speeds up decoding by 8% in total in the avx2 case.

20ms frames:
Before   (c):  17774 decicycles in postrotate,  262065 runs,     79 skips
After (sse3):   9624 decicycles in postrotate,  262113 runs,     31 skips
After (avx2):   7169 decicycles in postrotate,  262104 runs,     40 skips

10ms frames:
Before   (c):   9058 decicycles in postrotate,  524209 runs,     79 skips
After (sse3):   4964 decicycles in postrotate,  524236 runs,     52 skips
After (avx2):   3915 decicycles in postrotate,  524236 runs,     52 skips

5ms frames:
Before   (c):   4764 decicycles in postrotate, 1048466 runs,    110 skips
After (sse3):   2670 decicycles in postrotate, 1048507 runs,     69 skips
After (avx2):   2161 decicycles in postrotate, 1048515 runs,     61 skips

2.5ms frames:
Before   (c):   2608 decicycles in postrotate, 2097030 runs,    122 skips
After (sse3):   1507 decicycles in postrotate, 2097089 runs,     63 skips
After (avx2):   1377 decicycles in postrotate, 2097097 runs,     55 skips

Needs to overwrite the start of some buffers as well as the
end of them, hence the OVERALLOC stuff.

Signed-off-by: Rostislav Pehlivanov <atomnuker@gmail.com>
---
 libavcodec/mdct15.c          | 74 ++++++++++++++++++++++++-----------
 libavcodec/mdct15.h          |  3 ++
 libavcodec/x86/mdct15.asm    | 93 +++++++++++++++++++++++++++++++++++++++++++-
 libavcodec/x86/mdct15_init.c |  9 +++++
 4 files changed, 155 insertions(+), 24 deletions(-)

Comments

James Almer July 30, 2017, 1:30 a.m. UTC | #1
On 7/29/2017 9:48 PM, Rostislav Pehlivanov wrote:
> Speeds up decoding by 8% in total in the avx2 case.
> 
> 20ms frames:
> Before   (c):  17774 decicycles in postrotate,  262065 runs,     79 skips
> After (sse3):   9624 decicycles in postrotate,  262113 runs,     31 skips
> After (avx2):   7169 decicycles in postrotate,  262104 runs,     40 skips
> 
> 10ms frames:
> Before   (c):   9058 decicycles in postrotate,  524209 runs,     79 skips
> After (sse3):   4964 decicycles in postrotate,  524236 runs,     52 skips
> After (avx2):   3915 decicycles in postrotate,  524236 runs,     52 skips
> 
> 5ms frames:
> Before   (c):   4764 decicycles in postrotate, 1048466 runs,    110 skips
> After (sse3):   2670 decicycles in postrotate, 1048507 runs,     69 skips
> After (avx2):   2161 decicycles in postrotate, 1048515 runs,     61 skips
> 
> 2.5ms frames:
> Before   (c):   2608 decicycles in postrotate, 2097030 runs,    122 skips
> After (sse3):   1507 decicycles in postrotate, 2097089 runs,     63 skips
> After (avx2):   1377 decicycles in postrotate, 2097097 runs,     55 skips
> 
> Needs to overwrite the start of some buffers as well as the
> end of them, hence the OVERALLOC stuff.
> 
> Signed-off-by: Rostislav Pehlivanov <atomnuker@gmail.com>
> ---
>  libavcodec/mdct15.c          | 74 ++++++++++++++++++++++++-----------
>  libavcodec/mdct15.h          |  3 ++
>  libavcodec/x86/mdct15.asm    | 93 +++++++++++++++++++++++++++++++++++++++++++-
>  libavcodec/x86/mdct15_init.c |  9 +++++
>  4 files changed, 155 insertions(+), 24 deletions(-)
> 
> diff --git a/libavcodec/mdct15.c b/libavcodec/mdct15.c
> index d68372c344..9838082c7e 100644
> --- a/libavcodec/mdct15.c
> +++ b/libavcodec/mdct15.c
> @@ -28,6 +28,7 @@
>  #include <math.h>
>  #include <stddef.h>
>  
> +#include "avcodec.h"
>  #include "config.h"
>  
>  #include "libavutil/attributes.h"
> @@ -40,6 +41,25 @@
>  
>  #define CMUL3(c, a, b) CMUL((c).re, (c).im, (a).re, (a).im, (b).re, (b).im)
>  
> +#define OVERALLOC(val, len, size)                                           \
> +    {                                                                       \
> +        const int pad = AV_INPUT_BUFFER_PADDING_SIZE/size;                  \
> +        (val) = NULL;                                                       \
> +        uint8_t *temp = av_mallocz_array(len + pad, size);                  \
> +        if (temp)                                                           \
> +            (val) = (void *)(temp + AV_INPUT_BUFFER_PADDING_SIZE);          \
> +    }
> +
> +#define OVERFREEP(val)                                                      \
> +    {                                                                       \
> +        uint8_t *temp = (uint8_t *)(val);                                   \
> +        if (temp) {                                                         \
> +            temp -= AV_INPUT_BUFFER_PADDING_SIZE;                           \
> +            av_free(temp);                                                  \
> +        }                                                                   \
> +        val = NULL;                                                         \
> +    }
> +
>  av_cold void ff_mdct15_uninit(MDCT15Context **ps)
>  {
>      MDCT15Context *s = *ps;
> @@ -50,9 +70,9 @@ av_cold void ff_mdct15_uninit(MDCT15Context **ps)
>      ff_fft_end(&s->ptwo_fft);
>  
>      av_freep(&s->pfa_prereindex);
> -    av_freep(&s->pfa_postreindex);
> -    av_freep(&s->twiddle_exptab);
> -    av_freep(&s->tmp);
> +    OVERFREEP(s->pfa_postreindex);
> +    OVERFREEP(s->twiddle_exptab);
> +    OVERFREEP(s->tmp);
>  
>      av_freep(ps);
>  }
> @@ -65,11 +85,11 @@ static inline int init_pfa_reindex_tabs(MDCT15Context *s)
>      const int inv_1 = l_ptwo << ((4 - b_ptwo) & 3); /* (2^b_ptwo)^-1 mod 15 */
>      const int inv_2 = 0xeeeeeeef & ((1U << b_ptwo) - 1); /* 15^-1 mod 2^b_ptwo */
>  
> -    s->pfa_prereindex = av_malloc(15 * l_ptwo * sizeof(*s->pfa_prereindex));
> +    s->pfa_prereindex = av_malloc_array(15 * l_ptwo, sizeof(*s->pfa_prereindex));
>      if (!s->pfa_prereindex)
>          return 1;
>  
> -    s->pfa_postreindex = av_malloc(15 * l_ptwo * sizeof(*s->pfa_postreindex));
> +    OVERALLOC(s->pfa_postreindex, 15 * l_ptwo, sizeof(*s->pfa_postreindex));
>      if (!s->pfa_postreindex)
>          return 1;
>  
> @@ -203,6 +223,21 @@ static void mdct15(MDCT15Context *s, float *dst, const float *src, ptrdiff_t str
>      }
>  }
>  
> +static void postrotate_c(FFTComplex *out, FFTComplex *in, FFTComplex *exp,
> +                         int *lut, ptrdiff_t len8)
> +{
> +    int i;
> +
> +    /* Reindex again, apply twiddles and output */
> +    for (i = 0; i < len8; i++) {
> +        const int i0 = len8 + i, i1 = len8 - i - 1;
> +        const int s0 = lut[i0], s1 = lut[i1];
> +
> +        CMUL(out[i1].re, out[i0].im, in[s1].im, in[s1].re, exp[i1].im, exp[i1].re);
> +        CMUL(out[i0].re, out[i1].im, in[s0].im, in[s0].re, exp[i0].im, exp[i0].re);
> +    }
> +}
> +
>  static void imdct15_half(MDCT15Context *s, float *dst, const float *src,
>                           ptrdiff_t stride)
>  {
> @@ -226,15 +261,7 @@ static void imdct15_half(MDCT15Context *s, float *dst, const float *src,
>          s->ptwo_fft.fft_calc(&s->ptwo_fft, s->tmp + l_ptwo*i);
>  
>      /* Reindex again, apply twiddles and output */
> -    for (i = 0; i < len8; i++) {
> -        const int i0 = len8 + i, i1 = len8 - i - 1;
> -        const int s0 = s->pfa_postreindex[i0], s1 = s->pfa_postreindex[i1];
> -
> -        CMUL(z[i1].re, z[i0].im, s->tmp[s1].im, s->tmp[s1].re,
> -             s->twiddle_exptab[i1].im, s->twiddle_exptab[i1].re);
> -        CMUL(z[i0].re, z[i1].im, s->tmp[s0].im, s->tmp[s0].re,
> -             s->twiddle_exptab[i0].im, s->twiddle_exptab[i0].re);
> -    }
> +    s->postreindex(z, s->tmp, s->twiddle_exptab, s->pfa_postreindex, len8);
>  }
>  
>  av_cold int ff_mdct15_init(MDCT15Context **ps, int inverse, int N, double scale)
> @@ -253,13 +280,14 @@ av_cold int ff_mdct15_init(MDCT15Context **ps, int inverse, int N, double scale)
>      if (!s)
>          return AVERROR(ENOMEM);
>  
> -    s->fft_n      = N - 1;
> -    s->len4       = len2 / 2;
> -    s->len2       = len2;
> -    s->inverse    = inverse;
> -    s->fft15      = fft15_c;
> -    s->mdct       = mdct15;
> -    s->imdct_half = imdct15_half;
> +    s->fft_n       = N - 1;
> +    s->len4        = len2 / 2;
> +    s->len2        = len2;
> +    s->inverse     = inverse;
> +    s->fft15       = fft15_c;
> +    s->mdct        = mdct15;
> +    s->imdct_half  = imdct15_half;
> +    s->postreindex = postrotate_c;
>  
>      if (ff_fft_init(&s->ptwo_fft, N - 1, s->inverse) < 0)
>          goto fail;
> @@ -267,11 +295,11 @@ av_cold int ff_mdct15_init(MDCT15Context **ps, int inverse, int N, double scale)
>      if (init_pfa_reindex_tabs(s))
>          goto fail;
>  
> -    s->tmp  = av_malloc_array(len, 2 * sizeof(*s->tmp));
> +    OVERALLOC(s->tmp, 2*len, sizeof(*s->tmp));
>      if (!s->tmp)
>          goto fail;
>  
> -    s->twiddle_exptab  = av_malloc_array(s->len4, sizeof(*s->twiddle_exptab));
> +    OVERALLOC(s->twiddle_exptab, s->len4, sizeof(*s->twiddle_exptab));
>      if (!s->twiddle_exptab)
>          goto fail;
>  
> diff --git a/libavcodec/mdct15.h b/libavcodec/mdct15.h
> index 1c2149d436..42e60f3e10 100644
> --- a/libavcodec/mdct15.h
> +++ b/libavcodec/mdct15.h
> @@ -42,6 +42,9 @@ typedef struct MDCT15Context {
>      /* 15-point FFT */
>      void (*fft15)(FFTComplex *out, FFTComplex *in, FFTComplex *exptab, ptrdiff_t stride);
>  
> +    /* PFA postrotate and exptab */
> +    void (*postreindex)(FFTComplex *out, FFTComplex *in, FFTComplex *exp, int *lut, ptrdiff_t len8);
> +
>      /* Calculate a full 2N -> N MDCT */
>      void (*mdct)(struct MDCT15Context *s, float *dst, const float *src, ptrdiff_t stride);
>  
> diff --git a/libavcodec/x86/mdct15.asm b/libavcodec/x86/mdct15.asm
> index f8b895944d..b42adb4aa9 100644
> --- a/libavcodec/x86/mdct15.asm
> +++ b/libavcodec/x86/mdct15.asm
> @@ -24,7 +24,11 @@
>  
>  %if ARCH_X86_64
>  
> -SECTION_RODATA
> +SECTION_RODATA 32
> +
> +perm_neg: dd 2, 5, 3, 4, 6, 1, 7, 0
> +perm_pos: dd 0, 7, 1, 6, 4, 3, 5, 2
> +sign_adjust_r: times 4 dd 0x80000000, 0x00000000
>  
>  sign_adjust_5: dd 0x00000000, 0x80000000, 0x80000000, 0x00000000
>  
> @@ -138,4 +142,91 @@ cglobal fft15, 4, 6, 14, out, in, exptab, stride, stride3, stride5
>  
>      RET
>  
> +%macro LUT_LOAD_4D 3
> +    mov      r7d, [lutq + %3q*4 +  0]
> +    movsd  xmm%1, [inq +  r7q*8]
> +    mov      r7d, [lutq + %3q*4 +  4]
> +    movhps xmm%1, [inq +  r7q*8]
> +%if cpuflag(avx2)
> +    mov      r7d, [lutq + %3q*4 +  8]
> +    movsd     %2, [inq +  r7q*8]
> +    mov      r7d, [lutq + %3q*4 + 12]
> +    movhps    %2, [inq +  r7q*8]
> +    vinsertf128 %1, %1, %2, 1
> +%endif
> +%endmacro
> +
> +%macro POSTROTATE_FN 0
> +;**********************************************************************************************************
> +;void ff_mdct15_postreindex(FFTComplex *out, FFTComplex *in, FFTComplex *exp, uint32_t *lut, int64_t len8);
> +;**********************************************************************************************************

Nit: Move this above the LUT_LOAD_4D macro, so it's clear where all the
postreindex stuff starts.
Also, you forgot to replace the uint32_t and int64_t here.

> +cglobal mdct15_postreindex, 5, 8, 12, out, in, exp, lut, len8, offset_p, offset_n
> +%if cpuflag(avx2)
> +    %define INCREMENT 4
> +%else
> +    %define INCREMENT 2

You could make this a POSTROTATE_FN macro argument instead.

> +%endif
> +
> +    mova m7, [perm_pos]
> +    mova m8, [perm_neg]
> +    mova m9, [sign_adjust_r]

Change these three to movaps, since initializing the functions with sse3
and avx2 makes mova/u aliases of movdqa/u.

> +
> +    mov offset_pq, len8q
> +    lea offset_nq, [len8q - INCREMENT]
> +
> +    shl len8q, 1
> +
> +    movups m10, [outq - mmsize]         ; backup from start - mmsize to start
> +    movups m11, [outq + len8q*8]        ; backup from end to end + mmsize
> +
> +.loop:
> +    movups m0, [expq + offset_pq*8]     ; exp[p0].re, exp[p0].im, exp[p1].re, exp[p1].im, exp[p2].re, exp[p2].im, exp[p3].re, exp[p3].im
> +    movups m1, [expq + offset_nq*8]     ; exp[n3].re, exp[n3].im, exp[n2].re, exp[n2].im, exp[n1].re, exp[n1].im, exp[n0].re, exp[n0].im
> +
> +    LUT_LOAD_4D m3, xmm4, offset_p      ; in[p0].re, in[p0].im, in[p1].re, in[p1].im, in[p2].re, in[p2].im, in[p3].re, in[p3].im
> +    LUT_LOAD_4D m4, xmm5, offset_n      ; in[n3].re, in[n3].im, in[n2].re, in[n2].im, in[n1].re, in[n1].im, in[n0].re, in[n0].im

Nit: xm4 and xm5

> +
> +    mulps m5, m3, m0                    ; in[p].reim * exp[p].reim
> +    mulps m6, m4, m1                    ; in[n].reim * exp[n].reim
> +
> +    xorps m5, m9                        ; in[p].re *= -1, in[p].im *= 1
> +    xorps m6, m9                        ; in[n].re *= -1, in[n].im *= 1
> +
> +    shufps m3, m3, m3, q2301            ; in[p].imre
> +    shufps m4, m4, m4, q2301            ; in[n].imre
> +
> +    mulps m3, m0                        ; in[p].imre * exp[p].reim
> +    mulps m4, m1                        ; in[n].imre * exp[n].reim
> +
> +    haddps m5, m4                       ; out[p0].re, out[p1].re, out[p3].im, out[p2].im, out[p2].re, out[p3].re, out[p1].im, out[p0].im
> +    haddps m3, m6                       ; out[n0].im, out[n1].im, out[n3].re, out[n2].re, out[n2].im, out[n3].im, out[n1].re, out[n0].re
> +
> +%if cpuflag(avx2)
> +    vpermps m5, m7, m5                  ; out[p0].re, out[p0].im, out[p1].re, out[p1].im, out[p2].re, out[p2].im, out[p3].re, out[p3].im
> +    vpermps m3, m8, m3                  ; out[n3].im, out[n3].re, out[n2].im, out[n2].re, out[n1].im, out[n1].re, out[n0].im, out[n0].re
> +%else
> +    shufps m5, m5, m5, q2130
> +    shufps m3, m3, m3, q0312
> +%endif
> +
> +    movups [outq + offset_pq*8], m5
> +    movups [outq + offset_nq*8], m3
> +
> +    sub offset_nq, INCREMENT
> +    add offset_pq, INCREMENT
> +
> +    cmp offset_pq, len8q
> +    jl .loop
> +
> +    movups [outq - mmsize],  m10
> +    movups [outq + len8q*8], m11
> +
> +    RET
> +%endmacro
> +
> +INIT_XMM sse3
> +POSTROTATE_FN
> +INIT_YMM avx2
> +POSTROTATE_FN

Wrap the two avx2 lines in a HAVE_AVX2_EXTERNAL check or it will fail to
assemble with Yasm 1.1.0 and older.

> +
>  %endif
> diff --git a/libavcodec/x86/mdct15_init.c b/libavcodec/x86/mdct15_init.c
> index ba3d94c2ec..ec4ff42bb6 100644
> --- a/libavcodec/x86/mdct15_init.c
> +++ b/libavcodec/x86/mdct15_init.c
> @@ -25,6 +25,9 @@
>  #include "libavutil/x86/cpu.h"
>  #include "libavcodec/mdct15.h"
>  
> +void ff_mdct15_postreindex_sse3(FFTComplex *out, FFTComplex *in, FFTComplex *exp, int *lut, ptrdiff_t len8);
> +void ff_mdct15_postreindex_avx2(FFTComplex *out, FFTComplex *in, FFTComplex *exp, int *lut, ptrdiff_t len8);
> +
>  void ff_fft15_avx(FFTComplex *out, FFTComplex *in, FFTComplex *exptab, ptrdiff_t stride);
>  
>  static void perm_twiddles(MDCT15Context *s)
> @@ -90,6 +93,12 @@ av_cold void ff_mdct15_init_x86(MDCT15Context *s)
>          adjust_twiddles = 1;
>      }
>  
> +    if (ARCH_X86_64 && EXTERNAL_SSE3(cpu_flags))
> +        s->postreindex = ff_mdct15_postreindex_sse3;

SSE3 goes before AVX.

> +
> +    if (ARCH_X86_64 && EXTERNAL_AVX2(cpu_flags))

EXTERNAL_AVX2_FAST(cpu_flags)

> +        s->postreindex = ff_mdct15_postreindex_avx2;
> +
>      if (adjust_twiddles)
>          perm_twiddles(s);
>  }

Maybe poke Hendrik for his opinion, but it seems to work, so LGTM.
Rostislav Pehlivanov July 30, 2017, 6:46 a.m. UTC | #2
On 30 July 2017 at 02:30, James Almer <jamrial@gmail.com> wrote:

>
>
> Maybe poke Hendrik for his opinion, but it seems to work, so LGTM.
>
>
Managed to simplify the code and the crazy alignment requirements alot by
just iterating over the buffer in reverse. There the overlapping in the
middle solved itself by writing the positive part last.
No need for any overalloc macros too or even having to zero any buffers
during init.

Thanks for the review, pushed.
diff mbox

Patch

diff --git a/libavcodec/mdct15.c b/libavcodec/mdct15.c
index d68372c344..9838082c7e 100644
--- a/libavcodec/mdct15.c
+++ b/libavcodec/mdct15.c
@@ -28,6 +28,7 @@ 
 #include <math.h>
 #include <stddef.h>
 
+#include "avcodec.h"
 #include "config.h"
 
 #include "libavutil/attributes.h"
@@ -40,6 +41,25 @@ 
 
 #define CMUL3(c, a, b) CMUL((c).re, (c).im, (a).re, (a).im, (b).re, (b).im)
 
+#define OVERALLOC(val, len, size)                                           \
+    {                                                                       \
+        const int pad = AV_INPUT_BUFFER_PADDING_SIZE/size;                  \
+        (val) = NULL;                                                       \
+        uint8_t *temp = av_mallocz_array(len + pad, size);                  \
+        if (temp)                                                           \
+            (val) = (void *)(temp + AV_INPUT_BUFFER_PADDING_SIZE);          \
+    }
+
+#define OVERFREEP(val)                                                      \
+    {                                                                       \
+        uint8_t *temp = (uint8_t *)(val);                                   \
+        if (temp) {                                                         \
+            temp -= AV_INPUT_BUFFER_PADDING_SIZE;                           \
+            av_free(temp);                                                  \
+        }                                                                   \
+        val = NULL;                                                         \
+    }
+
 av_cold void ff_mdct15_uninit(MDCT15Context **ps)
 {
     MDCT15Context *s = *ps;
@@ -50,9 +70,9 @@  av_cold void ff_mdct15_uninit(MDCT15Context **ps)
     ff_fft_end(&s->ptwo_fft);
 
     av_freep(&s->pfa_prereindex);
-    av_freep(&s->pfa_postreindex);
-    av_freep(&s->twiddle_exptab);
-    av_freep(&s->tmp);
+    OVERFREEP(s->pfa_postreindex);
+    OVERFREEP(s->twiddle_exptab);
+    OVERFREEP(s->tmp);
 
     av_freep(ps);
 }
@@ -65,11 +85,11 @@  static inline int init_pfa_reindex_tabs(MDCT15Context *s)
     const int inv_1 = l_ptwo << ((4 - b_ptwo) & 3); /* (2^b_ptwo)^-1 mod 15 */
     const int inv_2 = 0xeeeeeeef & ((1U << b_ptwo) - 1); /* 15^-1 mod 2^b_ptwo */
 
-    s->pfa_prereindex = av_malloc(15 * l_ptwo * sizeof(*s->pfa_prereindex));
+    s->pfa_prereindex = av_malloc_array(15 * l_ptwo, sizeof(*s->pfa_prereindex));
     if (!s->pfa_prereindex)
         return 1;
 
-    s->pfa_postreindex = av_malloc(15 * l_ptwo * sizeof(*s->pfa_postreindex));
+    OVERALLOC(s->pfa_postreindex, 15 * l_ptwo, sizeof(*s->pfa_postreindex));
     if (!s->pfa_postreindex)
         return 1;
 
@@ -203,6 +223,21 @@  static void mdct15(MDCT15Context *s, float *dst, const float *src, ptrdiff_t str
     }
 }
 
+static void postrotate_c(FFTComplex *out, FFTComplex *in, FFTComplex *exp,
+                         int *lut, ptrdiff_t len8)
+{
+    int i;
+
+    /* Reindex again, apply twiddles and output */
+    for (i = 0; i < len8; i++) {
+        const int i0 = len8 + i, i1 = len8 - i - 1;
+        const int s0 = lut[i0], s1 = lut[i1];
+
+        CMUL(out[i1].re, out[i0].im, in[s1].im, in[s1].re, exp[i1].im, exp[i1].re);
+        CMUL(out[i0].re, out[i1].im, in[s0].im, in[s0].re, exp[i0].im, exp[i0].re);
+    }
+}
+
 static void imdct15_half(MDCT15Context *s, float *dst, const float *src,
                          ptrdiff_t stride)
 {
@@ -226,15 +261,7 @@  static void imdct15_half(MDCT15Context *s, float *dst, const float *src,
         s->ptwo_fft.fft_calc(&s->ptwo_fft, s->tmp + l_ptwo*i);
 
     /* Reindex again, apply twiddles and output */
-    for (i = 0; i < len8; i++) {
-        const int i0 = len8 + i, i1 = len8 - i - 1;
-        const int s0 = s->pfa_postreindex[i0], s1 = s->pfa_postreindex[i1];
-
-        CMUL(z[i1].re, z[i0].im, s->tmp[s1].im, s->tmp[s1].re,
-             s->twiddle_exptab[i1].im, s->twiddle_exptab[i1].re);
-        CMUL(z[i0].re, z[i1].im, s->tmp[s0].im, s->tmp[s0].re,
-             s->twiddle_exptab[i0].im, s->twiddle_exptab[i0].re);
-    }
+    s->postreindex(z, s->tmp, s->twiddle_exptab, s->pfa_postreindex, len8);
 }
 
 av_cold int ff_mdct15_init(MDCT15Context **ps, int inverse, int N, double scale)
@@ -253,13 +280,14 @@  av_cold int ff_mdct15_init(MDCT15Context **ps, int inverse, int N, double scale)
     if (!s)
         return AVERROR(ENOMEM);
 
-    s->fft_n      = N - 1;
-    s->len4       = len2 / 2;
-    s->len2       = len2;
-    s->inverse    = inverse;
-    s->fft15      = fft15_c;
-    s->mdct       = mdct15;
-    s->imdct_half = imdct15_half;
+    s->fft_n       = N - 1;
+    s->len4        = len2 / 2;
+    s->len2        = len2;
+    s->inverse     = inverse;
+    s->fft15       = fft15_c;
+    s->mdct        = mdct15;
+    s->imdct_half  = imdct15_half;
+    s->postreindex = postrotate_c;
 
     if (ff_fft_init(&s->ptwo_fft, N - 1, s->inverse) < 0)
         goto fail;
@@ -267,11 +295,11 @@  av_cold int ff_mdct15_init(MDCT15Context **ps, int inverse, int N, double scale)
     if (init_pfa_reindex_tabs(s))
         goto fail;
 
-    s->tmp  = av_malloc_array(len, 2 * sizeof(*s->tmp));
+    OVERALLOC(s->tmp, 2*len, sizeof(*s->tmp));
     if (!s->tmp)
         goto fail;
 
-    s->twiddle_exptab  = av_malloc_array(s->len4, sizeof(*s->twiddle_exptab));
+    OVERALLOC(s->twiddle_exptab, s->len4, sizeof(*s->twiddle_exptab));
     if (!s->twiddle_exptab)
         goto fail;
 
diff --git a/libavcodec/mdct15.h b/libavcodec/mdct15.h
index 1c2149d436..42e60f3e10 100644
--- a/libavcodec/mdct15.h
+++ b/libavcodec/mdct15.h
@@ -42,6 +42,9 @@  typedef struct MDCT15Context {
     /* 15-point FFT */
     void (*fft15)(FFTComplex *out, FFTComplex *in, FFTComplex *exptab, ptrdiff_t stride);
 
+    /* PFA postrotate and exptab */
+    void (*postreindex)(FFTComplex *out, FFTComplex *in, FFTComplex *exp, int *lut, ptrdiff_t len8);
+
     /* Calculate a full 2N -> N MDCT */
     void (*mdct)(struct MDCT15Context *s, float *dst, const float *src, ptrdiff_t stride);
 
diff --git a/libavcodec/x86/mdct15.asm b/libavcodec/x86/mdct15.asm
index f8b895944d..b42adb4aa9 100644
--- a/libavcodec/x86/mdct15.asm
+++ b/libavcodec/x86/mdct15.asm
@@ -24,7 +24,11 @@ 
 
 %if ARCH_X86_64
 
-SECTION_RODATA
+SECTION_RODATA 32
+
+perm_neg: dd 2, 5, 3, 4, 6, 1, 7, 0
+perm_pos: dd 0, 7, 1, 6, 4, 3, 5, 2
+sign_adjust_r: times 4 dd 0x80000000, 0x00000000
 
 sign_adjust_5: dd 0x00000000, 0x80000000, 0x80000000, 0x00000000
 
@@ -138,4 +142,91 @@  cglobal fft15, 4, 6, 14, out, in, exptab, stride, stride3, stride5
 
     RET
 
+%macro LUT_LOAD_4D 3
+    mov      r7d, [lutq + %3q*4 +  0]
+    movsd  xmm%1, [inq +  r7q*8]
+    mov      r7d, [lutq + %3q*4 +  4]
+    movhps xmm%1, [inq +  r7q*8]
+%if cpuflag(avx2)
+    mov      r7d, [lutq + %3q*4 +  8]
+    movsd     %2, [inq +  r7q*8]
+    mov      r7d, [lutq + %3q*4 + 12]
+    movhps    %2, [inq +  r7q*8]
+    vinsertf128 %1, %1, %2, 1
+%endif
+%endmacro
+
+%macro POSTROTATE_FN 0
+;**********************************************************************************************************
+;void ff_mdct15_postreindex(FFTComplex *out, FFTComplex *in, FFTComplex *exp, uint32_t *lut, int64_t len8);
+;**********************************************************************************************************
+cglobal mdct15_postreindex, 5, 8, 12, out, in, exp, lut, len8, offset_p, offset_n
+%if cpuflag(avx2)
+    %define INCREMENT 4
+%else
+    %define INCREMENT 2
+%endif
+
+    mova m7, [perm_pos]
+    mova m8, [perm_neg]
+    mova m9, [sign_adjust_r]
+
+    mov offset_pq, len8q
+    lea offset_nq, [len8q - INCREMENT]
+
+    shl len8q, 1
+
+    movups m10, [outq - mmsize]         ; backup from start - mmsize to start
+    movups m11, [outq + len8q*8]        ; backup from end to end + mmsize
+
+.loop:
+    movups m0, [expq + offset_pq*8]     ; exp[p0].re, exp[p0].im, exp[p1].re, exp[p1].im, exp[p2].re, exp[p2].im, exp[p3].re, exp[p3].im
+    movups m1, [expq + offset_nq*8]     ; exp[n3].re, exp[n3].im, exp[n2].re, exp[n2].im, exp[n1].re, exp[n1].im, exp[n0].re, exp[n0].im
+
+    LUT_LOAD_4D m3, xmm4, offset_p      ; in[p0].re, in[p0].im, in[p1].re, in[p1].im, in[p2].re, in[p2].im, in[p3].re, in[p3].im
+    LUT_LOAD_4D m4, xmm5, offset_n      ; in[n3].re, in[n3].im, in[n2].re, in[n2].im, in[n1].re, in[n1].im, in[n0].re, in[n0].im
+
+    mulps m5, m3, m0                    ; in[p].reim * exp[p].reim
+    mulps m6, m4, m1                    ; in[n].reim * exp[n].reim
+
+    xorps m5, m9                        ; in[p].re *= -1, in[p].im *= 1
+    xorps m6, m9                        ; in[n].re *= -1, in[n].im *= 1
+
+    shufps m3, m3, m3, q2301            ; in[p].imre
+    shufps m4, m4, m4, q2301            ; in[n].imre
+
+    mulps m3, m0                        ; in[p].imre * exp[p].reim
+    mulps m4, m1                        ; in[n].imre * exp[n].reim
+
+    haddps m5, m4                       ; out[p0].re, out[p1].re, out[p3].im, out[p2].im, out[p2].re, out[p3].re, out[p1].im, out[p0].im
+    haddps m3, m6                       ; out[n0].im, out[n1].im, out[n3].re, out[n2].re, out[n2].im, out[n3].im, out[n1].re, out[n0].re
+
+%if cpuflag(avx2)
+    vpermps m5, m7, m5                  ; out[p0].re, out[p0].im, out[p1].re, out[p1].im, out[p2].re, out[p2].im, out[p3].re, out[p3].im
+    vpermps m3, m8, m3                  ; out[n3].im, out[n3].re, out[n2].im, out[n2].re, out[n1].im, out[n1].re, out[n0].im, out[n0].re
+%else
+    shufps m5, m5, m5, q2130
+    shufps m3, m3, m3, q0312
+%endif
+
+    movups [outq + offset_pq*8], m5
+    movups [outq + offset_nq*8], m3
+
+    sub offset_nq, INCREMENT
+    add offset_pq, INCREMENT
+
+    cmp offset_pq, len8q
+    jl .loop
+
+    movups [outq - mmsize],  m10
+    movups [outq + len8q*8], m11
+
+    RET
+%endmacro
+
+INIT_XMM sse3
+POSTROTATE_FN
+INIT_YMM avx2
+POSTROTATE_FN
+
 %endif
diff --git a/libavcodec/x86/mdct15_init.c b/libavcodec/x86/mdct15_init.c
index ba3d94c2ec..ec4ff42bb6 100644
--- a/libavcodec/x86/mdct15_init.c
+++ b/libavcodec/x86/mdct15_init.c
@@ -25,6 +25,9 @@ 
 #include "libavutil/x86/cpu.h"
 #include "libavcodec/mdct15.h"
 
+void ff_mdct15_postreindex_sse3(FFTComplex *out, FFTComplex *in, FFTComplex *exp, int *lut, ptrdiff_t len8);
+void ff_mdct15_postreindex_avx2(FFTComplex *out, FFTComplex *in, FFTComplex *exp, int *lut, ptrdiff_t len8);
+
 void ff_fft15_avx(FFTComplex *out, FFTComplex *in, FFTComplex *exptab, ptrdiff_t stride);
 
 static void perm_twiddles(MDCT15Context *s)
@@ -90,6 +93,12 @@  av_cold void ff_mdct15_init_x86(MDCT15Context *s)
         adjust_twiddles = 1;
     }
 
+    if (ARCH_X86_64 && EXTERNAL_SSE3(cpu_flags))
+        s->postreindex = ff_mdct15_postreindex_sse3;
+
+    if (ARCH_X86_64 && EXTERNAL_AVX2(cpu_flags))
+        s->postreindex = ff_mdct15_postreindex_avx2;
+
     if (adjust_twiddles)
         perm_twiddles(s);
 }