diff mbox series

[FFmpeg-devel] lavc/sbrdsp: R-V V autocorrelate

Message ID 20231108203041.51648-1-remi@remlab.net
State Accepted
Commit cd7b352c534c097c8009d4022daff4027655d207
Headers show
Series [FFmpeg-devel] lavc/sbrdsp: R-V V autocorrelate | expand

Checks

Context Check Description
yinshiyou/make_loongarch64 success Make finished
yinshiyou/make_fate_loongarch64 success Make fate finished
andriy/make_x86 success Make finished
andriy/make_fate_x86 success Make fate finished

Commit Message

Rémi Denis-Courmont Nov. 8, 2023, 8:30 p.m. UTC
With 5 accumulator vectors and 6 inputs, this can only use LMUL=2.
Also the number of vector loop iterations is small, just 5 on 128-bit
vector hardware.

The vector loop is somewhat unusual in that it processes data in
descending memory order, in order to save on vector slides:
in descending order, we can extract elements to carry over to the next
iteration from the bottom of the vectors directly. With ascending order
(see in the Opus postfilter function), there are no ways to get the top
elements directly. On the downside, this requires the use of separate
shift and sub (the would-be SH3SUB instruction does not exist), with
a small pipeline stall on the vector load address.

The edge cases in scalar are done in scalar as this saves on loads
and remains significantly faster than C.

autocorrelate_c: 669.2
autocorrelate_rvv_f32: 421.0
---
 libavcodec/riscv/sbrdsp_init.c | 12 +++--
 libavcodec/riscv/sbrdsp_rvv.S  | 89 ++++++++++++++++++++++++++++++++++
 2 files changed, 97 insertions(+), 4 deletions(-)
diff mbox series

Patch

diff --git a/libavcodec/riscv/sbrdsp_init.c b/libavcodec/riscv/sbrdsp_init.c
index 71de681185..c1ed5b639c 100644
--- a/libavcodec/riscv/sbrdsp_init.c
+++ b/libavcodec/riscv/sbrdsp_init.c
@@ -26,6 +26,7 @@ 
 void ff_sbr_sum64x5_rvv(float *z);
 float ff_sbr_sum_square_rvv(float (*x)[2], int n);
 void ff_sbr_neg_odd_64_rvv(float *x);
+void ff_sbr_autocorrelate_rvv(const float x[40][2], float phi[3][2][2]);
 void ff_sbr_hf_g_filt_rvv(float (*Y)[2], const float (*X_high)[40][2],
                           const float *g_filt, int m_max, intptr_t ixh);
 
@@ -34,10 +35,13 @@  av_cold void ff_sbrdsp_init_riscv(SBRDSPContext *c)
 #if HAVE_RVV
     int flags = av_get_cpu_flags();
 
-    if ((flags & AV_CPU_FLAG_RVV_F32) && (flags & AV_CPU_FLAG_RVB_ADDR)) {
-        c->sum64x5 = ff_sbr_sum64x5_rvv;
-        c->sum_square = ff_sbr_sum_square_rvv;
-        c->hf_g_filt = ff_sbr_hf_g_filt_rvv;
+    if (flags & AV_CPU_FLAG_RVV_F32) {
+        if (flags & AV_CPU_FLAG_RVB_ADDR) {
+            c->sum64x5 = ff_sbr_sum64x5_rvv;
+            c->sum_square = ff_sbr_sum_square_rvv;
+            c->hf_g_filt = ff_sbr_hf_g_filt_rvv;
+        }
+        c->autocorrelate = ff_sbr_autocorrelate_rvv;
     }
 #if __riscv_xlen >= 64
     if ((flags & AV_CPU_FLAG_RVV_I64) && (flags & AV_CPU_FLAG_RVB_ADDR))
diff --git a/libavcodec/riscv/sbrdsp_rvv.S b/libavcodec/riscv/sbrdsp_rvv.S
index 932a5dd7d1..2f3a0969d7 100644
--- a/libavcodec/riscv/sbrdsp_rvv.S
+++ b/libavcodec/riscv/sbrdsp_rvv.S
@@ -85,6 +85,95 @@  func ff_sbr_neg_odd_64_rvv, zve64x
 endfunc
 #endif
 
+func ff_sbr_autocorrelate_rvv, zve32f
+        vsetvli t0, zero, e32, m4, ta, ma
+        vmv.v.x v0, zero
+        flw     fa0,   (a0)
+        vmv.v.x v4, zero
+        flw     fa1,  4(a0)
+        vmv.v.x v8, zero
+        flw     fa2,  8(a0)
+        li      a2, 37
+        flw     fa3, 12(a0)
+        fmul.s  ft10, fa0, fa0
+        flw     fa4, 16(a0)
+        fmul.s  ft6, fa0, fa2
+        flw     fa5, 20(a0)
+        addi    a0, a0, 38 * 8
+        fmul.s  ft7, fa0, fa3
+        fmul.s  ft2, fa0, fa4
+        fmul.s  ft3, fa0, fa5
+        flw     fa0,   (a0)
+        fmadd.s ft10, fa1, fa1, ft10
+        fmadd.s ft6, fa1, fa3, ft6
+        flw     fa3, 12(a0)
+        fnmsub.s ft7, fa1, fa2, ft7
+        flw     fa2,  8(a0)
+        fmadd.s ft2, fa1, fa5, ft2
+        fnmsub.s ft3, fa1, fa4, ft3
+        flw     fa1,  4(a0)
+        fmul.s  ft4, fa0, fa0
+        fmul.s  ft0, fa0, fa2
+        fmul.s  ft1, fa0, fa3
+        fmadd.s ft4, fa1, fa1, ft4
+        fmadd.s ft0, fa1, fa3, ft0
+        fnmsub.s ft1, fa1, fa2, ft1
+1:
+        vsetvli t0, a2, e32, m2, tu, ma
+        slli    t1, t0, 3
+        sub     a0, a0, t1
+        vlseg2e32.v v16, (a0)
+        sub     a2, a2, t0
+        vfmacc.vv v0, v16, v16
+        vfslide1down.vf v20, v16, fa0
+        vfmacc.vv v4, v16, v20
+        vfslide1down.vf v22, v18, fa1
+        vfmacc.vv v0, v18, v18
+        vfslide1down.vf v24, v20, fa2
+        vfmacc.vv v4, v18, v22
+        vfslide1down.vf v26, v22, fa3
+        vfmacc.vv v6, v16, v22
+        vfmv.f.s fa0, v16
+        vfmacc.vv v8, v16, v24
+        vfmv.f.s fa1, v18
+        vfmacc.vv v10, v16, v26
+        vfmv.f.s fa2, v20
+        vfnmsac.vv v6, v18, v20
+        vfmv.f.s fa3, v22
+        vfmacc.vv v8, v18, v26
+        vfnmsac.vv v10, v18, v24
+        bnez    a2, 1b
+
+        vsetvli t0, zero, e32, m2, ta, ma
+        vfredusum.vs v0, v0, v2
+        vfredusum.vs v4, v4, v2
+        vfmv.f.s fa0, v0
+        vfredusum.vs v6, v6, v2
+        vfmv.f.s fa2, v4
+        fadd.s   ft4, ft4, fa0
+        vfredusum.vs v8, v8, v2
+        vfmv.f.s fa3, v6
+        fadd.s   ft0, ft0, fa2
+        vfredusum.vs v10, v10, v2
+        vfmv.f.s fa4, v8
+        fadd.s   ft1, ft1, fa3
+        vfmv.f.s fa5, v10
+        fsw     ft0,   (a1)
+        fadd.s  ft2, ft2, fa4
+        fsw     ft1,  4(a1)
+        fadd.s  ft3, ft3, fa5
+        fsw     ft2,  8(a1)
+        fadd.s  ft6, ft6, fa2
+        fsw     ft3, 12(a1)
+        fadd.s  ft7, ft7, fa3
+        fsw     ft4, 16(a1)
+        fadd.s  ft10, ft10, fa0
+        fsw     ft6, 24(a1)
+        fsw     ft7, 28(a1)
+        fsw     ft10, 40(a1)
+        ret
+endfunc
+
 func ff_sbr_hf_g_filt_rvv, zve32f
         li      t1, 40 * 2 * 4
         sh3add  a1, a4, a1