diff mbox series

[FFmpeg-devel,v2,2/3] aarch64/vvc: Add dmvr_hv

Message ID tencent_0E6C401A9AC09FDEAB064E0870AE2D869A08@qq.com
State New
Headers show
Series [FFmpeg-devel,v2,1/3] aarch64/vvc: Add w_avg | expand

Checks

Context Check Description
yinshiyou/make_loongarch64 success Make finished
yinshiyou/make_fate_loongarch64 success Make fate finished
andriy/make_x86 success Make finished
andriy/make_fate_x86 success Make fate finished

Commit Message

Zhao Zhili Sept. 23, 2024, 9:05 a.m. UTC
From: Zhao Zhili <zhilizhao@tencent.com>

dmvr_hv_8_12x20_c:                                       8.0 ( 1.00x)
dmvr_hv_8_12x20_neon:                                    1.2 ( 6.62x)
dmvr_hv_8_20x12_c:                                       8.0 ( 1.00x)
dmvr_hv_8_20x12_neon:                                    0.9 ( 8.37x)
dmvr_hv_8_20x20_c:                                      12.9 ( 1.00x)
dmvr_hv_8_20x20_neon:                                    1.7 ( 7.62x)
dmvr_hv_10_12x20_c:                                      7.0 ( 1.00x)
dmvr_hv_10_12x20_neon:                                   1.7 ( 4.09x)
dmvr_hv_10_20x12_c:                                      7.0 ( 1.00x)
dmvr_hv_10_20x12_neon:                                   1.7 ( 4.09x)
dmvr_hv_10_20x20_c:                                     11.2 ( 1.00x)
dmvr_hv_10_20x20_neon:                                   2.7 ( 4.15x)
dmvr_hv_12_12x20_c:                                      6.5 ( 1.00x)
dmvr_hv_12_12x20_neon:                                   1.7 ( 3.79x)
dmvr_hv_12_20x12_c:                                      6.5 ( 1.00x)
dmvr_hv_12_20x12_neon:                                   1.7 ( 3.79x)
dmvr_hv_12_20x20_c:                                     10.2 ( 1.00x)
dmvr_hv_12_20x20_neon:                                   2.2 ( 4.64x)
---
 libavcodec/aarch64/vvc/dsp_init.c |  12 ++
 libavcodec/aarch64/vvc/inter.S    | 307 ++++++++++++++++++++++++++++++
 2 files changed, 319 insertions(+)

Comments

Martin Storsjö Sept. 26, 2024, 11:36 a.m. UTC | #1
On Mon, 23 Sep 2024, Zhao Zhili wrote:

> From: Zhao Zhili <zhilizhao@tencent.com>
>
> dmvr_hv_8_12x20_c:                                       8.0 ( 1.00x)
> dmvr_hv_8_12x20_neon:                                    1.2 ( 6.62x)
> dmvr_hv_8_20x12_c:                                       8.0 ( 1.00x)
> dmvr_hv_8_20x12_neon:                                    0.9 ( 8.37x)
> dmvr_hv_8_20x20_c:                                      12.9 ( 1.00x)
> dmvr_hv_8_20x20_neon:                                    1.7 ( 7.62x)
> dmvr_hv_10_12x20_c:                                      7.0 ( 1.00x)
> dmvr_hv_10_12x20_neon:                                   1.7 ( 4.09x)
> dmvr_hv_10_20x12_c:                                      7.0 ( 1.00x)
> dmvr_hv_10_20x12_neon:                                   1.7 ( 4.09x)
> dmvr_hv_10_20x20_c:                                     11.2 ( 1.00x)
> dmvr_hv_10_20x20_neon:                                   2.7 ( 4.15x)
> dmvr_hv_12_12x20_c:                                      6.5 ( 1.00x)
> dmvr_hv_12_12x20_neon:                                   1.7 ( 3.79x)
> dmvr_hv_12_20x12_c:                                      6.5 ( 1.00x)
> dmvr_hv_12_20x12_neon:                                   1.7 ( 3.79x)
> dmvr_hv_12_20x20_c:                                     10.2 ( 1.00x)
> dmvr_hv_12_20x20_neon:                                   2.2 ( 4.64x)
> ---
> libavcodec/aarch64/vvc/dsp_init.c |  12 ++
> libavcodec/aarch64/vvc/inter.S    | 307 ++++++++++++++++++++++++++++++
> 2 files changed, 319 insertions(+)
>
> diff --git a/libavcodec/aarch64/vvc/dsp_init.c b/libavcodec/aarch64/vvc/dsp_init.c
> index b39ebb83fc..995e26d163 100644
> --- a/libavcodec/aarch64/vvc/dsp_init.c
> +++ b/libavcodec/aarch64/vvc/dsp_init.c
> @@ -83,6 +83,15 @@ W_AVG_FUN(8)
> W_AVG_FUN(10)
> W_AVG_FUN(12)
>
> +#define DMVR_FUN(fn, bd) \
> +    void ff_vvc_dmvr_ ## fn ## bd ## _neon(int16_t *dst, \
> +        const uint8_t *_src, const ptrdiff_t _src_stride, const int height, \
> +        const intptr_t mx, const intptr_t my, const int width);

Unnecessary const on scalar parameters

> +
> +DMVR_FUN(hv_, 8)
> +DMVR_FUN(hv_, 10)
> +DMVR_FUN(hv_, 12)
> +
> void ff_vvc_dsp_init_aarch64(VVCDSPContext *const c, const int bd)
> {
>     int cpu_flags = av_get_cpu_flags();
> @@ -155,6 +164,7 @@ void ff_vvc_dsp_init_aarch64(VVCDSPContext *const c, const int bd)
>
>         c->inter.avg = ff_vvc_avg_8_neon;
>         c->inter.w_avg = vvc_w_avg_8;
> +        c->inter.dmvr[1][1] = ff_vvc_dmvr_hv_8_neon;
>
>         for (int i = 0; i < FF_ARRAY_ELEMS(c->sao.band_filter); i++)
>             c->sao.band_filter[i] = ff_h26x_sao_band_filter_8x8_8_neon;
> @@ -196,12 +206,14 @@ void ff_vvc_dsp_init_aarch64(VVCDSPContext *const c, const int bd)
>     } else if (bd == 10) {
>         c->inter.avg = ff_vvc_avg_10_neon;
>         c->inter.w_avg = vvc_w_avg_10;
> +        c->inter.dmvr[1][1] = ff_vvc_dmvr_hv_10_neon;
>
>         c->alf.filter[LUMA] = alf_filter_luma_10_neon;
>         c->alf.filter[CHROMA] = alf_filter_chroma_10_neon;
>     } else if (bd == 12) {
>         c->inter.avg = ff_vvc_avg_12_neon;
>         c->inter.w_avg = vvc_w_avg_12;
> +        c->inter.dmvr[1][1] = ff_vvc_dmvr_hv_12_neon;
>
>         c->alf.filter[LUMA] = alf_filter_luma_12_neon;
>         c->alf.filter[CHROMA] = alf_filter_chroma_12_neon;
> diff --git a/libavcodec/aarch64/vvc/inter.S b/libavcodec/aarch64/vvc/inter.S
> index c4c6ab1a72..a0bb356f07 100644
> --- a/libavcodec/aarch64/vvc/inter.S
> +++ b/libavcodec/aarch64/vvc/inter.S
> @@ -226,3 +226,310 @@ vvc_avg avg, 12
> vvc_avg w_avg, 8
> vvc_avg w_avg, 10
> vvc_avg w_avg, 12
> +
> +/* x0: int16_t *dst
> + * x1: const uint8_t *_src
> + * x2: const ptrdiff_t _src_stride
> + * w3: const int height
> + * x4: const intptr_t mx
> + * x5: const intptr_t my
> + * w6: const int width

Unnecessary const

> + */
> +function ff_vvc_dmvr_hv_8_neon, export=1
> +        dst             .req x0
> +        src             .req x1
> +        src_stride      .req x2
> +        height          .req w3
> +        mx              .req x4
> +        my              .req x5
> +        width           .req w6
> +        tmp0            .req x7
> +        tmp1            .req x8
> +
> +        sub             sp, sp, #(VVC_MAX_PB_SIZE * 4)
> +
> +        movrel          x9, X(ff_vvc_inter_luma_dmvr_filters)
> +        add             x12, x9, mx, lsl #1
> +        ldrb            w10, [x12]
> +        ldrb            w11, [x12, #1]
> +        mov             tmp0, sp
> +        add             tmp1, tmp0, #(VVC_MAX_PB_SIZE * 2)
> +        // We know the value are positive
> +        dup             v0.8h, w10                  // filter_x[0]
> +        dup             v1.8h, w11                  // filter_x[1]

If we don't need these values in GPRs, we could also just do ld1r, 
although that requires incrementing the pointer (which probably can be 
done with a post-increment, [x12], #1) between the loads. Then again, I 
see you load 8 bits but you want them in 16 bit elements, so that would 
require a separate uxtl. So then I guess this use of GPRs for loading is 
reasonable.

All in all, the patch seems fine, except for the unnecessary consts.

// Martin
diff mbox series

Patch

diff --git a/libavcodec/aarch64/vvc/dsp_init.c b/libavcodec/aarch64/vvc/dsp_init.c
index b39ebb83fc..995e26d163 100644
--- a/libavcodec/aarch64/vvc/dsp_init.c
+++ b/libavcodec/aarch64/vvc/dsp_init.c
@@ -83,6 +83,15 @@  W_AVG_FUN(8)
 W_AVG_FUN(10)
 W_AVG_FUN(12)
 
+#define DMVR_FUN(fn, bd) \
+    void ff_vvc_dmvr_ ## fn ## bd ## _neon(int16_t *dst, \
+        const uint8_t *_src, const ptrdiff_t _src_stride, const int height, \
+        const intptr_t mx, const intptr_t my, const int width);
+
+DMVR_FUN(hv_, 8)
+DMVR_FUN(hv_, 10)
+DMVR_FUN(hv_, 12)
+
 void ff_vvc_dsp_init_aarch64(VVCDSPContext *const c, const int bd)
 {
     int cpu_flags = av_get_cpu_flags();
@@ -155,6 +164,7 @@  void ff_vvc_dsp_init_aarch64(VVCDSPContext *const c, const int bd)
 
         c->inter.avg = ff_vvc_avg_8_neon;
         c->inter.w_avg = vvc_w_avg_8;
+        c->inter.dmvr[1][1] = ff_vvc_dmvr_hv_8_neon;
 
         for (int i = 0; i < FF_ARRAY_ELEMS(c->sao.band_filter); i++)
             c->sao.band_filter[i] = ff_h26x_sao_band_filter_8x8_8_neon;
@@ -196,12 +206,14 @@  void ff_vvc_dsp_init_aarch64(VVCDSPContext *const c, const int bd)
     } else if (bd == 10) {
         c->inter.avg = ff_vvc_avg_10_neon;
         c->inter.w_avg = vvc_w_avg_10;
+        c->inter.dmvr[1][1] = ff_vvc_dmvr_hv_10_neon;
 
         c->alf.filter[LUMA] = alf_filter_luma_10_neon;
         c->alf.filter[CHROMA] = alf_filter_chroma_10_neon;
     } else if (bd == 12) {
         c->inter.avg = ff_vvc_avg_12_neon;
         c->inter.w_avg = vvc_w_avg_12;
+        c->inter.dmvr[1][1] = ff_vvc_dmvr_hv_12_neon;
 
         c->alf.filter[LUMA] = alf_filter_luma_12_neon;
         c->alf.filter[CHROMA] = alf_filter_chroma_12_neon;
diff --git a/libavcodec/aarch64/vvc/inter.S b/libavcodec/aarch64/vvc/inter.S
index c4c6ab1a72..a0bb356f07 100644
--- a/libavcodec/aarch64/vvc/inter.S
+++ b/libavcodec/aarch64/vvc/inter.S
@@ -226,3 +226,310 @@  vvc_avg avg, 12
 vvc_avg w_avg, 8
 vvc_avg w_avg, 10
 vvc_avg w_avg, 12
+
+/* x0: int16_t *dst
+ * x1: const uint8_t *_src
+ * x2: const ptrdiff_t _src_stride
+ * w3: const int height
+ * x4: const intptr_t mx
+ * x5: const intptr_t my
+ * w6: const int width
+ */
+function ff_vvc_dmvr_hv_8_neon, export=1
+        dst             .req x0
+        src             .req x1
+        src_stride      .req x2
+        height          .req w3
+        mx              .req x4
+        my              .req x5
+        width           .req w6
+        tmp0            .req x7
+        tmp1            .req x8
+
+        sub             sp, sp, #(VVC_MAX_PB_SIZE * 4)
+
+        movrel          x9, X(ff_vvc_inter_luma_dmvr_filters)
+        add             x12, x9, mx, lsl #1
+        ldrb            w10, [x12]
+        ldrb            w11, [x12, #1]
+        mov             tmp0, sp
+        add             tmp1, tmp0, #(VVC_MAX_PB_SIZE * 2)
+        // We know the value are positive
+        dup             v0.8h, w10                  // filter_x[0]
+        dup             v1.8h, w11                  // filter_x[1]
+
+        add             x12, x9, my, lsl #1
+        ldrb            w10, [x12]
+        ldrb            w11, [x12, #1]
+        sxtw            x6, w6
+        movi            v30.8h, #(1 << (8 - 7))     // offset1
+        movi            v31.8h, #8                  // offset2
+        dup             v2.8h, w10                  // filter_y[0]
+        dup             v3.8h, w11                  // filter_y[1]
+
+        // Valid value for width can only be 8 + 4, 16 + 4
+        cmp             width, #16
+        mov             w10, #0                     // start filter_y or not
+        add             height, height, #1
+        sub             dst, dst, #(VVC_MAX_PB_SIZE * 2)
+        sub             src_stride, src_stride, x6
+        cset            w15, gt                     // width > 16
+1:
+        mov             x12, tmp0
+        mov             x13, tmp1
+        mov             x14, dst
+        cbz             w15, 2f
+
+        // width > 16
+        ldur            q5, [src, #1]
+        ldr             q4, [src], #16
+        uxtl            v7.8h, v5.8b
+        uxtl2           v17.8h, v5.16b
+        uxtl            v6.8h, v4.8b
+        uxtl2           v16.8h, v4.16b
+        mul             v6.8h, v6.8h, v0.8h
+        mul             v16.8h, v16.8h, v0.8h
+        mla             v6.8h, v7.8h, v1.8h
+        mla             v16.8h, v17.8h, v1.8h
+        add             v6.8h, v6.8h, v30.8h
+        add             v16.8h, v16.8h, v30.8h
+        ushr            v6.8h, v6.8h, #(8 - 6)
+        ushr            v7.8h, v16.8h, #(8 - 6)
+        stp             q6, q7, [x13], #32
+
+        cbz             w10, 3f
+
+        ldp             q16, q17, [x12], #32
+        mul             v16.8h, v16.8h, v2.8h
+        mul             v17.8h, v17.8h, v2.8h
+        mla             v16.8h, v6.8h, v3.8h
+        mla             v17.8h, v7.8h, v3.8h
+        add             v16.8h, v16.8h, v31.8h
+        add             v17.8h, v17.8h, v31.8h
+        ushr            v16.8h, v16.8h, #4
+        ushr            v17.8h, v17.8h, #4
+        stp             q16, q17, [x14], #32
+        b               3f
+2:
+        // width > 8
+        ldur            d5, [src, #1]
+        ldr             d4, [src], #8
+        uxtl            v7.8h, v5.8b
+        uxtl            v6.8h, v4.8b
+        mul             v6.8h, v6.8h, v0.8h
+        mla             v6.8h, v7.8h, v1.8h
+        add             v6.8h, v6.8h, v30.8h
+        ushr            v6.8h, v6.8h, #(8 - 6)
+        str             q6, [x13], #16
+
+        cbz             w10, 3f
+
+        ldr             q16, [x12], #16
+        mul             v16.8h, v16.8h, v2.8h
+        mla             v16.8h, v6.8h, v3.8h
+        add             v16.8h, v16.8h, v31.8h
+        ushr            v16.8h, v16.8h, #4
+        str             q16, [x14], #16
+3:
+        ldr             s5, [src, #1]
+        ldr             s4, [src], #4
+        uxtl            v7.8h, v5.8b
+        uxtl            v6.8h, v4.8b
+        mul             v6.4h, v6.4h, v0.4h
+        mla             v6.4h, v7.4h, v1.4h
+        add             v6.4h, v6.4h, v30.4h
+        ushr            v6.4h, v6.4h, #(8 - 6)
+        str             d6, [x13], #8
+
+        cbz             w10, 4f
+
+        ldr             d16, [x12], #8
+        mul             v16.4h, v16.4h, v2.4h
+        mla             v16.4h, v6.4h, v3.4h
+        add             v16.4h, v16.4h, v31.4h
+        ushr            v16.4h, v16.4h, #4
+        str             d16, [x14], #8
+4:
+        subs            height, height, #1
+        mov             w10, #1
+        add             src, src, src_stride
+        add             dst, dst, #(VVC_MAX_PB_SIZE * 2)
+        eor             tmp0, tmp0, tmp1
+        eor             tmp1, tmp0, tmp1
+        eor             tmp0, tmp0, tmp1
+        b.ne            1b
+
+        add             sp, sp, #(VVC_MAX_PB_SIZE * 4)
+        ret
+endfunc
+
+function ff_vvc_dmvr_hv_12_neon, export=1
+        movi            v29.4s, #(12 - 6)
+        movi            v30.4s, #(1 << (12 - 7))    // offset1
+        b               0f
+endfunc
+
+function ff_vvc_dmvr_hv_10_neon, export=1
+        movi            v29.4s, #(10 - 6)
+        movi            v30.4s, #(1 << (10 - 7))    // offset1
+0:
+        movi            v31.4s, #8                  // offset2
+        neg             v29.4s, v29.4s
+
+        sub             sp, sp, #(VVC_MAX_PB_SIZE * 4)
+
+        movrel          x9, X(ff_vvc_inter_luma_dmvr_filters)
+        add             x12, x9, mx, lsl #1
+        ldrb            w10, [x12]
+        ldrb            w11, [x12, #1]
+        mov             tmp0, sp
+        add             tmp1, tmp0, #(VVC_MAX_PB_SIZE * 2)
+        // We know the value are positive
+        dup             v0.8h, w10                  // filter_x[0]
+        dup             v1.8h, w11                  // filter_x[1]
+
+        add             x12, x9, my, lsl #1
+        ldrb            w10, [x12]
+        ldrb            w11, [x12, #1]
+        sxtw            x6, w6
+        dup             v2.8h, w10                  // filter_y[0]
+        dup             v3.8h, w11                  // filter_y[1]
+
+        // Valid value for width can only be 8 + 4, 16 + 4
+        cmp             width, #16
+        mov             w10, #0                     // start filter_y or not
+        add             height, height, #1
+        sub             dst, dst, #(VVC_MAX_PB_SIZE * 2)
+        sub             src_stride, src_stride, x6, lsl #1
+        cset            w15, gt                     // width > 16
+1:
+        mov             x12, tmp0
+        mov             x13, tmp1
+        mov             x14, dst
+        cbz             w15, 2f
+
+        // width > 16
+        add             x16, src, #2
+        ldp             q6, q16, [src], #32
+        ldp             q7, q17, [x16]
+        umull           v4.4s, v6.4h, v0.4h
+        umull2          v5.4s, v6.8h, v0.8h
+        umull           v18.4s, v16.4h, v0.4h
+        umull2          v19.4s, v16.8h, v0.8h
+        umlal           v4.4s, v7.4h, v1.4h
+        umlal2          v5.4s, v7.8h, v1.8h
+        umlal           v18.4s, v17.4h, v1.4h
+        umlal2          v19.4s, v17.8h, v1.8h
+
+        add             v4.4s, v4.4s, v30.4s
+        add             v5.4s, v5.4s, v30.4s
+        add             v18.4s, v18.4s, v30.4s
+        add             v19.4s, v19.4s, v30.4s
+        ushl            v4.4s, v4.4s, v29.4s
+        ushl            v5.4s, v5.4s, v29.4s
+        ushl            v18.4s, v18.4s, v29.4s
+        ushl            v19.4s, v19.4s, v29.4s
+        uqxtn           v6.4h, v4.4s
+        uqxtn2          v6.8h, v5.4s
+        uqxtn           v7.4h, v18.4s
+        uqxtn2          v7.8h, v19.4s
+        stp             q6, q7, [x13], #32
+
+        cbz             w10, 3f
+
+        ldp             q4, q5, [x12], #32
+        umull           v17.4s, v4.4h, v2.4h
+        umull2          v18.4s, v4.8h, v2.8h
+        umull           v19.4s, v5.4h, v2.4h
+        umull2          v20.4s, v5.8h, v2.8h
+        umlal           v17.4s, v6.4h, v3.4h
+        umlal2          v18.4s, v6.8h, v3.8h
+        umlal           v19.4s, v7.4h, v3.4h
+        umlal2          v20.4s, v7.8h, v3.8h
+        add             v17.4s, v17.4s, v31.4s
+        add             v18.4s, v18.4s, v31.4s
+        add             v19.4s, v19.4s, v31.4s
+        add             v20.4s, v20.4s, v31.4s
+        ushr            v17.4s, v17.4s, #4
+        ushr            v18.4s, v18.4s, #4
+        ushr            v19.4s, v19.4s, #4
+        ushr            v20.4s, v20.4s, #4
+        uqxtn           v6.4h, v17.4s
+        uqxtn2          v6.8h, v18.4s
+        uqxtn           v7.4h, v19.4s
+        uqxtn2          v7.8h, v20.4s
+        stp             q6, q7, [x14], #32
+        b               3f
+2:
+        // width > 8
+        ldur            q7, [src, #2]
+        ldr             q6, [src], #16
+        umull           v4.4s, v6.4h, v0.4h
+        umull2          v5.4s, v6.8h, v0.8h
+        umlal           v4.4s, v7.4h, v1.4h
+        umlal2          v5.4s, v7.8h, v1.8h
+
+        add             v4.4s, v4.4s, v30.4s
+        add             v5.4s, v5.4s, v30.4s
+        ushl            v4.4s, v4.4s, v29.4s
+        ushl            v5.4s, v5.4s, v29.4s
+        uqxtn           v6.4h, v4.4s
+        uqxtn2          v6.8h, v5.4s
+        str             q6, [x13], #16
+
+        cbz             w10, 3f
+
+        ldr             q16, [x12], #16
+        umull           v17.4s, v16.4h, v2.4h
+        umull2          v18.4s, v16.8h, v2.8h
+        umlal           v17.4s, v6.4h, v3.4h
+        umlal2          v18.4s, v6.8h, v3.8h
+        add             v17.4s, v17.4s, v31.4s
+        add             v18.4s, v18.4s, v31.4s
+        ushr            v17.4s, v17.4s, #4
+        ushr            v18.4s, v18.4s, #4
+        uqxtn           v16.4h, v17.4s
+        uqxtn2          v16.8h, v18.4s
+        str             q16, [x14], #16
+3:
+        ldr             d7, [src, #2]
+        ldr             d6, [src], #8
+        umull           v4.4s, v7.4h, v1.4h
+        umlal           v4.4s, v6.4h, v0.4h
+        add             v4.4s, v4.4s, v30.4s
+        ushl            v4.4s, v4.4s, v29.4s
+        uqxtn           v6.4h, v4.4s
+        str             d6, [x13], #8
+
+        cbz             w10, 4f
+
+        ldr             d16, [x12], #8
+        umull           v17.4s, v16.4h, v2.4h
+        umlal           v17.4s, v6.4h, v3.4h
+        add             v17.4s, v17.4s, v31.4s
+        ushr            v17.4s, v17.4s, #4
+        uqxtn           v16.4h, v17.4s
+        str             d16, [x14], #8
+4:
+        subs            height, height, #1
+        mov             w10, #1
+        add             src, src, src_stride
+        add             dst, dst, #(VVC_MAX_PB_SIZE * 2)
+        eor             tmp0, tmp0, tmp1
+        eor             tmp1, tmp0, tmp1
+        eor             tmp0, tmp0, tmp1
+        b.ne            1b
+
+        add             sp, sp, #(VVC_MAX_PB_SIZE * 4)
+        ret
+
+.unreq dst
+.unreq src
+.unreq src_stride
+.unreq height
+.unreq mx
+.unreq my
+.unreq width
+.unreq tmp0
+.unreq tmp1
+endfunc