diff mbox series

[FFmpeg-devel,v3,3/3] swscale/aarch64: add vscale specializations

Message ID 70629b7632564b30a44c71bf6a903b26@amazon.com
State Accepted
Headers show
Series checkasm: updated tests for sw_scale | expand

Checks

Context Check Description
yinshiyou/make_loongarch64 success Make finished
yinshiyou/make_fate_loongarch64 success Make fate finished
andriy/make_x86 success Make finished
andriy/make_fate_x86 success Make fate finished

Commit Message

Swinney, Jonathan Aug. 13, 2022, 8:56 p.m. UTC
This commit adds new code paths for vscale when filterSize is 2, 4, or
8. By using specialized code with unrolling to match the filterSize we
can improve performance.

On AWS c7g (Graviton 3, Neoverse V1) instances:
                                 before   after
yuv2yuvX_2_0_512_accurate_neon:  558.8    268.9
yuv2yuvX_4_0_512_accurate_neon:  637.5    434.9
yuv2yuvX_8_0_512_accurate_neon:  1144.8   806.2
yuv2yuvX_16_0_512_accurate_neon: 2080.5   1853.7

Signed-off-by: Jonathan Swinney <jswinney@amazon.com>
---
 libswscale/aarch64/output.S  | 177 +++++++++++++++++++++++++++++++++++
 libswscale/aarch64/swscale.c |  12 +++
 2 files changed, 189 insertions(+)
diff mbox series

Patch

diff --git a/libswscale/aarch64/output.S b/libswscale/aarch64/output.S
index 991750cf31..b8a2818c9b 100644
--- a/libswscale/aarch64/output.S
+++ b/libswscale/aarch64/output.S
@@ -21,13 +21,33 @@ 
 #include "libavutil/aarch64/asm.S"
 
 function ff_yuv2planeX_8_neon, export=1
+// x0 - const int16_t *filter,
+// x1 - int filterSize,
+// x2 - const int16_t **src,
+// x3 - uint8_t *dest,
+// w4 - int dstW,
+// x5 - const uint8_t *dither,
+// w6 - int offset
+
         ld1                 {v0.8B}, [x5]                   // load 8x8-bit dither
+        and                 w6, w6, #7
         cbz                 w6, 1f                          // check if offsetting present
         ext                 v0.8B, v0.8B, v0.8B, #3         // honor offsetting which can be 0 or 3 only
 1:      uxtl                v0.8H, v0.8B                    // extend dither to 16-bit
         ushll               v1.4S, v0.4H, #12               // extend dither to 32-bit with left shift by 12 (part 1)
         ushll2              v2.4S, v0.8H, #12               // extend dither to 32-bit with left shift by 12 (part 2)
+        cmp                 w1, #8                          // if filterSize == 8, branch to specialized version
+        b.eq                6f
+        cmp                 w1, #4                          // if filterSize == 4, branch to specialized version
+        b.eq                8f
+        cmp                 w1, #2                          // if filterSize == 2, branch to specialized version
+        b.eq                10f
+
+// The filter size does not match of the of specialized implementations. It is either even or odd. If it is even
+// then use the first section below.
         mov                 x7, #0                          // i = 0
+        tbnz                w1, #0, 4f                      // if filterSize % 2 != 0 branch to specialized version
+// fs % 2 == 0
 2:      mov                 v3.16B, v1.16B                  // initialize accumulator part 1 with dithering value
         mov                 v4.16B, v2.16B                  // initialize accumulator part 2 with dithering value
         mov                 w8, w1                          // tmpfilterSize = filterSize
@@ -54,4 +74,161 @@  function ff_yuv2planeX_8_neon, export=1
         add                 x7, x7, #8                      // i += 8
         b.gt                2b                              // loop until width consumed
         ret
+
+// If filter size is odd (most likely == 1), then use this section.
+// fs % 2 != 0
+4:      mov                 v3.16B, v1.16B                  // initialize accumulator part 1 with dithering value
+        mov                 v4.16B, v2.16B                  // initialize accumulator part 2 with dithering value
+        mov                 w8, w1                          // tmpfilterSize = filterSize
+        mov                 x9, x2                          // srcp    = src
+        mov                 x10, x0                         // filterp = filter
+5:      ldr                 x11, [x9], #8                   // get 1 pointer: src[j]
+        ldr                 h6, [x10], #2                   // read 1 16 bit coeff X at filter[j]
+        add                 x11, x11, x7, lsl #1            // &src[j  ][i]
+        ld1                 {v5.8H}, [x11]                  // read 8x16-bit @ src[j  ][i + {0..7}]: A,B,C,D,E,F,G,H
+        smlal               v3.4S, v5.4H, v6.H[0]           // val0 += {A,B,C,D} * X
+        smlal2              v4.4S, v5.8H, v6.H[0]           // val1 += {E,F,G,H} * X
+        subs                w8, w8, #1                      // tmpfilterSize -= 2
+        b.gt                5b                              // loop until filterSize consumed
+
+        sqshrun             v3.4h, v3.4s, #16               // clip16(val0>>16)
+        sqshrun2            v3.8h, v4.4s, #16               // clip16(val1>>16)
+        uqshrn              v3.8b, v3.8h, #3                // clip8(val>>19)
+        st1                 {v3.8b}, [x3], #8               // write to destination
+        subs                w4, w4, #8                      // dstW -= 8
+        add                 x7, x7, #8                      // i += 8
+        b.gt                4b                              // loop until width consumed
+        ret
+
+6:      // fs=8
+        ldp                 x5, x6, [x2]                    // load 2 pointers: src[j  ] and src[j+1]
+        ldp                 x7, x9, [x2, #16]               // load 2 pointers: src[j+2] and src[j+3]
+        ldp                 x10, x11, [x2, #32]             // load 2 pointers: src[j+4] and src[j+5]
+        ldp                 x12, x13, [x2, #48]             // load 2 pointers: src[j+6] and src[j+7]
+
+        // load 8x16-bit values for filter[j], where j=0..7
+        ld1                 {v6.8H}, [x0]
+7:
+        mov                 v3.16B, v1.16B                  // initialize accumulator part 1 with dithering value
+        mov                 v4.16B, v2.16B                  // initialize accumulator part 2 with dithering value
+
+        ld1                 {v24.8H}, [x5], #16             // load 8x16-bit values for src[j + 0][i + {0..7}]
+        ld1                 {v25.8H}, [x6], #16             // load 8x16-bit values for src[j + 1][i + {0..7}]
+        ld1                 {v26.8H}, [x7], #16             // load 8x16-bit values for src[j + 2][i + {0..7}]
+        ld1                 {v27.8H}, [x9], #16             // load 8x16-bit values for src[j + 3][i + {0..7}]
+        ld1                 {v28.8H}, [x10], #16            // load 8x16-bit values for src[j + 4][i + {0..7}]
+        ld1                 {v29.8H}, [x11], #16            // load 8x16-bit values for src[j + 5][i + {0..7}]
+        ld1                 {v30.8H}, [x12], #16            // load 8x16-bit values for src[j + 6][i + {0..7}]
+        ld1                 {v31.8H}, [x13], #16            // load 8x16-bit values for src[j + 7][i + {0..7}]
+
+        smlal               v3.4S, v24.4H, v6.H[0]          // val0 += src[0][i + {0..3}] * filter[0]
+        smlal2              v4.4S, v24.8H, v6.H[0]          // val1 += src[0][i + {4..7}] * filter[0]
+        smlal               v3.4S, v25.4H, v6.H[1]          // val0 += src[1][i + {0..3}] * filter[1]
+        smlal2              v4.4S, v25.8H, v6.H[1]          // val1 += src[1][i + {4..7}] * filter[1]
+        smlal               v3.4S, v26.4H, v6.H[2]          // val0 += src[2][i + {0..3}] * filter[2]
+        smlal2              v4.4S, v26.8H, v6.H[2]          // val1 += src[2][i + {4..7}] * filter[2]
+        smlal               v3.4S, v27.4H, v6.H[3]          // val0 += src[3][i + {0..3}] * filter[3]
+        smlal2              v4.4S, v27.8H, v6.H[3]          // val1 += src[3][i + {4..7}] * filter[3]
+        smlal               v3.4S, v28.4H, v6.H[4]          // val0 += src[4][i + {0..3}] * filter[4]
+        smlal2              v4.4S, v28.8H, v6.H[4]          // val1 += src[4][i + {4..7}] * filter[4]
+        smlal               v3.4S, v29.4H, v6.H[5]          // val0 += src[5][i + {0..3}] * filter[5]
+        smlal2              v4.4S, v29.8H, v6.H[5]          // val1 += src[5][i + {4..7}] * filter[5]
+        smlal               v3.4S, v30.4H, v6.H[6]          // val0 += src[6][i + {0..3}] * filter[6]
+        smlal2              v4.4S, v30.8H, v6.H[6]          // val1 += src[6][i + {4..7}] * filter[6]
+        smlal               v3.4S, v31.4H, v6.H[7]          // val0 += src[7][i + {0..3}] * filter[7]
+        smlal2              v4.4S, v31.8H, v6.H[7]          // val1 += src[7][i + {4..7}] * filter[7]
+
+        sqshrun             v3.4h, v3.4s, #16               // clip16(val0>>16)
+        sqshrun2            v3.8h, v4.4s, #16               // clip16(val1>>16)
+        uqshrn              v3.8b, v3.8h, #3                // clip8(val>>19)
+        subs                w4, w4, #8                      // dstW -= 8
+        st1                 {v3.8b}, [x3], #8               // write to destination
+        b.gt                7b                              // loop until width consumed
+        ret
+
+8:      // fs=4
+        ldp                 x5, x6, [x2]                    // load 2 pointers: src[j  ] and src[j+1]
+        ldp                 x7, x9, [x2, #16]               // load 2 pointers: src[j+2] and src[j+3]
+
+        // load 4x16-bit values for filter[j], where j=0..3 and replicated across lanes
+        ld1                 {v6.4H}, [x0]
+9:
+        mov                 v3.16B, v1.16B                  // initialize accumulator part 1 with dithering value
+        mov                 v4.16B, v2.16B                  // initialize accumulator part 2 with dithering value
+
+        ld1                 {v24.8H}, [x5], #16             // load 8x16-bit values for src[j + 0][i + {0..7}]
+        ld1                 {v25.8H}, [x6], #16             // load 8x16-bit values for src[j + 1][i + {0..7}]
+        ld1                 {v26.8H}, [x7], #16             // load 8x16-bit values for src[j + 2][i + {0..7}]
+        ld1                 {v27.8H}, [x9], #16             // load 8x16-bit values for src[j + 3][i + {0..7}]
+
+        smlal               v3.4S, v24.4H, v6.H[0]          // val0 += src[0][i + {0..3}] * filter[0]
+        smlal2              v4.4S, v24.8H, v6.H[0]          // val1 += src[0][i + {4..7}] * filter[0]
+        smlal               v3.4S, v25.4H, v6.H[1]          // val0 += src[1][i + {0..3}] * filter[1]
+        smlal2              v4.4S, v25.8H, v6.H[1]          // val1 += src[1][i + {4..7}] * filter[1]
+        smlal               v3.4S, v26.4H, v6.H[2]          // val0 += src[2][i + {0..3}] * filter[2]
+        smlal2              v4.4S, v26.8H, v6.H[2]          // val1 += src[2][i + {4..7}] * filter[2]
+        smlal               v3.4S, v27.4H, v6.H[3]          // val0 += src[3][i + {0..3}] * filter[3]
+        smlal2              v4.4S, v27.8H, v6.H[3]          // val1 += src[3][i + {4..7}] * filter[3]
+
+        sqshrun             v3.4h, v3.4s, #16               // clip16(val0>>16)
+        sqshrun2            v3.8h, v4.4s, #16               // clip16(val1>>16)
+        uqshrn              v3.8b, v3.8h, #3                // clip8(val>>19)
+        st1                 {v3.8b}, [x3], #8               // write to destination
+        subs                w4, w4, #8                      // dstW -= 8
+        b.gt                9b                              // loop until width consumed
+        ret
+
+10:     // fs=2
+        ldp                 x5, x6, [x2]                    // load 2 pointers: src[j  ] and src[j+1]
+
+        // load 2x16-bit values for filter[j], where j=0..1 and replicated across lanes
+        ldr                 s6, [x0]
+11:
+        mov                 v3.16B, v1.16B                  // initialize accumulator part 1 with dithering value
+        mov                 v4.16B, v2.16B                  // initialize accumulator part 2 with dithering value
+
+        ld1                 {v24.8H}, [x5], #16             // load 8x16-bit values for src[j + 0][i + {0..7}]
+        ld1                 {v25.8H}, [x6], #16             // load 8x16-bit values for src[j + 1][i + {0..7}]
+
+        smlal               v3.4S, v24.4H, v6.H[0]          // val0 += src[0][i + {0..3}] * filter[0]
+        smlal2              v4.4S, v24.8H, v6.H[0]          // val1 += src[0][i + {4..7}] * filter[0]
+        smlal               v3.4S, v25.4H, v6.H[1]          // val0 += src[1][i + {0..3}] * filter[1]
+        smlal2              v4.4S, v25.8H, v6.H[1]          // val1 += src[1][i + {4..7}] * filter[1]
+
+        sqshrun             v3.4h, v3.4s, #16               // clip16(val0>>16)
+        sqshrun2            v3.8h, v4.4s, #16               // clip16(val1>>16)
+        uqshrn              v3.8b, v3.8h, #3                // clip8(val>>19)
+        st1                 {v3.8b}, [x3], #8               // write to destination
+        subs                w4, w4, #8                      // dstW -= 8
+        b.gt                11b                             // loop until width consumed
+        ret
+endfunc
+
+function ff_yuv2plane1_8_neon, export=1
+// x0 - const int16_t *src,
+// x1 - uint8_t *dest,
+// w2 - int dstW,
+// x3 - const uint8_t *dither,
+// w4 - int offset
+        ld1                 {v0.8B}, [x3]                   // load 8x8-bit dither
+        and                 w4, w4, #7
+        cbz                 w4, 1f                          // check if offsetting present
+        ext                 v0.8B, v0.8B, v0.8B, #3         // honor offsetting which can be 0 or 3 only
+1:      uxtl                v0.8H, v0.8B                    // extend dither to 32-bit
+        uxtl                v1.4s, v0.4h
+        uxtl2               v2.4s, v0.8h
+2:
+        ld1                 {v3.8h}, [x0], #16              // read 8x16-bit @ src[j  ][i + {0..7}]: A,B,C,D,E,F,G,H
+        sxtl                v4.4s, v3.4h
+        sxtl2               v5.4s, v3.8h
+        add                 v4.4s, v4.4s, v1.4s
+        add                 v5.4s, v5.4s, v2.4s
+        sqshrun             v4.4h, v4.4s, #6
+        sqshrun2            v4.8h, v5.4s, #6
+
+        uqshrn              v3.8b, v4.8h, #1                // clip8(val>>7)
+        subs                w2, w2, #8                      // dstW -= 8
+        st1                 {v3.8b}, [x1], #8               // write to destination
+        b.gt                2b                              // loop until width consumed
+        ret
 endfunc
diff --git a/libswscale/aarch64/swscale.c b/libswscale/aarch64/swscale.c
index ab28be4da6..321d1f844e 100644
--- a/libswscale/aarch64/swscale.c
+++ b/libswscale/aarch64/swscale.c
@@ -39,6 +39,12 @@  ALL_SCALE_FUNCS(neon);
 void ff_yuv2planeX_8_neon(const int16_t *filter, int filterSize,
                           const int16_t **src, uint8_t *dest, int dstW,
                           const uint8_t *dither, int offset);
+void ff_yuv2plane1_8_neon(
+        const int16_t *src,
+        uint8_t *dest,
+        int dstW,
+        const uint8_t *dither,
+        int offset);
 
 #define ASSIGN_SCALE_FUNC2(hscalefn, filtersize, opt) do {              \
     if (c->srcBpc == 8 && c->dstBpc <= 14) {                            \
@@ -54,6 +60,11 @@  void ff_yuv2planeX_8_neon(const int16_t *filter, int filterSize,
                ASSIGN_SCALE_FUNC2(hscalefn, X8, opt);                   \
            break;                                                       \
   }
+#define ASSIGN_VSCALE_FUNC(vscalefn, opt)                               \
+    switch (c->dstBpc) {                                                \
+    case 8: vscalefn = ff_yuv2plane1_8_  ## opt;  break;                \
+    default: break;                                                     \
+    }
 
 av_cold void ff_sws_init_swscale_aarch64(SwsContext *c)
 {
@@ -62,6 +73,7 @@  av_cold void ff_sws_init_swscale_aarch64(SwsContext *c)
     if (have_neon(cpu_flags)) {
         ASSIGN_SCALE_FUNC(c->hyScale, c->hLumFilterSize, neon);
         ASSIGN_SCALE_FUNC(c->hcScale, c->hChrFilterSize, neon);
+        ASSIGN_VSCALE_FUNC(c->yuv2plane1, neon);
         if (c->dstBpc == 8) {
             c->yuv2planeX = ff_yuv2planeX_8_neon;
         }