diff mbox

[FFmpeg-devel] mdct15: add inverse transform postrotation SIMD

Message ID 20170729203809.31398-1-atomnuker@gmail.com
State Superseded
Headers show

Commit Message

Rostislav Pehlivanov July 29, 2017, 8:38 p.m. UTC
Speeds up decoding by 8% in total.

20ms frames:
Before: 17774 decicycles in postrotate,  262065 runs,     79 skips
After:   7169 decicycles in postrotate,  262104 runs,     40 skips

10ms frames:
Before: 9058 decicycles in postrotate,  524209 runs,     79 skips
After:  3915 decicycles in postrotate,  524236 runs,     52 skips

5ms frames:
Before: 4764 decicycles in postrotate, 1048466 runs,    110 skips
After:  2161 decicycles in postrotate, 1048515 runs,     61 skips

2.5ms frames:
Before: 2608 decicycles in postrotate, 2097030 runs,    122 skips
After:  1377 decicycles in postrotate, 2097097 runs,     55 skips

Needs to overwrite the start of some buffers as well as the
end of them, hence the OVERALLOC stuff.

Signed-off-by: Rostislav Pehlivanov <atomnuker@gmail.com>
---
 libavcodec/mdct15.c          | 75 +++++++++++++++++++++++++++++++-------------
 libavcodec/mdct15.h          |  7 +++--
 libavcodec/x86/mdct15.asm    | 72 ++++++++++++++++++++++++++++++++++++++++++
 libavcodec/x86/mdct15_init.c |  4 +++
 4 files changed, 134 insertions(+), 24 deletions(-)

Comments

James Almer July 29, 2017, 9:37 p.m. UTC | #1
On 7/29/2017 5:38 PM, Rostislav Pehlivanov wrote:
> Speeds up decoding by 8% in total.
> 
> 20ms frames:
> Before: 17774 decicycles in postrotate,  262065 runs,     79 skips
> After:   7169 decicycles in postrotate,  262104 runs,     40 skips
> 
> 10ms frames:
> Before: 9058 decicycles in postrotate,  524209 runs,     79 skips
> After:  3915 decicycles in postrotate,  524236 runs,     52 skips
> 
> 5ms frames:
> Before: 4764 decicycles in postrotate, 1048466 runs,    110 skips
> After:  2161 decicycles in postrotate, 1048515 runs,     61 skips
> 
> 2.5ms frames:
> Before: 2608 decicycles in postrotate, 2097030 runs,    122 skips
> After:  1377 decicycles in postrotate, 2097097 runs,     55 skips
> 
> Needs to overwrite the start of some buffers as well as the
> end of them, hence the OVERALLOC stuff.
> 
> Signed-off-by: Rostislav Pehlivanov <atomnuker@gmail.com>
> ---
>  libavcodec/mdct15.c          | 75 +++++++++++++++++++++++++++++++-------------
>  libavcodec/mdct15.h          |  7 +++--
>  libavcodec/x86/mdct15.asm    | 72 ++++++++++++++++++++++++++++++++++++++++++
>  libavcodec/x86/mdct15_init.c |  4 +++
>  4 files changed, 134 insertions(+), 24 deletions(-)
> 
> diff --git a/libavcodec/mdct15.c b/libavcodec/mdct15.c
> index d68372c344..0a6c0069db 100644
> --- a/libavcodec/mdct15.c
> +++ b/libavcodec/mdct15.c
> @@ -40,6 +40,29 @@
>  
>  #define CMUL3(c, a, b) CMUL((c).re, (c).im, (a).re, (a).im, (b).re, (b).im)
>  
> +#define OVERALLOC_AMOUNT 32 /* 1 ymm register */

Use AV_INPUT_BUFFER_PADDING_SIZE

> +
> +#define OVERALLOC(val, size)                                                \
> +    {                                                                       \
> +        (val) = NULL;                                                       \
> +        uint8_t *temp = av_malloc((size) + 2*OVERALLOC_AMOUNT);             \

Why two times?

> +        if (temp) {                                                         \
> +            memset(temp, 0, OVERALLOC_AMOUNT);                              \
> +            memset(temp + (size) + OVERALLOC_AMOUNT, 0, OVERALLOC_AMOUNT);  \
> +            (val) = (void *)(temp + OVERALLOC_AMOUNT);                      \

Can't you just keep the zero padding at the end like in every other
padded buffer used by simd functions?
This seems pretty unconventional, especially with the custom free code
below.

Also, you could just use av_mallocz*.

> +        }                                                                   \
> +    }
> +
> +#define OVERFREEP(val)                                                      \
> +    {                                                                       \
> +        uint8_t *temp = (uint8_t *)(val);                                   \
> +        if (temp) {                                                         \
> +            temp -= OVERALLOC_AMOUNT;                                       \
> +            av_free(temp);                                                  \
> +        }                                                                   \
> +        val = NULL;                                                         \
> +    }
> +
>  av_cold void ff_mdct15_uninit(MDCT15Context **ps)
>  {
>      MDCT15Context *s = *ps;
> @@ -50,9 +73,9 @@ av_cold void ff_mdct15_uninit(MDCT15Context **ps)
>      ff_fft_end(&s->ptwo_fft);
>  
>      av_freep(&s->pfa_prereindex);
> -    av_freep(&s->pfa_postreindex);
> -    av_freep(&s->twiddle_exptab);
> -    av_freep(&s->tmp);
> +    OVERFREEP(s->pfa_postreindex);
> +    OVERFREEP(s->twiddle_exptab);
> +    OVERFREEP(s->tmp);
>  
>      av_freep(ps);
>  }
> @@ -69,7 +92,7 @@ static inline int init_pfa_reindex_tabs(MDCT15Context *s)
>      if (!s->pfa_prereindex)
>          return 1;
>  
> -    s->pfa_postreindex = av_malloc(15 * l_ptwo * sizeof(*s->pfa_postreindex));
> +    OVERALLOC(s->pfa_postreindex, 15 * l_ptwo * sizeof(*s->pfa_postreindex));

Not strictly related to this patch, but this should probably be using
av_malloc_array().

>      if (!s->pfa_postreindex)
>          return 1;
>  
> @@ -203,6 +226,21 @@ static void mdct15(MDCT15Context *s, float *dst, const float *src, ptrdiff_t str
>      }
>  }
>  
> +static void postrotate_c(FFTComplex *out, FFTComplex *in, FFTComplex *exp,
> +                         uint32_t *lut, int64_t len8)
> +{
> +    int i;
> +
> +    /* Reindex again, apply twiddles and output */
> +    for (i = 0; i < len8; i++) {
> +        const int i0 = len8 + i, i1 = len8 - i - 1;
> +        const int s0 = lut[i0], s1 = lut[i1];
> +
> +        CMUL(out[i1].re, out[i0].im, in[s1].im, in[s1].re, exp[i1].im, exp[i1].re);
> +        CMUL(out[i0].re, out[i1].im, in[s0].im, in[s0].re, exp[i0].im, exp[i0].re);
> +    }
> +}
> +
>  static void imdct15_half(MDCT15Context *s, float *dst, const float *src,
>                           ptrdiff_t stride)
>  {
> @@ -226,15 +264,7 @@ static void imdct15_half(MDCT15Context *s, float *dst, const float *src,
>          s->ptwo_fft.fft_calc(&s->ptwo_fft, s->tmp + l_ptwo*i);
>  
>      /* Reindex again, apply twiddles and output */
> -    for (i = 0; i < len8; i++) {
> -        const int i0 = len8 + i, i1 = len8 - i - 1;
> -        const int s0 = s->pfa_postreindex[i0], s1 = s->pfa_postreindex[i1];
> -
> -        CMUL(z[i1].re, z[i0].im, s->tmp[s1].im, s->tmp[s1].re,
> -             s->twiddle_exptab[i1].im, s->twiddle_exptab[i1].re);
> -        CMUL(z[i0].re, z[i1].im, s->tmp[s0].im, s->tmp[s0].re,
> -             s->twiddle_exptab[i0].im, s->twiddle_exptab[i0].re);
> -    }
> +    s->postreindex(z, s->tmp, s->twiddle_exptab, s->pfa_postreindex, len8);
>  }
>  
>  av_cold int ff_mdct15_init(MDCT15Context **ps, int inverse, int N, double scale)
> @@ -253,13 +283,14 @@ av_cold int ff_mdct15_init(MDCT15Context **ps, int inverse, int N, double scale)
>      if (!s)
>          return AVERROR(ENOMEM);
>  
> -    s->fft_n      = N - 1;
> -    s->len4       = len2 / 2;
> -    s->len2       = len2;
> -    s->inverse    = inverse;
> -    s->fft15      = fft15_c;
> -    s->mdct       = mdct15;
> -    s->imdct_half = imdct15_half;
> +    s->fft_n       = N - 1;
> +    s->len4        = len2 / 2;
> +    s->len2        = len2;
> +    s->inverse     = inverse;
> +    s->fft15       = fft15_c;
> +    s->mdct        = mdct15;
> +    s->imdct_half  = imdct15_half;
> +    s->postreindex = postrotate_c;
>  
>      if (ff_fft_init(&s->ptwo_fft, N - 1, s->inverse) < 0)
>          goto fail;
> @@ -267,11 +298,11 @@ av_cold int ff_mdct15_init(MDCT15Context **ps, int inverse, int N, double scale)
>      if (init_pfa_reindex_tabs(s))
>          goto fail;
>  
> -    s->tmp  = av_malloc_array(len, 2 * sizeof(*s->tmp));
> +    OVERALLOC(s->tmp, len * 2 * sizeof(*s->tmp));
>      if (!s->tmp)
>          goto fail;
>  
> -    s->twiddle_exptab  = av_malloc_array(s->len4, sizeof(*s->twiddle_exptab));
> +    OVERALLOC(s->twiddle_exptab, s->len4 * sizeof(*s->twiddle_exptab));
>      if (!s->twiddle_exptab)
>          goto fail;
>  
> diff --git a/libavcodec/mdct15.h b/libavcodec/mdct15.h
> index 1c2149d436..301b238ec8 100644
> --- a/libavcodec/mdct15.h
> +++ b/libavcodec/mdct15.h
> @@ -30,8 +30,8 @@ typedef struct MDCT15Context {
>      int len2;
>      int len4;
>      int inverse;
> -    int *pfa_prereindex;
> -    int *pfa_postreindex;
> +    uint32_t *pfa_prereindex;
> +    uint32_t *pfa_postreindex;

What's the point changing these? It has no effect on hand written asm
and the C version still uses const int when loading single values from them.

>  
>      FFTContext ptwo_fft;
>      FFTComplex *tmp;
> @@ -42,6 +42,9 @@ typedef struct MDCT15Context {
>      /* 15-point FFT */
>      void (*fft15)(FFTComplex *out, FFTComplex *in, FFTComplex *exptab, ptrdiff_t stride);
>  
> +    /* PFA postrotate and exptab */
> +    void (*postreindex)(FFTComplex *out, FFTComplex *in, FFTComplex *exp, uint32_t *lut, int64_t len8);

ptrdiff_t, not int64_t.

> +
>      /* Calculate a full 2N -> N MDCT */
>      void (*mdct)(struct MDCT15Context *s, float *dst, const float *src, ptrdiff_t stride);
>  
> diff --git a/libavcodec/x86/mdct15.asm b/libavcodec/x86/mdct15.asm
> index f8b895944d..c5ea56c2c3 100644
> --- a/libavcodec/x86/mdct15.asm
> +++ b/libavcodec/x86/mdct15.asm
> @@ -28,6 +28,10 @@ SECTION_RODATA
>  
>  sign_adjust_5: dd 0x00000000, 0x80000000, 0x80000000, 0x00000000
>  
> +sign_adjust_r: times 4 dd 0x80000000, 0x00000000
> +perm_neg: dd 0x2, 0x5, 0x3, 0x4, 0x6, 0x1, 0x7, 0x0
> +perm_pos: dd 0x0, 0x7, 0x1, 0x6, 0x4, 0x3, 0x5, 0x2

These are index values, so IMO use decimal. It's more readable that way.

Also, SECTION_RODATA needs to be 32 byte aligned now.

> +
>  SECTION .text
>  
>  %macro FFT5 3 ; %1 - in_offset, %2 - dst1 (64bit used), %3 - dst2
> @@ -138,4 +142,72 @@ cglobal fft15, 4, 6, 14, out, in, exptab, stride, stride3, stride5
>  
>      RET
>  
> +%macro LUT_LOAD_4D 4
> +    mov r7d, [lutq + %4q*4 +  0]
> +    movsd  %2, [inq + r7q*8]
> +    mov r7d, [lutq + %4q*4 +  4]
> +    movhps %2, [inq + r7q*8]
> +    mov r7d, [lutq + %4q*4 +  8]
> +    movsd  %3, [inq + r7q*8]
> +    mov r7d, [lutq + %4q*4 + 12]
> +    movhps %3, [inq + r7q*8]
> +    vinsertf128 %1, %1, %3, 1
> +%endmacro
> +
> +;***************************************************************************************************************
> +;void ff_mdct15_postreindex_avx2(FFTComplex *out, FFTComplex *in, FFTComplex *exp, uint32_t *lut, int64_t len8);
> +;***************************************************************************************************************
> +INIT_YMM avx2
> +cglobal mdct15_postreindex, 5, 8, 10, out, in, exp, lut, len8, offset_p, offset_n

You're using seven gprs, not 8.

> +
> +    mova m7, [perm_pos]
> +    mova m8, [perm_neg]
> +    mova m9, [sign_adjust_r]
> +
> +    mov offset_pq, len8q
> +    lea offset_nq, [len8q - 4]
> +
> +    shl len8q, 1
> +
> +    movu m10, [outq - mmsize]           ; backup from start - mmsize to start
> +    movu m11, [outq + len8q*8]          ; backup from end to end + mmsize

And you seem to be using 12 xmmy/ymm regs, not 10.

> +
> +.loop:
> +    movups m0, [expq + offset_pq*8]     ; exp[p0].re, exp[p0].im, exp[p1].re, exp[p1].im, exp[p2].re, exp[p2].im, exp[p3].re, exp[p3].im
> +    movups m1, [expq + offset_nq*8]     ; exp[n3].re, exp[n3].im, exp[n2].re, exp[n2].im, exp[n1].re, exp[n1].im, exp[n0].re, exp[n0].im
> +
> +    LUT_LOAD_4D m3, xm3, xm4, offset_p  ; in[p0].re, in[p0].im, in[p1].re, in[p1].im, in[p2].re, in[p2].im, in[p3].re, in[p3].im
> +    LUT_LOAD_4D m4, xm4, xm5, offset_n  ; in[n3].re, in[n3].im, in[n2].re, in[n2].im, in[n1].re, in[n1].im, in[n0].re, in[n0].im     
> +
> +    mulps m5, m3, m0                    ; in[p].reim * exp[p].reim
> +    mulps m6, m4, m1                    ; in[n].reim * exp[n].reim
> +
> +    xorps m5, m9                        ; in[p].re *= -1, in[p].im *= 1
> +    xorps m6, m9                        ; in[n].re *= -1, in[n].im *= 1
> +
> +    shufps m3, m3, m3, q2301            ; in[p].imre
> +    shufps m4, m4, m4, q2301            ; in[n].imre
> +
> +    mulps m3, m0                        ; in[p].imre * exp[p].reim
> +    mulps m4, m1                        ; in[n].imre * exp[n].reim
> +
> +    haddps m5, m4                       ; out[p0].re, out[p1].re, out[p3].im, out[p2].im, out[p2].re, out[p3].re, out[p1].im, out[p0].im
> +    haddps m3, m6                       ; out[n0].im, out[n1].im, out[n3].re, out[n2].re, out[n2].im, out[n3].im, out[n1].re, out[n0].re
> +
> +    vpermps m5, m7, m5                  ; out[p0].re, out[p0].im, out[p1].re, out[p1].im, out[p2].re, out[p2].im, out[p3].re, out[p3].im
> +    vpermps m3, m8, m3                  ; out[n3].im, out[n3].re, out[n2].im, out[n2].re, out[n1].im, out[n1].re, out[n0].im, out[n0].re
> +
> +    movups [outq + offset_pq*8], m5
> +    movups [outq + offset_nq*8], m3
> +
> +    sub offset_nq, 4
> +    add offset_pq, 4
> +    cmp offset_pq, len8q
> +    jl .loop
> +
> +    movu [outq - mmsize], m10
> +    movu [outq + len8q*8], m11
> +
> +    RET
> +
>  %endif
> diff --git a/libavcodec/x86/mdct15_init.c b/libavcodec/x86/mdct15_init.c
> index ba3d94c2ec..60d47e71ce 100644
> --- a/libavcodec/x86/mdct15_init.c
> +++ b/libavcodec/x86/mdct15_init.c
> @@ -25,6 +25,7 @@
>  #include "libavutil/x86/cpu.h"
>  #include "libavcodec/mdct15.h"
>  
> +void ff_mdct15_postreindex_avx2(FFTComplex *out, FFTComplex *in, FFTComplex *exp, uint32_t *lut, int64_t len8);
>  void ff_fft15_avx(FFTComplex *out, FFTComplex *in, FFTComplex *exptab, ptrdiff_t stride);
>  
>  static void perm_twiddles(MDCT15Context *s)
> @@ -90,6 +91,9 @@ av_cold void ff_mdct15_init_x86(MDCT15Context *s)
>          adjust_twiddles = 1;
>      }
>  
> +    if (ARCH_X86_64 && EXTERNAL_AVX2(cpu_flags))
> +        s->postreindex = ff_mdct15_postreindex_avx2;
> +

Why didn't you write a SSE3 version of this function? Baseline for any
function should be SSE/SSE3 if float, SSE2/SSE4 if integer.

>      if (adjust_twiddles)
>          perm_twiddles(s);
>  }
>
Rostislav Pehlivanov July 29, 2017, 9:55 p.m. UTC | #2
On 29 July 2017 at 22:37, James Almer <jamrial@gmail.com> wrote:

> On 7/29/2017 5:38 PM, Rostislav Pehlivanov wrote:
> > Speeds up decoding by 8% in total.
> >
> > 20ms frames:
> > Before: 17774 decicycles in postrotate,  262065 runs,     79 skips
> > After:   7169 decicycles in postrotate,  262104 runs,     40 skips
> >
> > 10ms frames:
> > Before: 9058 decicycles in postrotate,  524209 runs,     79 skips
> > After:  3915 decicycles in postrotate,  524236 runs,     52 skips
> >
> > 5ms frames:
> > Before: 4764 decicycles in postrotate, 1048466 runs,    110 skips
> > After:  2161 decicycles in postrotate, 1048515 runs,     61 skips
> >
> > 2.5ms frames:
> > Before: 2608 decicycles in postrotate, 2097030 runs,    122 skips
> > After:  1377 decicycles in postrotate, 2097097 runs,     55 skips
> >
> > Needs to overwrite the start of some buffers as well as the
> > end of them, hence the OVERALLOC stuff.
> >
> > Signed-off-by: Rostislav Pehlivanov <atomnuker@gmail.com>
> > ---
> >  libavcodec/mdct15.c          | 75 ++++++++++++++++++++++++++++++
> +-------------
> >  libavcodec/mdct15.h          |  7 +++--
> >  libavcodec/x86/mdct15.asm    | 72 ++++++++++++++++++++++++++++++
> ++++++++++++
> >  libavcodec/x86/mdct15_init.c |  4 +++
> >  4 files changed, 134 insertions(+), 24 deletions(-)
> >
> > diff --git a/libavcodec/mdct15.c b/libavcodec/mdct15.c
> > index d68372c344..0a6c0069db 100644
> > --- a/libavcodec/mdct15.c
> > +++ b/libavcodec/mdct15.c
> > @@ -40,6 +40,29 @@
> >
> >  #define CMUL3(c, a, b) CMUL((c).re, (c).im, (a).re, (a).im, (b).re,
> (b).im)
> >
> > +#define OVERALLOC_AMOUNT 32 /* 1 ymm register */
>
> Use AV_INPUT_BUFFER_PADDING_SIZE
>
> > +
> > +#define OVERALLOC(val, size)
>     \
> > +    {
>      \
> > +        (val) = NULL;
>      \
> > +        uint8_t *temp = av_malloc((size) + 2*OVERALLOC_AMOUNT);
>      \
>
> Why two times?
>
>
Start and end.


> > +        if (temp) {
>      \
> > +            memset(temp, 0, OVERALLOC_AMOUNT);
>     \
> > +            memset(temp + (size) + OVERALLOC_AMOUNT, 0,
> OVERALLOC_AMOUNT);  \
> > +            (val) = (void *)(temp + OVERALLOC_AMOUNT);
>     \
>
> Can't you just keep the zero padding at the end like in every other
> padded buffer used by simd functions?
> This seems pretty unconventional, especially with the custom free code
> below.
>
>
No, the function starts in the middle and reads/writes to the start and end
per single loop.
The amount of iterations is not mod 8 so I have to pad both the start and
end.


> Also, you could just use av_mallocz*.
>
> > +        }
>      \
> > +    }
> > +
> > +#define OVERFREEP(val)
>     \
> > +    {
>      \
> > +        uint8_t *temp = (uint8_t *)(val);
>      \
> > +        if (temp) {
>      \
> > +            temp -= OVERALLOC_AMOUNT;
>      \
> > +            av_free(temp);
>     \
> > +        }
>      \
> > +        val = NULL;
>      \
> > +    }
> > +
> >  av_cold void ff_mdct15_uninit(MDCT15Context **ps)
> >  {
> >      MDCT15Context *s = *ps;
> > @@ -50,9 +73,9 @@ av_cold void ff_mdct15_uninit(MDCT15Context **ps)
> >      ff_fft_end(&s->ptwo_fft);
> >
> >      av_freep(&s->pfa_prereindex);
> > -    av_freep(&s->pfa_postreindex);
> > -    av_freep(&s->twiddle_exptab);
> > -    av_freep(&s->tmp);
> > +    OVERFREEP(s->pfa_postreindex);
> > +    OVERFREEP(s->twiddle_exptab);
> > +    OVERFREEP(s->tmp);
> >
> >      av_freep(ps);
> >  }
> > @@ -69,7 +92,7 @@ static inline int init_pfa_reindex_tabs(MDCT15Context
> *s)
> >      if (!s->pfa_prereindex)
> >          return 1;
> >
> > -    s->pfa_postreindex = av_malloc(15 * l_ptwo *
> sizeof(*s->pfa_postreindex));
> > +    OVERALLOC(s->pfa_postreindex, 15 * l_ptwo *
> sizeof(*s->pfa_postreindex));
>
> Not strictly related to this patch, but this should probably be using
> av_malloc_array().
>
>

Fixed in the overalloc macro.


> >      if (!s->pfa_postreindex)
> >          return 1;
> >
> > @@ -203,6 +226,21 @@ static void mdct15(MDCT15Context *s, float *dst,
> const float *src, ptrdiff_t str
> >      }
> >  }
> >
> > +static void postrotate_c(FFTComplex *out, FFTComplex *in, FFTComplex
> *exp,
> > +                         uint32_t *lut, int64_t len8)
> > +{
> > +    int i;
> > +
> > +    /* Reindex again, apply twiddles and output */
> > +    for (i = 0; i < len8; i++) {
> > +        const int i0 = len8 + i, i1 = len8 - i - 1;
> > +        const int s0 = lut[i0], s1 = lut[i1];
> > +
> > +        CMUL(out[i1].re, out[i0].im, in[s1].im, in[s1].re, exp[i1].im,
> exp[i1].re);
> > +        CMUL(out[i0].re, out[i1].im, in[s0].im, in[s0].re, exp[i0].im,
> exp[i0].re);
> > +    }
> > +}
> > +
> >  static void imdct15_half(MDCT15Context *s, float *dst, const float *src,
> >                           ptrdiff_t stride)
> >  {
> > @@ -226,15 +264,7 @@ static void imdct15_half(MDCT15Context *s, float
> *dst, const float *src,
> >          s->ptwo_fft.fft_calc(&s->ptwo_fft, s->tmp + l_ptwo*i);
> >
> >      /* Reindex again, apply twiddles and output */
> > -    for (i = 0; i < len8; i++) {
> > -        const int i0 = len8 + i, i1 = len8 - i - 1;
> > -        const int s0 = s->pfa_postreindex[i0], s1 =
> s->pfa_postreindex[i1];
> > -
> > -        CMUL(z[i1].re, z[i0].im, s->tmp[s1].im, s->tmp[s1].re,
> > -             s->twiddle_exptab[i1].im, s->twiddle_exptab[i1].re);
> > -        CMUL(z[i0].re, z[i1].im, s->tmp[s0].im, s->tmp[s0].re,
> > -             s->twiddle_exptab[i0].im, s->twiddle_exptab[i0].re);
> > -    }
> > +    s->postreindex(z, s->tmp, s->twiddle_exptab, s->pfa_postreindex,
> len8);
> >  }
> >
> >  av_cold int ff_mdct15_init(MDCT15Context **ps, int inverse, int N,
> double scale)
> > @@ -253,13 +283,14 @@ av_cold int ff_mdct15_init(MDCT15Context **ps, int
> inverse, int N, double scale)
> >      if (!s)
> >          return AVERROR(ENOMEM);
> >
> > -    s->fft_n      = N - 1;
> > -    s->len4       = len2 / 2;
> > -    s->len2       = len2;
> > -    s->inverse    = inverse;
> > -    s->fft15      = fft15_c;
> > -    s->mdct       = mdct15;
> > -    s->imdct_half = imdct15_half;
> > +    s->fft_n       = N - 1;
> > +    s->len4        = len2 / 2;
> > +    s->len2        = len2;
> > +    s->inverse     = inverse;
> > +    s->fft15       = fft15_c;
> > +    s->mdct        = mdct15;
> > +    s->imdct_half  = imdct15_half;
> > +    s->postreindex = postrotate_c;
> >
> >      if (ff_fft_init(&s->ptwo_fft, N - 1, s->inverse) < 0)
> >          goto fail;
> > @@ -267,11 +298,11 @@ av_cold int ff_mdct15_init(MDCT15Context **ps, int
> inverse, int N, double scale)
> >      if (init_pfa_reindex_tabs(s))
> >          goto fail;
> >
> > -    s->tmp  = av_malloc_array(len, 2 * sizeof(*s->tmp));
> > +    OVERALLOC(s->tmp, len * 2 * sizeof(*s->tmp));
> >      if (!s->tmp)
> >          goto fail;
> >
> > -    s->twiddle_exptab  = av_malloc_array(s->len4,
> sizeof(*s->twiddle_exptab));
> > +    OVERALLOC(s->twiddle_exptab, s->len4 * sizeof(*s->twiddle_exptab));
> >      if (!s->twiddle_exptab)
> >          goto fail;
> >
> > diff --git a/libavcodec/mdct15.h b/libavcodec/mdct15.h
> > index 1c2149d436..301b238ec8 100644
> > --- a/libavcodec/mdct15.h
> > +++ b/libavcodec/mdct15.h
> > @@ -30,8 +30,8 @@ typedef struct MDCT15Context {
> >      int len2;
> >      int len4;
> >      int inverse;
> > -    int *pfa_prereindex;
> > -    int *pfa_postreindex;
> > +    uint32_t *pfa_prereindex;
> > +    uint32_t *pfa_postreindex;
>
> What's the point changing these? It has no effect on hand written asm
> and the C version still uses const int when loading single values from
> them.
>
>
Wasn't sure if we require ints to be strictly 32 bits. Fixed.


> >
> >      FFTContext ptwo_fft;
> >      FFTComplex *tmp;
> > @@ -42,6 +42,9 @@ typedef struct MDCT15Context {
> >      /* 15-point FFT */
> >      void (*fft15)(FFTComplex *out, FFTComplex *in, FFTComplex *exptab,
> ptrdiff_t stride);
> >
> > +    /* PFA postrotate and exptab */
> > +    void (*postreindex)(FFTComplex *out, FFTComplex *in, FFTComplex
> *exp, uint32_t *lut, int64_t len8);
>
> ptrdiff_t, not int64_t.
>

Thought because its not a stride it should be int64_t, fixed.


>
> > +
> >      /* Calculate a full 2N -> N MDCT */
> >      void (*mdct)(struct MDCT15Context *s, float *dst, const float *src,
> ptrdiff_t stride);
> >
> > diff --git a/libavcodec/x86/mdct15.asm b/libavcodec/x86/mdct15.asm
> > index f8b895944d..c5ea56c2c3 100644
> > --- a/libavcodec/x86/mdct15.asm
> > +++ b/libavcodec/x86/mdct15.asm
> > @@ -28,6 +28,10 @@ SECTION_RODATA
> >
> >  sign_adjust_5: dd 0x00000000, 0x80000000, 0x80000000, 0x00000000
> >
> > +sign_adjust_r: times 4 dd 0x80000000, 0x00000000
> > +perm_neg: dd 0x2, 0x5, 0x3, 0x4, 0x6, 0x1, 0x7, 0x0
> > +perm_pos: dd 0x0, 0x7, 0x1, 0x6, 0x4, 0x3, 0x5, 0x2
>
> These are index values, so IMO use decimal. It's more readable that way.
>
> Also, SECTION_RODATA needs to be 32 byte aligned now.
>
>
Fixed both.


> > +
> >  SECTION .text
> >
> >  %macro FFT5 3 ; %1 - in_offset, %2 - dst1 (64bit used), %3 - dst2
> > @@ -138,4 +142,72 @@ cglobal fft15, 4, 6, 14, out, in, exptab, stride,
> stride3, stride5
> >
> >      RET
> >
> > +%macro LUT_LOAD_4D 4
> > +    mov r7d, [lutq + %4q*4 +  0]
> > +    movsd  %2, [inq + r7q*8]
> > +    mov r7d, [lutq + %4q*4 +  4]
> > +    movhps %2, [inq + r7q*8]
> > +    mov r7d, [lutq + %4q*4 +  8]
> > +    movsd  %3, [inq + r7q*8]
> > +    mov r7d, [lutq + %4q*4 + 12]
> > +    movhps %3, [inq + r7q*8]
> > +    vinsertf128 %1, %1, %3, 1
> > +%endmacro
> > +
> > +;**********************************************************
> *****************************************************
> > +;void ff_mdct15_postreindex_avx2(FFTComplex *out, FFTComplex *in,
> FFTComplex *exp, uint32_t *lut, int64_t len8);
> > +;**********************************************************
> *****************************************************
> > +INIT_YMM avx2
> > +cglobal mdct15_postreindex, 5, 8, 10, out, in, exp, lut, len8,
> offset_p, offset_n
>
> You're using seven gprs, not 8.
>
>
No, look in LUT_LOAD_4D, I'm using r7 there so 8 gprs.


> > +
> > +    mova m7, [perm_pos]
> > +    mova m8, [perm_neg]
> > +    mova m9, [sign_adjust_r]
> > +
> > +    mov offset_pq, len8q
> > +    lea offset_nq, [len8q - 4]
> > +
> > +    shl len8q, 1
> > +
> > +    movu m10, [outq - mmsize]           ; backup from start - mmsize to
> start
> > +    movu m11, [outq + len8q*8]          ; backup from end to end +
> mmsize
>
> And you seem to be using 12 xmmy/ymm regs, not 10.
>

Fixed.


>
> > +
> > +.loop:
> > +    movups m0, [expq + offset_pq*8]     ; exp[p0].re, exp[p0].im,
> exp[p1].re, exp[p1].im, exp[p2].re, exp[p2].im, exp[p3].re, exp[p3].im
> > +    movups m1, [expq + offset_nq*8]     ; exp[n3].re, exp[n3].im,
> exp[n2].re, exp[n2].im, exp[n1].re, exp[n1].im, exp[n0].re, exp[n0].im
> > +
> > +    LUT_LOAD_4D m3, xm3, xm4, offset_p  ; in[p0].re, in[p0].im,
> in[p1].re, in[p1].im, in[p2].re, in[p2].im, in[p3].re, in[p3].im
> > +    LUT_LOAD_4D m4, xm4, xm5, offset_n  ; in[n3].re, in[n3].im,
> in[n2].re, in[n2].im, in[n1].re, in[n1].im, in[n0].re, in[n0].im
> > +
> > +    mulps m5, m3, m0                    ; in[p].reim * exp[p].reim
> > +    mulps m6, m4, m1                    ; in[n].reim * exp[n].reim
> > +
> > +    xorps m5, m9                        ; in[p].re *= -1, in[p].im *= 1
> > +    xorps m6, m9                        ; in[n].re *= -1, in[n].im *= 1
> > +
> > +    shufps m3, m3, m3, q2301            ; in[p].imre
> > +    shufps m4, m4, m4, q2301            ; in[n].imre
> > +
> > +    mulps m3, m0                        ; in[p].imre * exp[p].reim
> > +    mulps m4, m1                        ; in[n].imre * exp[n].reim
> > +
> > +    haddps m5, m4                       ; out[p0].re, out[p1].re,
> out[p3].im, out[p2].im, out[p2].re, out[p3].re, out[p1].im, out[p0].im
> > +    haddps m3, m6                       ; out[n0].im, out[n1].im,
> out[n3].re, out[n2].re, out[n2].im, out[n3].im, out[n1].re, out[n0].re
> > +
> > +    vpermps m5, m7, m5                  ; out[p0].re, out[p0].im,
> out[p1].re, out[p1].im, out[p2].re, out[p2].im, out[p3].re, out[p3].im
> > +    vpermps m3, m8, m3                  ; out[n3].im, out[n3].re,
> out[n2].im, out[n2].re, out[n1].im, out[n1].re, out[n0].im, out[n0].re
> > +
> > +    movups [outq + offset_pq*8], m5
> > +    movups [outq + offset_nq*8], m3
> > +
> > +    sub offset_nq, 4
> > +    add offset_pq, 4
> > +    cmp offset_pq, len8q
> > +    jl .loop
> > +
> > +    movu [outq - mmsize], m10
> > +    movu [outq + len8q*8], m11
> > +
> > +    RET
> > +
> >  %endif
> > diff --git a/libavcodec/x86/mdct15_init.c b/libavcodec/x86/mdct15_init.c
> > index ba3d94c2ec..60d47e71ce 100644
> > --- a/libavcodec/x86/mdct15_init.c
> > +++ b/libavcodec/x86/mdct15_init.c
> > @@ -25,6 +25,7 @@
> >  #include "libavutil/x86/cpu.h"
> >  #include "libavcodec/mdct15.h"
> >
> > +void ff_mdct15_postreindex_avx2(FFTComplex *out, FFTComplex *in,
> FFTComplex *exp, uint32_t *lut, int64_t len8);
> >  void ff_fft15_avx(FFTComplex *out, FFTComplex *in, FFTComplex *exptab,
> ptrdiff_t stride);
> >
> >  static void perm_twiddles(MDCT15Context *s)
> > @@ -90,6 +91,9 @@ av_cold void ff_mdct15_init_x86(MDCT15Context *s)
> >          adjust_twiddles = 1;
> >      }
> >
> > +    if (ARCH_X86_64 && EXTERNAL_AVX2(cpu_flags))
> > +        s->postreindex = ff_mdct15_postreindex_avx2;
> > +
>
> Why didn't you write a SSE3 version of this function? Baseline for any
> function should be SSE/SSE3 if float, SSE2/SSE4 if integer.
>
>
I or someone else could do it later. The whole thing started as a test to
see how slow vgather is so I had to use avx2 but then it turned out to be
slower so I dropped it but kept the whole thing as avx2.


> >      if (adjust_twiddles)
> >          perm_twiddles(s);
> >  }
> >
>
> _______________________________________________
> ffmpeg-devel mailing list
> ffmpeg-devel@ffmpeg.org
> http://ffmpeg.org/mailman/listinfo/ffmpeg-devel
>
diff mbox

Patch

diff --git a/libavcodec/mdct15.c b/libavcodec/mdct15.c
index d68372c344..0a6c0069db 100644
--- a/libavcodec/mdct15.c
+++ b/libavcodec/mdct15.c
@@ -40,6 +40,29 @@ 
 
 #define CMUL3(c, a, b) CMUL((c).re, (c).im, (a).re, (a).im, (b).re, (b).im)
 
+#define OVERALLOC_AMOUNT 32 /* 1 ymm register */
+
+#define OVERALLOC(val, size)                                                \
+    {                                                                       \
+        (val) = NULL;                                                       \
+        uint8_t *temp = av_malloc((size) + 2*OVERALLOC_AMOUNT);             \
+        if (temp) {                                                         \
+            memset(temp, 0, OVERALLOC_AMOUNT);                              \
+            memset(temp + (size) + OVERALLOC_AMOUNT, 0, OVERALLOC_AMOUNT);  \
+            (val) = (void *)(temp + OVERALLOC_AMOUNT);                      \
+        }                                                                   \
+    }
+
+#define OVERFREEP(val)                                                      \
+    {                                                                       \
+        uint8_t *temp = (uint8_t *)(val);                                   \
+        if (temp) {                                                         \
+            temp -= OVERALLOC_AMOUNT;                                       \
+            av_free(temp);                                                  \
+        }                                                                   \
+        val = NULL;                                                         \
+    }
+
 av_cold void ff_mdct15_uninit(MDCT15Context **ps)
 {
     MDCT15Context *s = *ps;
@@ -50,9 +73,9 @@  av_cold void ff_mdct15_uninit(MDCT15Context **ps)
     ff_fft_end(&s->ptwo_fft);
 
     av_freep(&s->pfa_prereindex);
-    av_freep(&s->pfa_postreindex);
-    av_freep(&s->twiddle_exptab);
-    av_freep(&s->tmp);
+    OVERFREEP(s->pfa_postreindex);
+    OVERFREEP(s->twiddle_exptab);
+    OVERFREEP(s->tmp);
 
     av_freep(ps);
 }
@@ -69,7 +92,7 @@  static inline int init_pfa_reindex_tabs(MDCT15Context *s)
     if (!s->pfa_prereindex)
         return 1;
 
-    s->pfa_postreindex = av_malloc(15 * l_ptwo * sizeof(*s->pfa_postreindex));
+    OVERALLOC(s->pfa_postreindex, 15 * l_ptwo * sizeof(*s->pfa_postreindex));
     if (!s->pfa_postreindex)
         return 1;
 
@@ -203,6 +226,21 @@  static void mdct15(MDCT15Context *s, float *dst, const float *src, ptrdiff_t str
     }
 }
 
+static void postrotate_c(FFTComplex *out, FFTComplex *in, FFTComplex *exp,
+                         uint32_t *lut, int64_t len8)
+{
+    int i;
+
+    /* Reindex again, apply twiddles and output */
+    for (i = 0; i < len8; i++) {
+        const int i0 = len8 + i, i1 = len8 - i - 1;
+        const int s0 = lut[i0], s1 = lut[i1];
+
+        CMUL(out[i1].re, out[i0].im, in[s1].im, in[s1].re, exp[i1].im, exp[i1].re);
+        CMUL(out[i0].re, out[i1].im, in[s0].im, in[s0].re, exp[i0].im, exp[i0].re);
+    }
+}
+
 static void imdct15_half(MDCT15Context *s, float *dst, const float *src,
                          ptrdiff_t stride)
 {
@@ -226,15 +264,7 @@  static void imdct15_half(MDCT15Context *s, float *dst, const float *src,
         s->ptwo_fft.fft_calc(&s->ptwo_fft, s->tmp + l_ptwo*i);
 
     /* Reindex again, apply twiddles and output */
-    for (i = 0; i < len8; i++) {
-        const int i0 = len8 + i, i1 = len8 - i - 1;
-        const int s0 = s->pfa_postreindex[i0], s1 = s->pfa_postreindex[i1];
-
-        CMUL(z[i1].re, z[i0].im, s->tmp[s1].im, s->tmp[s1].re,
-             s->twiddle_exptab[i1].im, s->twiddle_exptab[i1].re);
-        CMUL(z[i0].re, z[i1].im, s->tmp[s0].im, s->tmp[s0].re,
-             s->twiddle_exptab[i0].im, s->twiddle_exptab[i0].re);
-    }
+    s->postreindex(z, s->tmp, s->twiddle_exptab, s->pfa_postreindex, len8);
 }
 
 av_cold int ff_mdct15_init(MDCT15Context **ps, int inverse, int N, double scale)
@@ -253,13 +283,14 @@  av_cold int ff_mdct15_init(MDCT15Context **ps, int inverse, int N, double scale)
     if (!s)
         return AVERROR(ENOMEM);
 
-    s->fft_n      = N - 1;
-    s->len4       = len2 / 2;
-    s->len2       = len2;
-    s->inverse    = inverse;
-    s->fft15      = fft15_c;
-    s->mdct       = mdct15;
-    s->imdct_half = imdct15_half;
+    s->fft_n       = N - 1;
+    s->len4        = len2 / 2;
+    s->len2        = len2;
+    s->inverse     = inverse;
+    s->fft15       = fft15_c;
+    s->mdct        = mdct15;
+    s->imdct_half  = imdct15_half;
+    s->postreindex = postrotate_c;
 
     if (ff_fft_init(&s->ptwo_fft, N - 1, s->inverse) < 0)
         goto fail;
@@ -267,11 +298,11 @@  av_cold int ff_mdct15_init(MDCT15Context **ps, int inverse, int N, double scale)
     if (init_pfa_reindex_tabs(s))
         goto fail;
 
-    s->tmp  = av_malloc_array(len, 2 * sizeof(*s->tmp));
+    OVERALLOC(s->tmp, len * 2 * sizeof(*s->tmp));
     if (!s->tmp)
         goto fail;
 
-    s->twiddle_exptab  = av_malloc_array(s->len4, sizeof(*s->twiddle_exptab));
+    OVERALLOC(s->twiddle_exptab, s->len4 * sizeof(*s->twiddle_exptab));
     if (!s->twiddle_exptab)
         goto fail;
 
diff --git a/libavcodec/mdct15.h b/libavcodec/mdct15.h
index 1c2149d436..301b238ec8 100644
--- a/libavcodec/mdct15.h
+++ b/libavcodec/mdct15.h
@@ -30,8 +30,8 @@  typedef struct MDCT15Context {
     int len2;
     int len4;
     int inverse;
-    int *pfa_prereindex;
-    int *pfa_postreindex;
+    uint32_t *pfa_prereindex;
+    uint32_t *pfa_postreindex;
 
     FFTContext ptwo_fft;
     FFTComplex *tmp;
@@ -42,6 +42,9 @@  typedef struct MDCT15Context {
     /* 15-point FFT */
     void (*fft15)(FFTComplex *out, FFTComplex *in, FFTComplex *exptab, ptrdiff_t stride);
 
+    /* PFA postrotate and exptab */
+    void (*postreindex)(FFTComplex *out, FFTComplex *in, FFTComplex *exp, uint32_t *lut, int64_t len8);
+
     /* Calculate a full 2N -> N MDCT */
     void (*mdct)(struct MDCT15Context *s, float *dst, const float *src, ptrdiff_t stride);
 
diff --git a/libavcodec/x86/mdct15.asm b/libavcodec/x86/mdct15.asm
index f8b895944d..c5ea56c2c3 100644
--- a/libavcodec/x86/mdct15.asm
+++ b/libavcodec/x86/mdct15.asm
@@ -28,6 +28,10 @@  SECTION_RODATA
 
 sign_adjust_5: dd 0x00000000, 0x80000000, 0x80000000, 0x00000000
 
+sign_adjust_r: times 4 dd 0x80000000, 0x00000000
+perm_neg: dd 0x2, 0x5, 0x3, 0x4, 0x6, 0x1, 0x7, 0x0
+perm_pos: dd 0x0, 0x7, 0x1, 0x6, 0x4, 0x3, 0x5, 0x2
+
 SECTION .text
 
 %macro FFT5 3 ; %1 - in_offset, %2 - dst1 (64bit used), %3 - dst2
@@ -138,4 +142,72 @@  cglobal fft15, 4, 6, 14, out, in, exptab, stride, stride3, stride5
 
     RET
 
+%macro LUT_LOAD_4D 4
+    mov r7d, [lutq + %4q*4 +  0]
+    movsd  %2, [inq + r7q*8]
+    mov r7d, [lutq + %4q*4 +  4]
+    movhps %2, [inq + r7q*8]
+    mov r7d, [lutq + %4q*4 +  8]
+    movsd  %3, [inq + r7q*8]
+    mov r7d, [lutq + %4q*4 + 12]
+    movhps %3, [inq + r7q*8]
+    vinsertf128 %1, %1, %3, 1
+%endmacro
+
+;***************************************************************************************************************
+;void ff_mdct15_postreindex_avx2(FFTComplex *out, FFTComplex *in, FFTComplex *exp, uint32_t *lut, int64_t len8);
+;***************************************************************************************************************
+INIT_YMM avx2
+cglobal mdct15_postreindex, 5, 8, 10, out, in, exp, lut, len8, offset_p, offset_n
+
+    mova m7, [perm_pos]
+    mova m8, [perm_neg]
+    mova m9, [sign_adjust_r]
+
+    mov offset_pq, len8q
+    lea offset_nq, [len8q - 4]
+
+    shl len8q, 1
+
+    movu m10, [outq - mmsize]           ; backup from start - mmsize to start
+    movu m11, [outq + len8q*8]          ; backup from end to end + mmsize
+
+.loop:
+    movups m0, [expq + offset_pq*8]     ; exp[p0].re, exp[p0].im, exp[p1].re, exp[p1].im, exp[p2].re, exp[p2].im, exp[p3].re, exp[p3].im
+    movups m1, [expq + offset_nq*8]     ; exp[n3].re, exp[n3].im, exp[n2].re, exp[n2].im, exp[n1].re, exp[n1].im, exp[n0].re, exp[n0].im
+
+    LUT_LOAD_4D m3, xm3, xm4, offset_p  ; in[p0].re, in[p0].im, in[p1].re, in[p1].im, in[p2].re, in[p2].im, in[p3].re, in[p3].im
+    LUT_LOAD_4D m4, xm4, xm5, offset_n  ; in[n3].re, in[n3].im, in[n2].re, in[n2].im, in[n1].re, in[n1].im, in[n0].re, in[n0].im     
+
+    mulps m5, m3, m0                    ; in[p].reim * exp[p].reim
+    mulps m6, m4, m1                    ; in[n].reim * exp[n].reim
+
+    xorps m5, m9                        ; in[p].re *= -1, in[p].im *= 1
+    xorps m6, m9                        ; in[n].re *= -1, in[n].im *= 1
+
+    shufps m3, m3, m3, q2301            ; in[p].imre
+    shufps m4, m4, m4, q2301            ; in[n].imre
+
+    mulps m3, m0                        ; in[p].imre * exp[p].reim
+    mulps m4, m1                        ; in[n].imre * exp[n].reim
+
+    haddps m5, m4                       ; out[p0].re, out[p1].re, out[p3].im, out[p2].im, out[p2].re, out[p3].re, out[p1].im, out[p0].im
+    haddps m3, m6                       ; out[n0].im, out[n1].im, out[n3].re, out[n2].re, out[n2].im, out[n3].im, out[n1].re, out[n0].re
+
+    vpermps m5, m7, m5                  ; out[p0].re, out[p0].im, out[p1].re, out[p1].im, out[p2].re, out[p2].im, out[p3].re, out[p3].im
+    vpermps m3, m8, m3                  ; out[n3].im, out[n3].re, out[n2].im, out[n2].re, out[n1].im, out[n1].re, out[n0].im, out[n0].re
+
+    movups [outq + offset_pq*8], m5
+    movups [outq + offset_nq*8], m3
+
+    sub offset_nq, 4
+    add offset_pq, 4
+    cmp offset_pq, len8q
+    jl .loop
+
+    movu [outq - mmsize], m10
+    movu [outq + len8q*8], m11
+
+    RET
+
 %endif
diff --git a/libavcodec/x86/mdct15_init.c b/libavcodec/x86/mdct15_init.c
index ba3d94c2ec..60d47e71ce 100644
--- a/libavcodec/x86/mdct15_init.c
+++ b/libavcodec/x86/mdct15_init.c
@@ -25,6 +25,7 @@ 
 #include "libavutil/x86/cpu.h"
 #include "libavcodec/mdct15.h"
 
+void ff_mdct15_postreindex_avx2(FFTComplex *out, FFTComplex *in, FFTComplex *exp, uint32_t *lut, int64_t len8);
 void ff_fft15_avx(FFTComplex *out, FFTComplex *in, FFTComplex *exptab, ptrdiff_t stride);
 
 static void perm_twiddles(MDCT15Context *s)
@@ -90,6 +91,9 @@  av_cold void ff_mdct15_init_x86(MDCT15Context *s)
         adjust_twiddles = 1;
     }
 
+    if (ARCH_X86_64 && EXTERNAL_AVX2(cpu_flags))
+        s->postreindex = ff_mdct15_postreindex_avx2;
+
     if (adjust_twiddles)
         perm_twiddles(s);
 }