diff mbox series

[FFmpeg-devel,28/31] lavc/aacpsdsp: RISC-V V hybrid_analysis

Message ID 20220925142619.67917-28-remi@remlab.net
State New
Headers show
Series [FFmpeg-devel,01/31] lavu/cpu: detect RISC-V base extensions | expand

Commit Message

Rémi Denis-Courmont Sept. 25, 2022, 2:26 p.m. UTC
From: Rémi Denis-Courmont <remi@remlab.net>

This starts with one-time initialisation of the 26 constant factors
like  08edacc248bce3f8946d75e97188d189c74a6de6. That is done with
the scalar instruction set. While the formula can readily be vectored,
the gains would (probably) be more than lost in transfering the results
back to FP registers (or suitably reshuffling them into vector
registers).

Note that the main loop could likely be scheduled sligthly better by
expanding the filter macro and interleaving loads with arithmetic.
It is not clear yet if that would be relevant for vector processing (as
opposed to traditional SIMD).

We could also use fewer vectors, but there is not much point in sparing
them (they are *all* callee-clobbered).
---
 libavcodec/riscv/aacpsdsp_init.c |  3 +
 libavcodec/riscv/aacpsdsp_rvv.S  | 97 ++++++++++++++++++++++++++++++++
 2 files changed, 100 insertions(+)
diff mbox series

Patch

diff --git a/libavcodec/riscv/aacpsdsp_init.c b/libavcodec/riscv/aacpsdsp_init.c
index 90c9c501c3..6222d6f787 100644
--- a/libavcodec/riscv/aacpsdsp_init.c
+++ b/libavcodec/riscv/aacpsdsp_init.c
@@ -27,6 +27,8 @@ 
 void ff_ps_add_squares_rvv(float *dst, const float (*src)[2], int n);
 void ff_ps_mul_pair_single_rvv(float (*dst)[2], float (*src0)[2], float *src1,
                                int n);
+void ff_ps_hybrid_analysis_rvv(float (*out)[2], float (*in)[2],
+                               const float (*filter)[8][2], ptrdiff_t, int n);
 
 av_cold void ff_psdsp_init_riscv(PSDSPContext *c)
 {
@@ -36,6 +38,7 @@  av_cold void ff_psdsp_init_riscv(PSDSPContext *c)
     if (flags & AV_CPU_FLAG_RV_ZVE32F) {
         c->add_squares = ff_ps_add_squares_rvv;
         c->mul_pair_single = ff_ps_mul_pair_single_rvv;
+        c->hybrid_analysis = ff_ps_hybrid_analysis_rvv;
     }
 #endif
 }
diff --git a/libavcodec/riscv/aacpsdsp_rvv.S b/libavcodec/riscv/aacpsdsp_rvv.S
index 70b7b72218..65e5e0be4f 100644
--- a/libavcodec/riscv/aacpsdsp_rvv.S
+++ b/libavcodec/riscv/aacpsdsp_rvv.S
@@ -52,3 +52,100 @@  func ff_ps_mul_pair_single_rvv, zve32f
 
         ret
 endfunc
+
+func ff_ps_hybrid_analysis_rvv, zve32f
+        /* We need 26 FP registers, for 20 scratch ones. Spill fs0-fs5. */
+        addi    sp, sp, -32
+        .irp n, 0, 1, 2, 3, 4, 5
+        fsw     fs\n, (4 * \n)(sp)
+        .endr
+
+        .macro input, j, fd0, fd1, fd2, fd3
+        flw     \fd0, (4 * ((\j * 2) + 0))(a1)
+        flw     fs4, (4 * (((12 - \j) * 2) + 0))(a1)
+        flw     \fd1, (4 * ((\j * 2) + 1))(a1)
+        fsub.s  \fd3, \fd0, fs4
+        flw     fs5, (4 * (((12 - \j) * 2) + 1))(a1)
+        fadd.s  \fd2, \fd1, fs5
+        fadd.s  \fd0, \fd0, fs4
+        fsub.s  \fd1, \fd1, fs5
+        .endm
+
+        //         re0, re1, im0, im1
+        input   0, ft0, ft1, ft2, ft3
+        input   1, ft4, ft5, ft6, ft7
+        input   2, ft8, ft9, ft10, ft11
+        input   3, fa0, fa1, fa2, fa3
+        input   4, fa4, fa5, fa6, fa7
+        input   5, fs0, fs1, fs2, fs3
+        flw     fs4, (4 * ((6 * 2) + 0))(a1)
+        flw     fs5, (4 * ((6 * 2) + 1))(a1)
+
+        add        a2, a2, 6 * 2 * 4 // point to filter[i][6][0]
+        li         t4, 8 * 2 * 4 // filter byte stride
+        slli       a3, a3, 3 // output byte stride
+1:
+        .macro filter, vs0, vs1, fo0, fo1, fo2, fo3
+        vfmacc.vf  v8, \fo0, \vs0
+        vfmacc.vf  v9, \fo2, \vs0
+        vfnmsac.vf v8, \fo1, \vs1
+        vfmacc.vf  v9, \fo3, \vs1
+        .endm
+
+        vsetvli    t0, a4, e32, m1, ta, ma
+        /*
+         * The filter (a2) has 16 segments, of which 13 need to be extracted.
+         * R-V V supports only up to 8 segments, so unrolling is unavoidable.
+         */
+        addi       t1, a2, -48
+        vlse32.v   v22, (a2), t4
+        addi       t2, a2, -44
+        vlse32.v   v16, (t1), t4
+        addi       t1, a2, -40
+        vfmul.vf   v8, v22, fs4
+        vlse32.v   v24, (t2), t4
+        addi       t2, a2, -36
+        vfmul.vf   v9, v22, fs5
+        vlse32.v   v17, (t1), t4
+        addi       t1, a2, -32
+        vlse32.v   v25, (t2), t4
+        addi       t2, a2, -28
+        filter     v16, v24, ft0, ft1, ft2, ft3
+        vlse32.v   v18, (t1), t4
+        addi       t1, a2, -24
+        vlse32.v   v26, (t2), t4
+        addi       t2, a2, -20
+        filter     v17, v25, ft4, ft5, ft6, ft7
+        vlse32.v   v19, (t1), t4
+        addi       t1, a2, -16
+        vlse32.v   v27, (t2), t4
+        addi       t2, a2, -12
+        filter     v18, v26, ft8, ft9, ft10, ft11
+        vlse32.v   v20, (t1), t4
+        addi       t1, a2, -8
+        vlse32.v   v28, (t2), t4
+        addi       t2, a2, -4
+        filter     v19, v27, fa0, fa1, fa2, fa3
+        vlse32.v   v21, (t1), t4
+        sub        a4, a4, t0
+        vlse32.v   v29, (t2), t4
+        slli       t1, t0, 3 + 1 + 2 // ctz(8 * 2 * 4)
+        add        a2, a2, t1
+        filter     v20, v28, fa4, fa5, fa6, fa7
+        filter     v21, v29, fs0, fs1, fs2, fs3
+
+        add        t2, a0, 4
+        vsse32.v   v8, (a0), a3
+        mul        t0, t0, a3
+        vsse32.v   v9, (t2), a3
+        add        a0, a0, t0
+        bnez       a4, 1b
+
+        .irp n, 5, 4, 3, 2, 1, 0
+        flw     fs\n, (4 * \n)(sp)
+        .endr
+        addi    sp, sp, 32
+        ret
+        .purgem input
+        .purgem filter
+endfunc