diff mbox series

[FFmpeg-devel,1/2] swscale/aarch64: add hscale specializations

Message ID 199f4223693645eda3fdd8257c2e6355@EX13D07UWB004.ant.amazon.com
State Superseded
Headers show
Series [FFmpeg-devel,1/2] swscale/aarch64: add hscale specializations | expand

Checks

Context Check Description
yinshiyou/make_loongarch64 success Make finished
yinshiyou/make_fate_loongarch64 success Make fate finished
andriy/make_x86 success Make finished
andriy/make_fate_x86 success Make fate finished

Commit Message

Swinney, Jonathan April 15, 2022, 9:36 p.m. UTC
This patch adds specializations for hscale for filterSize == 4 and 8 and
converts the existing implementation for the X8 version. For the old code, now
used for the X8 version, it improves the efficiency of the final summations by
reducing 11 instructions to 7.

ff_hscale8to15_8_neon is mostly unchanged from the original except for a few
changes.
 - The loads for the filter data were consolidated into a single 64 byte ld1
   instruction.
 - The final summations were improved.
 - The inner loop on filterSize was completely removed

ff_hscale8to15_4_neon is a complete rewrite. Since the main bottleneck here is
loading the data from src, this data is loaded a whole block ahead and stored
back to the stack to be loaded again with ld4. This arranges the data for most
efficient use of the vector instructions and removes the need for completion
adds at the end. The number of iterations of the C per iteration of the assembly
is increased from 4 to 8, but because of the prefetching, it can only be used
when dstW is >= 16.

This improves speed by 26% on Graviton 2 (Neoverse N1)
ffmpeg -nostats -f lavfi -i testsrc2=4k:d=2 -vf bench=start,scale=1024x1024,bench=stop -f null -
before: t:0.001796 avg:0.001839 max:0.002756 min:0.001733
after:  t:0.001690 avg:0.001352 max:0.002171 min:0.001292

In direct micro benchmarks I wrote the benefit is more dramatic when filterSize == 4.

| (seconds)   | c6g   |       |
| ----------- | ----- | ----- |
| filterSize  | 4     | 8     |
| original    | 7.554 | 7.621 |
| optimized   | 3.736 | 7.054 |
| improvement | 50.5% | 7.44% |

Signed-off-by: Jonathan Swinney <jswinney@amazon.com>
---
 libswscale/aarch64/hscale.S  | 263 +++++++++++++++++++++++++++++++++--
 libswscale/aarch64/swscale.c |  41 ++++--
 libswscale/utils.c           |   2 +-
 3 files changed, 284 insertions(+), 22 deletions(-)

-- 
2.32.0

Comments

Martin Storsjö April 16, 2022, 9:22 p.m. UTC | #1
On Fri, 15 Apr 2022, Swinney, Jonathan wrote:

> This patch adds specializations for hscale for filterSize == 4 and 8 and
> converts the existing implementation for the X8 version. For the old code, now
> used for the X8 version, it improves the efficiency of the final summations by
> reducing 11 instructions to 7.
>
> ff_hscale8to15_8_neon is mostly unchanged from the original except for a few
> changes.
> - The loads for the filter data were consolidated into a single 64 byte ld1
>   instruction.

Couldn't you do this optimization on the existing function too?

> - The final summations were improved.
> - The inner loop on filterSize was completely removed

I presume that this is the only differing factor which affects whether 
it's worthwhile to keep a separate width=8 function or not. At least from 
the checkasm benchmark numbers, the difference is notable but not huge (on 
the range of 4-10%, while the summation improvements gain even more).

Given a fully optimized function that has an inner loop (which is only 
taken once for the width=8 case), is the separate function without an 
inner loop really necessary?

>
> ff_hscale8to15_4_neon is a complete rewrite. Since the main bottleneck here is
> loading the data from src, this data is loaded a whole block ahead and stored
> back to the stack to be loaded again with ld4. This arranges the data for most
> efficient use of the vector instructions and removes the need for completion
> adds at the end. The number of iterations of the C per iteration of the assembly
> is increased from 4 to 8, but because of the prefetching, it can only be used
> when dstW is >= 16.
>
> This improves speed by 26% on Graviton 2 (Neoverse N1)
> ffmpeg -nostats -f lavfi -i testsrc2=4k:d=2 -vf bench=start,scale=1024x1024,bench=stop -f null -
> before: t:0.001796 avg:0.001839 max:0.002756 min:0.001733
> after:  t:0.001690 avg:0.001352 max:0.002171 min:0.001292
>
> In direct micro benchmarks I wrote the benefit is more dramatic when filterSize == 4.
>
> | (seconds)   | c6g   |       |
> | ----------- | ----- | ----- |
> | filterSize  | 4     | 8     |
> | original    | 7.554 | 7.621 |
> | optimized   | 3.736 | 7.054 |
> | improvement | 50.5% | 7.44% |

This function does already have a checkasm test, so it'd be useful to 
include those numbers too!

FWIW that test (runnable with "checkasm --test=sw_scale", benchmarkable 
with "checkasm --bench=hscale --test=sw_scale") is a bit lacking - it only 
ever tests with one dstW, 512. As this patch shows special handling of 
smaller dstW, it'd probably be good to e.g. randomly test a couple widths 
in the range of 1-512, or test one width < 16 and a couple more bigger 
ones. Or just exhaustively testing the full range, although that's quite 
uncommon among tests.

Additionally, that testcase only tests with a steady 1 pixel stride in 
filterPos for each destination pixel - I guess it'd be good for testcase 
coverage to actually test more realistic scaling scenarios too.

So if you'd be able to improve that testcase while at it, that'd be very 
much appreciated!

>
> Signed-off-by: Jonathan Swinney <jswinney@amazon.com>
> ---
> libswscale/aarch64/hscale.S  | 263 +++++++++++++++++++++++++++++++++--
> libswscale/aarch64/swscale.c |  41 ++++--
> libswscale/utils.c           |   2 +-
> 3 files changed, 284 insertions(+), 22 deletions(-)
>
> diff --git a/libswscale/aarch64/hscale.S b/libswscale/aarch64/hscale.S
> index af55ffe2b7..a934653a46 100644
> --- a/libswscale/aarch64/hscale.S
> +++ b/libswscale/aarch64/hscale.S
> @@ -1,5 +1,7 @@
> /*
>  * Copyright (c) 2016 Clément Bœsch <clement stupeflix.com>
> + * Copyright (c) 2019-2021 Sebastian Pop <spop@amazon.com>
> + * Copyright (c) 2022 Jonathan Swinney <jswinney@amazon.com>
>  *
>  * This file is part of FFmpeg.
>  *
> @@ -20,7 +22,25 @@
>
> #include "libavutil/aarch64/asm.S"
>
> -function ff_hscale_8_to_15_neon, export=1
> +/*
> +;-----------------------------------------------------------------------------
> +; horizontal line scaling
> +;
> +; void hscale<source_width>to<intermediate_nbits>_<filterSize>_<opt>
> +;                               (SwsContext *c, int{16,32}_t *dst,
> +;                                int dstW, const uint{8,16}_t *src,
> +;                                const int16_t *filter,
> +;                                const int32_t *filterPos, int filterSize);
> +;
> +; Scale one horizontal line. Input is either 8-bit width or 16-bit width
> +; ($source_width can be either 8, 9, 10 or 16, difference is whether we have to
> +; downscale before multiplying). Filter is 14 bits. Output is either 15 bits
> +; (in int16_t) or 19 bits (in int32_t), as given in $intermediate_nbits. Each
> +; output pixel is generated from $filterSize input pixels, the position of
> +; the first pixel is given in filterPos[nOutputPixel].
> +;----------------------------------------------------------------------------- */
> +
> +function ff_hscale8to15_X8_neon, export=1
>         sbfiz               x7, x6, #1, #32             // filterSize*2 (*2 because int16)
> 1:      ldr                 w8, [x5], #4                // filterPos[idx]
>         ldr                 w0, [x5], #4                // filterPos[idx + 1]
> @@ -61,20 +81,239 @@ function ff_hscale_8_to_15_neon, export=1
>         smlal               v3.4S, v18.4H, v19.4H       // v3 accumulates srcp[filterPos[3] + {0..3}] * filter[{0..3}]
>         smlal2              v3.4S, v18.8H, v19.8H       // v3 accumulates srcp[filterPos[3] + {4..7}] * filter[{4..7}]
>         b.gt                2b                          // inner loop if filterSize not consumed completely
> -        addp                v0.4S, v0.4S, v0.4S         // part0 horizontal pair adding
> -        addp                v1.4S, v1.4S, v1.4S         // part1 horizontal pair adding
> -        addp                v2.4S, v2.4S, v2.4S         // part2 horizontal pair adding
> -        addp                v3.4S, v3.4S, v3.4S         // part3 horizontal pair adding
> -        addp                v0.4S, v0.4S, v0.4S         // part0 horizontal pair adding
> -        addp                v1.4S, v1.4S, v1.4S         // part1 horizontal pair adding
> -        addp                v2.4S, v2.4S, v2.4S         // part2 horizontal pair adding
> -        addp                v3.4S, v3.4S, v3.4S         // part3 horizontal pair adding
> -        zip1                v0.4S, v0.4S, v1.4S         // part01 = zip values from part0 and part1
> -        zip1                v2.4S, v2.4S, v3.4S         // part23 = zip values from part2 and part3
> -        mov                 v0.d[1], v2.d[0]            // part0123 = zip values from part01 and part23
> +        uzp1                v4.4S, v0.4S, v1.4S         // unzip low parts 0 and 1
> +        uzp2                v5.4S, v0.4S, v1.4S         // unzip high parts 0 and 1
> +        uzp1                v6.4S, v2.4S, v3.4S         // unzip low parts 2 and 3
> +        uzp2                v7.4S, v2.4S, v3.4S         // unzip high parts 2 and 3
> +        add                 v16.4S, v4.4S, v5.4S        // add half of each of part 0 and 1
> +        add                 v17.4S, v6.4S, v7.4S        // add half of each of part 2 and 3
> +        addp                v0.4S, v16.4S, v17.4S       // pairwise add to complete half adds in earlier steps

This change in itself makes it better, but unless I'm missing something, 
it can be simplified even further.

With the current checkasm test, the original function gives the following 
runtimes (with the raw cycle counter registers, i.e. configured with 
--disable-linux-perf):

                             Cortex A53      A72       A73
hscale_8_to_15_width8_neon:     8273.0   4604.0    4271.7

With your suggested modification, it's sped up to this:

hscale_8_to_15_width8_neon:     8017.0   4420.0   3684.0

But your 7 instructions can be replaced with just these three 
instructions:

         addp                v0.4S, v0.4S, v1.4S         // part0 half-sum x2, part1 half-sum x2
         addp                v1.4S, v2.4S, v3.4S         // part2 half-sum x2, part3 half-sum x2
         addp                v0.4S, v0.4S, v1.4S         // sums of part0-3

With that in place, I get these benchmark numbers:

hscale_8_to_15_width8_neon:     7633.0   3980.5   3348.0

Which is yet quite a fair bit faster.


>         subs                w2, w2, #4                  // dstW -= 4
>         sqshrn              v0.4H, v0.4S, #7            // shift and clip the 2x16-bit final values
>         st1                 {v0.4H}, [x1], #8           // write to destination part0123
>         b.gt                1b                          // loop until end of line
>         ret
> endfunc
> +
> +
> +function ff_hscale8to15_8_neon, export=1
> +// x0      SwsContext *c (not used)
> +// x1      int16_t *dst
> +// x2      int dstW
> +// x3      const uint8_t *src
> +// x4      const int16_t *filter
> +// x5      const int32_t *filterPos
> +// x6      int filterSize
> +// x8-x11  filterPos values
> +
> +// v0-v3   multiply add accumulators
> +// v4-v7   filter data, temp for final horizontal sum
> +// v16-v19 src data
> +1:
> +        ld1                 {v4.8H, v5.8H, v6.8H, v7.8H}, [x4], #64 // load filter[idx=0..3, j=0..7]
> +        ldp                 w8, w9,  [x5]               // filterPos[idx + 0], [idx + 1]
> +        ldp                 w10, w11, [x5, 8]           // filterPos[idx + 2], [idx + 3]
> +        movi                v0.2D, #0                   // val sum part 1 (for dst[0])
> +        movi                v1.2D, #0                   // val sum part 2 (for dst[1])
> +        add                 x5, x5, #16                 // increment filterPos
> +
> +        add                 x8, x3, w8, UXTW            // srcp + filterPos[0]
> +        add                 x9,  x3, w9, UXTW           // srcp + filterPos[1]
> +        add                 x10, x3, w10, UXTW          // srcp + filterPos[2]
> +        add                 x11, x3, w11, UXTW          // srcp + filterPos[3]
> +
> +        ld1                 {v16.8B}, [x8], #8          // srcp[filterPos[0] + {0..7}]
> +        ld1                 {v17.8B}, [x9], #8          // srcp[filterPos[1] + {0..7}]
> +
> +        movi                v2.2D, #0                   // val sum part 3 (for dst[2])
> +        movi                v3.2D, #0                   // val sum part 4 (for dst[3])
> +
> +        uxtl                v16.8H, v16.8B              // unpack part 1 to 16-bit
> +        uxtl                v17.8H, v17.8B              // unpack part 2 to 16-bit
> +
> +        smlal               v0.4S, v16.4H, v4.4H        // v0 accumulates srcp[filterPos[0] + {0..3}] * filter[{0..3}]
> +        smlal               v1.4S, v17.4H, v5.4H        // v1 accumulates srcp[filterPos[1] + {0..3}] * filter[{0..3}]
> +
> +        ld1                 {v18.8B}, [x10], #8         // srcp[filterPos[2] + {0..7}]
> +        ld1                 {v19.8B}, [x11], #8         // srcp[filterPos[3] + {0..7}]
> +
> +        smlal2              v0.4S, v16.8H, v4.8H        // v0 accumulates srcp[filterPos[0] + {4..7}] * filter[{4..7}]
> +        smlal2              v1.4S, v17.8H, v5.8H        // v1 accumulates srcp[filterPos[1] + {4..7}] * filter[{4..7}]
> +
> +        uxtl                v18.8H, v18.8B              // unpack part 3 to 16-bit
> +        uxtl                v19.8H, v19.8B              // unpack part 4 to 16-bit
> +
> +        smlal               v2.4S, v18.4H, v6.4H        // v2 accumulates srcp[filterPos[2] + {0..3}] * filter[{0..3}]
> +        smlal               v3.4S, v19.4H, v7.4H        // v3 accumulates srcp[filterPos[3] + {0..3}] * filter[{0..3}]
> +
> +        smlal2              v2.4S, v18.8H, v6.8H        // v2 accumulates srcp[filterPos[2] + {4..7}] * filter[{4..7}]
> +        smlal2              v3.4S, v19.8H, v7.8H        // v3 accumulates srcp[filterPos[3] + {4..7}] * filter[{4..7}]
> +
> +        uzp1                v4.4S, v0.4S, v1.4S         // unzip low parts 0 and 1
> +        uzp2                v5.4S, v0.4S, v1.4S         // unzip high parts 0 and 1
> +        uzp1                v6.4S, v2.4S, v3.4S         // unzip low parts 2 and 3
> +        uzp2                v7.4S, v2.4S, v3.4S         // unzip high parts 2 and 3
> +
> +        add                 v0.4S, v4.4S, v5.4S         // add half of each of part 0 and 1
> +        add                 v1.4S, v6.4S, v7.4S         // add half of each of part 2 and 3
> +
> +        addp                v4.4S, v0.4S, v1.4S         // pairwise add to complete half adds in earlier steps
> +
> +        subs                w2, w2, #4                  // dstW -= 4
> +        sqshrn              v0.4H, v4.4S, #7            // shift and clip the 2x16-bit final values
> +        st1                 {v0.4H}, [x1], #8           // write to destination part0123
> +        b.gt                1b                          // loop until end of line
> +        ret
> +endfunc
> +
> +function ff_hscale8to15_4_neon, export=1
> +// x0  SwsContext *c (not used)
> +// x1  int16_t *dst
> +// x2  int dstW
> +// x3  const uint8_t *src
> +// x4  const int16_t *filter
> +// x5  const int32_t *filterPos
> +// x6  int filterSize
> +// x8-x15 registers for gathering src data
> +
> +// v0      madd accumulator 4S
> +// v1-v4   filter values (16 bit) 8H
> +// v5      madd accumulator 4S
> +// v16-v19 src values (8 bit) 8B
> +
> +// This implementation has 4 sections:
> +//  1. Prefetch src data
> +//  2. Interleaved prefetching src data and madd
> +//  3. Complete madd
> +//  4. Complete remaining iterations when dstW % 8 != 0
> +
> +        add                 sp, sp, #-32                // allocate 32 bytes on the stack
> +        cmp                 w2, #16                     // if dstW <16, skip to the last block used for wrapping up
> +        b.lt                2f
> +
> +        // load 8 values from filterPos to be used as offsets into src
> +        ldp                 w8, w9,  [x5]               // filterPos[idx + 0], [idx + 1]
> +        ldp                 w10, w11, [x5, 8]           // filterPos[idx + 2], [idx + 3]
> +        ldp                 w12, w13, [x5, 16]          // filterPos[idx + 4], [idx + 5]
> +        ldp                 w14, w15, [x5, 24]          // filterPos[idx + 6], [idx + 7]
> +        add                 x5, x5, #32                 // advance filterPos
> +
> +        // gather random access data from src into contiguous memory
> +        ldr                 w8, [x3, w8, UXTW]          // src[filterPos[idx + 0]][0..3]
> +        ldr                 w9, [x3, w9, UXTW]          // src[filterPos[idx + 1]][0..3]
> +        ldr                 w10, [x3, w10, UXTW]        // src[filterPos[idx + 2]][0..3]
> +        ldr                 w11, [x3, w11, UXTW]        // src[filterPos[idx + 3]][0..3]
> +        ldr                 w12, [x3, w12, UXTW]        // src[filterPos[idx + 4]][0..3]
> +        ldr                 w13, [x3, w13, UXTW]        // src[filterPos[idx + 5]][0..3]
> +        ldr                 w14, [x3, w14, UXTW]        // src[filterPos[idx + 6]][0..3]
> +        ldr                 w15, [x3, w15, UXTW]        // src[filterPos[idx + 7]][0..3]
> +        stp                 w8, w9, [sp]                // *scratch_mem = { src[filterPos[idx + 0]][0..3], src[filterPos[idx + 1]][0..3] }
> +        stp                 w10, w11, [sp, 8]           // *scratch_mem = { src[filterPos[idx + 2]][0..3], src[filterPos[idx + 3]][0..3] }
> +        stp                 w12, w13, [sp, 16]          // *scratch_mem = { src[filterPos[idx + 4]][0..3], src[filterPos[idx + 5]][0..3] }
> +        stp                 w14, w15, [sp, 24]          // *scratch_mem = { src[filterPos[idx + 6]][0..3], src[filterPos[idx + 7]][0..3] }
> +
> +1:
> +        ld4                 {v16.8B, v17.8B, v18.8B, v19.8B}, [sp] // transpose 8 bytes each from src into 4 registers
> +
> +        // load 8 values from filterPos to be used as offsets into src
> +        ldp                 w8, w9,  [x5]               // filterPos[idx + 0][0..3], [idx + 1][0..3], next iteration
> +        ldp                 w10, w11, [x5, 8]           // filterPos[idx + 2][0..3], [idx + 3][0..3], next iteration
> +        ldp                 w12, w13, [x5, 16]          // filterPos[idx + 4][0..3], [idx + 5][0..3], next iteration
> +        ldp                 w14, w15, [x5, 24]          // filterPos[idx + 6][0..3], [idx + 7][0..3], next iteration
> +
> +        movi                v0.2D, #0                   // Clear madd accumulator for idx 0..3
> +        movi                v5.2D, #0                   // Clear madd accumulator for idx 4..7
> +
> +        ld4                 {v1.8H, v2.8H, v3.8H, v4.8H}, [x4], #64 // load filter idx + 0..7
> +
> +        add                 x5, x5, #32                 // advance filterPos
> +
> +        // interleaved SIMD and prefetching intended to keep ld/st and vector pipelines busy
> +        uxtl                v16.8H, v16.8B              // unsigned extend long, covert src data to 16-bit
> +        uxtl                v17.8H, v17.8B              // unsigned extend long, covert src data to 16-bit
> +        ldr                 w8, [x3, w8, UXTW]          // src[filterPos[idx + 0]], next iteration
> +        ldr                 w9, [x3, w9, UXTW]          // src[filterPos[idx + 1]], next iteration
> +        uxtl                v18.8H, v18.8B              // unsigned extend long, covert src data to 16-bit
> +        uxtl                v19.8H, v19.8B              // unsigned extend long, covert src data to 16-bit
> +        ldr                 w10, [x3, w10, UXTW]        // src[filterPos[idx + 2]], next iteration
> +        ldr                 w11, [x3, w11, UXTW]        // src[filterPos[idx + 3]], next iteration
> +
> +        smlal               v0.4S, v1.4H, v16.4H        // multiply accumulate inner loop j = 0, idx = 0..3
> +        smlal               v0.4S, v2.4H, v17.4H        // multiply accumulate inner loop j = 1, idx = 0..3
> +        ldr                 w12, [x3, w12, UXTW]        // src[filterPos[idx + 4]], next iteration
> +        ldr                 w13, [x3, w13, UXTW]        // src[filterPos[idx + 5]], next iteration
> +        smlal               v0.4S, v3.4H, v18.4H        // multiply accumulate inner loop j = 2, idx = 0..3
> +        smlal               v0.4S, v4.4H, v19.4H        // multiply accumulate inner loop j = 3, idx = 0..3
> +        ldr                 w14, [x3, w14, UXTW]        // src[filterPos[idx + 6]], next iteration
> +        ldr                 w15, [x3, w15, UXTW]        // src[filterPos[idx + 7]], next iteration
> +
> +        smlal2              v5.4S, v1.8H, v16.8H        // multiply accumulate inner loop j = 0, idx = 4..7
> +        smlal2              v5.4S, v2.8H, v17.8H        // multiply accumulate inner loop j = 1, idx = 4..7
> +        stp                 w8, w9, [sp]                // *scratch_mem = { src[filterPos[idx + 0]][0..3], src[filterPos[idx + 1]][0..3] }
> +        stp                 w10, w11, [sp, 8]           // *scratch_mem = { src[filterPos[idx + 2]][0..3], src[filterPos[idx + 3]][0..3] }
> +        smlal2              v5.4S, v3.8H, v18.8H        // multiply accumulate inner loop j = 2, idx = 4..7
> +        smlal2              v5.4S, v4.8H, v19.8H        // multiply accumulate inner loop j = 3, idx = 4..7
> +        stp                 w12, w13, [sp, 16]          // *scratch_mem = { src[filterPos[idx + 4]][0..3], src[filterPos[idx + 5]][0..3] }
> +        stp                 w14, w15, [sp, 24]          // *scratch_mem = { src[filterPos[idx + 6]][0..3], src[filterPos[idx + 7]][0..3] }
> +
> +        sub                 w2, w2, #8                  // dstW -= 8
> +        sqshrn              v0.4H, v0.4S, #7            // shift and clip the 2x16-bit final values
> +        sqshrn              v1.4H, v5.4S, #7            // shift and clip the 2x16-bit final values
> +        st1                 {v0.4H, v1.4H}, [x1], #16   // write to dst[idx + 0..7]
> +        cmp                 w2, #16                     // continue on main loop if there are at least 16 iterations left
> +        b.ge                1b
> +
> +        // last full iteration
> +        ld4                 {v16.8B, v17.8B, v18.8B, v19.8B}, [sp]
> +        ld4                 {v1.8H, v2.8H, v3.8H, v4.8H}, [x4], #64 // load filter idx + 0..7
> +
> +        movi                v0.2D, #0                   // Clear madd accumulator for idx 0..3
> +        movi                v5.2D, #0                   // Clear madd accumulator for idx 4..7
> +
> +        uxtl                v16.8H, v16.8B              // unsigned extend long, covert src data to 16-bit
> +        uxtl                v17.8H, v17.8B              // unsigned extend long, covert src data to 16-bit
> +        uxtl                v18.8H, v18.8B              // unsigned extend long, covert src data to 16-bit
> +        uxtl                v19.8H, v19.8B              // unsigned extend long, covert src data to 16-bit
> +
> +        smlal               v0.4S, v1.4H, v16.4H        // multiply accumulate inner loop j = 0, idx = 0..3
> +        smlal               v0.4S, v2.4H, v17.4H        // multiply accumulate inner loop j = 1, idx = 0..3
> +        smlal               v0.4S, v3.4H, v18.4H        // multiply accumulate inner loop j = 2, idx = 0..3
> +        smlal               v0.4S, v4.4H, v19.4H        // multiply accumulate inner loop j = 3, idx = 0..3
> +
> +        smlal2              v5.4S, v1.8H, v16.8H        // multiply accumulate inner loop j = 0, idx = 4..7
> +        smlal2              v5.4S, v2.8H, v17.8H        // multiply accumulate inner loop j = 1, idx = 4..7
> +        smlal2              v5.4S, v3.8H, v18.8H        // multiply accumulate inner loop j = 2, idx = 4..7
> +        smlal2              v5.4S, v4.8H, v19.8H        // multiply accumulate inner loop j = 3, idx = 4..7
> +
> +        subs                w2, w2, #8                  // dstW -= 8
> +        sqshrn              v0.4H, v0.4S, #7            // shift and clip the 2x16-bit final values
> +        sqshrn              v1.4H, v5.4S, #7            // shift and clip the 2x16-bit final values
> +        st1                 {v0.4H, v1.4H}, [x1], #16   // write to dst[idx + 0..7]
> +
> +        cbnz                w2, 2f                      // if >0 iterations remain, jump to the wrap up section
> +
> +        add                 sp, sp, #32                 // clean up stack
> +        ret

FWIW this implementation looks quite clever, I've got nothing to add to 
this.


> +
> +        // finish up when dstW % 8 != 0 or dstW < 16
> +2:
> +        // load src
> +        ldr                 w8, [x5], #4                // filterPos[i]
> +        ldr                 w9, [x3, w8, UXTW]          // src[filterPos[i] + 0..3]
> +        ins                 v5.S[0], w9                 // move to simd register
> +        // load filter
> +        ld1                 {v6.4H}, [x4], #8           // filter[filterSize * i + 0..3]
> +
> +        uxtl                v5.8H, v5.8B                // unsigned exten long, convert src data to 16-bit
> +        smull               v0.4S, v5.4H, v6.4H         // 4 iterations of src[...] * filter[...]
> +        addp                v0.4S, v0.4S, v0.4S         // accumulate the smull results
> +        addp                v0.4S, v0.4S, v0.4S         // accumulate the smull results

Wouldn't it be better with just one addv instead of two addp?

> +        sqshrn              v0.4H, v0.4S, #7            // shift and clip the 2x16-bit final values
> +        mov                 w10, v0.S[0]                // move back to general register (only one value from simd reg is used)
> +        strh                w10, [x1], #2               // dst[i] = ...

Transfers from SIMD registers to GPRs usually is a big bottleneck and 
ideally should be avoided. Wouldn't this work just as well just with a 
single element st1 store? I.e. "st1 {v0.h}[0], [x1], #2".

> +        sub                 w2, w2, #1                  // dstW--
> +        cbnz                w2, 2b
> +
> +        add                 sp, sp, #32                 // clean up stack
> +        ret
> +endfunc
> diff --git a/libswscale/aarch64/swscale.c b/libswscale/aarch64/swscale.c
> index 09d0a7130e..2ea4ccb3a6 100644
> --- a/libswscale/aarch64/swscale.c
> +++ b/libswscale/aarch64/swscale.c
> @@ -22,25 +22,48 @@
> #include "libswscale/swscale_internal.h"
> #include "libavutil/aarch64/cpu.h"
>
> -void ff_hscale_8_to_15_neon(SwsContext *c, int16_t *dst, int dstW,
> -                            const uint8_t *src, const int16_t *filter,
> -                            const int32_t *filterPos, int filterSize);
> +#define SCALE_FUNC(filter_n, from_bpc, to_bpc, opt) \
> +void ff_hscale ## from_bpc ## to ## to_bpc ## _ ## filter_n ## _ ## opt( \
> +                                                SwsContext *c, int16_t *data, \
> +                                                int dstW, const uint8_t *src, \
> +                                                const int16_t *filter, \
> +                                                const int32_t *filterPos, int filterSize)
> +#define SCALE_FUNCS(filter_n, opt) \
> +    SCALE_FUNC(filter_n,  8, 15, opt);
> +#define ALL_SCALE_FUNCS(opt) \
> +    SCALE_FUNCS(4, opt); \
> +    SCALE_FUNCS(8, opt); \
> +    SCALE_FUNCS(X8, opt)
> +
> +ALL_SCALE_FUNCS(neon);
>
> void ff_yuv2planeX_8_neon(const int16_t *filter, int filterSize,
>                           const int16_t **src, uint8_t *dest, int dstW,
>                           const uint8_t *dither, int offset);
>
> +#define ASSIGN_SCALE_FUNC2(hscalefn, filtersize, opt) do {              \
> +    if (c->srcBpc == 8 && c->dstBpc <= 14) {                            \
> +      hscalefn =                                                        \
> +        ff_hscale8to15_ ## filtersize ## _ ## opt;                      \
> +    }                                                                   \
> +} while (0)
> +
> +#define ASSIGN_SCALE_FUNC(hscalefn, filtersize, opt)                    \
> +  switch (filtersize) {                                                 \
> +  case 4:  ASSIGN_SCALE_FUNC2(hscalefn, 4, opt); break;                 \
> +  case 8:  ASSIGN_SCALE_FUNC2(hscalefn, 8, opt); break;                 \
> +  default: if (filtersize % 8 == 0)                                     \
> +               ASSIGN_SCALE_FUNC2(hscalefn, X8, opt);                   \
> +           break;                                                       \
> +  }
> +
> av_cold void ff_sws_init_swscale_aarch64(SwsContext *c)
> {
>     int cpu_flags = av_get_cpu_flags();
>
>     if (have_neon(cpu_flags)) {
> -        if (c->srcBpc == 8 && c->dstBpc <= 14 &&
> -            (c->hLumFilterSize % 8) == 0 &&
> -            (c->hChrFilterSize % 8) == 0)
> -        {
> -            c->hyScale = c->hcScale = ff_hscale_8_to_15_neon;
> -        }
> +        ASSIGN_SCALE_FUNC(c->hyScale, c->hLumFilterSize, neon);
> +        ASSIGN_SCALE_FUNC(c->hcScale, c->hChrFilterSize, neon);

The fact that this now assigns hyScale and hcScale indpendently based on 
their filter sizes looks like a good, valuable improvement too!

// Martin
Martin Storsjö April 20, 2022, 8:44 a.m. UTC | #2
On Sun, 17 Apr 2022, Martin Storsjö wrote:

> On Fri, 15 Apr 2022, Swinney, Jonathan wrote:
>
>> This patch adds specializations for hscale for filterSize == 4 and 8 and
>> converts the existing implementation for the X8 version. For the old code, 
>> now
>> used for the X8 version, it improves the efficiency of the final summations 
>> by
>> reducing 11 instructions to 7.
>> 
>> ff_hscale8to15_8_neon is mostly unchanged from the original except for a 
>> few
>> changes.
>> - The loads for the filter data were consolidated into a single 64 byte ld1
>>   instruction.
>
> Couldn't you do this optimization on the existing function too?

Sorry, now I realized why this optimization only can be done if you 
operate on a specific known filter width.

>> - The final summations were improved.
>> - The inner loop on filterSize was completely removed
>
> I presume that this is the only differing factor which affects whether it's 
> worthwhile to keep a separate width=8 function or not. At least from the 
> checkasm benchmark numbers, the difference is notable but not huge (on the 
> range of 4-10%, while the summation improvements gain even more).
>
> Given a fully optimized function that has an inner loop (which is only taken 
> once for the width=8 case), is the separate function without an inner loop 
> really necessary?

With the ideal version of the final summation in both functions, the 
separate filtersize=8 function is 11-19% faster than the generic 
multiple-of-8 function (on Cortex A53 and A72 - on A73 the both versions 
are essentially equally fast), so there's probably good reason to go with 
the separate version.

Thus, disregard the review comments above.

// Martin
diff mbox series

Patch

diff --git a/libswscale/aarch64/hscale.S b/libswscale/aarch64/hscale.S
index af55ffe2b7..a934653a46 100644
--- a/libswscale/aarch64/hscale.S
+++ b/libswscale/aarch64/hscale.S
@@ -1,5 +1,7 @@ 
 /*
  * Copyright (c) 2016 Clément Bœsch <clement stupeflix.com>
+ * Copyright (c) 2019-2021 Sebastian Pop <spop@amazon.com>
+ * Copyright (c) 2022 Jonathan Swinney <jswinney@amazon.com>
  *
  * This file is part of FFmpeg.
  *
@@ -20,7 +22,25 @@ 
 
 #include "libavutil/aarch64/asm.S"
 
-function ff_hscale_8_to_15_neon, export=1
+/*
+;-----------------------------------------------------------------------------
+; horizontal line scaling
+;
+; void hscale<source_width>to<intermediate_nbits>_<filterSize>_<opt>
+;                               (SwsContext *c, int{16,32}_t *dst,
+;                                int dstW, const uint{8,16}_t *src,
+;                                const int16_t *filter,
+;                                const int32_t *filterPos, int filterSize);
+;
+; Scale one horizontal line. Input is either 8-bit width or 16-bit width
+; ($source_width can be either 8, 9, 10 or 16, difference is whether we have to
+; downscale before multiplying). Filter is 14 bits. Output is either 15 bits
+; (in int16_t) or 19 bits (in int32_t), as given in $intermediate_nbits. Each
+; output pixel is generated from $filterSize input pixels, the position of
+; the first pixel is given in filterPos[nOutputPixel].
+;----------------------------------------------------------------------------- */
+
+function ff_hscale8to15_X8_neon, export=1
         sbfiz               x7, x6, #1, #32             // filterSize*2 (*2 because int16)
 1:      ldr                 w8, [x5], #4                // filterPos[idx]
         ldr                 w0, [x5], #4                // filterPos[idx + 1]
@@ -61,20 +81,239 @@  function ff_hscale_8_to_15_neon, export=1
         smlal               v3.4S, v18.4H, v19.4H       // v3 accumulates srcp[filterPos[3] + {0..3}] * filter[{0..3}]
         smlal2              v3.4S, v18.8H, v19.8H       // v3 accumulates srcp[filterPos[3] + {4..7}] * filter[{4..7}]
         b.gt                2b                          // inner loop if filterSize not consumed completely
-        addp                v0.4S, v0.4S, v0.4S         // part0 horizontal pair adding
-        addp                v1.4S, v1.4S, v1.4S         // part1 horizontal pair adding
-        addp                v2.4S, v2.4S, v2.4S         // part2 horizontal pair adding
-        addp                v3.4S, v3.4S, v3.4S         // part3 horizontal pair adding
-        addp                v0.4S, v0.4S, v0.4S         // part0 horizontal pair adding
-        addp                v1.4S, v1.4S, v1.4S         // part1 horizontal pair adding
-        addp                v2.4S, v2.4S, v2.4S         // part2 horizontal pair adding
-        addp                v3.4S, v3.4S, v3.4S         // part3 horizontal pair adding
-        zip1                v0.4S, v0.4S, v1.4S         // part01 = zip values from part0 and part1
-        zip1                v2.4S, v2.4S, v3.4S         // part23 = zip values from part2 and part3
-        mov                 v0.d[1], v2.d[0]            // part0123 = zip values from part01 and part23
+        uzp1                v4.4S, v0.4S, v1.4S         // unzip low parts 0 and 1
+        uzp2                v5.4S, v0.4S, v1.4S         // unzip high parts 0 and 1
+        uzp1                v6.4S, v2.4S, v3.4S         // unzip low parts 2 and 3
+        uzp2                v7.4S, v2.4S, v3.4S         // unzip high parts 2 and 3
+        add                 v16.4S, v4.4S, v5.4S        // add half of each of part 0 and 1
+        add                 v17.4S, v6.4S, v7.4S        // add half of each of part 2 and 3
+        addp                v0.4S, v16.4S, v17.4S       // pairwise add to complete half adds in earlier steps
         subs                w2, w2, #4                  // dstW -= 4
         sqshrn              v0.4H, v0.4S, #7            // shift and clip the 2x16-bit final values
         st1                 {v0.4H}, [x1], #8           // write to destination part0123
         b.gt                1b                          // loop until end of line
         ret
 endfunc
+
+
+function ff_hscale8to15_8_neon, export=1
+// x0      SwsContext *c (not used)
+// x1      int16_t *dst
+// x2      int dstW
+// x3      const uint8_t *src
+// x4      const int16_t *filter
+// x5      const int32_t *filterPos
+// x6      int filterSize
+// x8-x11  filterPos values
+
+// v0-v3   multiply add accumulators
+// v4-v7   filter data, temp for final horizontal sum
+// v16-v19 src data
+1:
+        ld1                 {v4.8H, v5.8H, v6.8H, v7.8H}, [x4], #64 // load filter[idx=0..3, j=0..7]
+        ldp                 w8, w9,  [x5]               // filterPos[idx + 0], [idx + 1]
+        ldp                 w10, w11, [x5, 8]           // filterPos[idx + 2], [idx + 3]
+        movi                v0.2D, #0                   // val sum part 1 (for dst[0])
+        movi                v1.2D, #0                   // val sum part 2 (for dst[1])
+        add                 x5, x5, #16                 // increment filterPos
+
+        add                 x8, x3, w8, UXTW            // srcp + filterPos[0]
+        add                 x9,  x3, w9, UXTW           // srcp + filterPos[1]
+        add                 x10, x3, w10, UXTW          // srcp + filterPos[2]
+        add                 x11, x3, w11, UXTW          // srcp + filterPos[3]
+
+        ld1                 {v16.8B}, [x8], #8          // srcp[filterPos[0] + {0..7}]
+        ld1                 {v17.8B}, [x9], #8          // srcp[filterPos[1] + {0..7}]
+
+        movi                v2.2D, #0                   // val sum part 3 (for dst[2])
+        movi                v3.2D, #0                   // val sum part 4 (for dst[3])
+
+        uxtl                v16.8H, v16.8B              // unpack part 1 to 16-bit
+        uxtl                v17.8H, v17.8B              // unpack part 2 to 16-bit
+
+        smlal               v0.4S, v16.4H, v4.4H        // v0 accumulates srcp[filterPos[0] + {0..3}] * filter[{0..3}]
+        smlal               v1.4S, v17.4H, v5.4H        // v1 accumulates srcp[filterPos[1] + {0..3}] * filter[{0..3}]
+
+        ld1                 {v18.8B}, [x10], #8         // srcp[filterPos[2] + {0..7}]
+        ld1                 {v19.8B}, [x11], #8         // srcp[filterPos[3] + {0..7}]
+
+        smlal2              v0.4S, v16.8H, v4.8H        // v0 accumulates srcp[filterPos[0] + {4..7}] * filter[{4..7}]
+        smlal2              v1.4S, v17.8H, v5.8H        // v1 accumulates srcp[filterPos[1] + {4..7}] * filter[{4..7}]
+
+        uxtl                v18.8H, v18.8B              // unpack part 3 to 16-bit
+        uxtl                v19.8H, v19.8B              // unpack part 4 to 16-bit
+
+        smlal               v2.4S, v18.4H, v6.4H        // v2 accumulates srcp[filterPos[2] + {0..3}] * filter[{0..3}]
+        smlal               v3.4S, v19.4H, v7.4H        // v3 accumulates srcp[filterPos[3] + {0..3}] * filter[{0..3}]
+
+        smlal2              v2.4S, v18.8H, v6.8H        // v2 accumulates srcp[filterPos[2] + {4..7}] * filter[{4..7}]
+        smlal2              v3.4S, v19.8H, v7.8H        // v3 accumulates srcp[filterPos[3] + {4..7}] * filter[{4..7}]
+
+        uzp1                v4.4S, v0.4S, v1.4S         // unzip low parts 0 and 1
+        uzp2                v5.4S, v0.4S, v1.4S         // unzip high parts 0 and 1
+        uzp1                v6.4S, v2.4S, v3.4S         // unzip low parts 2 and 3
+        uzp2                v7.4S, v2.4S, v3.4S         // unzip high parts 2 and 3
+
+        add                 v0.4S, v4.4S, v5.4S         // add half of each of part 0 and 1
+        add                 v1.4S, v6.4S, v7.4S         // add half of each of part 2 and 3
+
+        addp                v4.4S, v0.4S, v1.4S         // pairwise add to complete half adds in earlier steps
+
+        subs                w2, w2, #4                  // dstW -= 4
+        sqshrn              v0.4H, v4.4S, #7            // shift and clip the 2x16-bit final values
+        st1                 {v0.4H}, [x1], #8           // write to destination part0123
+        b.gt                1b                          // loop until end of line
+        ret
+endfunc
+
+function ff_hscale8to15_4_neon, export=1
+// x0  SwsContext *c (not used)
+// x1  int16_t *dst
+// x2  int dstW
+// x3  const uint8_t *src
+// x4  const int16_t *filter
+// x5  const int32_t *filterPos
+// x6  int filterSize
+// x8-x15 registers for gathering src data
+
+// v0      madd accumulator 4S
+// v1-v4   filter values (16 bit) 8H
+// v5      madd accumulator 4S
+// v16-v19 src values (8 bit) 8B
+
+// This implementation has 4 sections:
+//  1. Prefetch src data
+//  2. Interleaved prefetching src data and madd
+//  3. Complete madd
+//  4. Complete remaining iterations when dstW % 8 != 0
+
+        add                 sp, sp, #-32                // allocate 32 bytes on the stack
+        cmp                 w2, #16                     // if dstW <16, skip to the last block used for wrapping up
+        b.lt                2f
+
+        // load 8 values from filterPos to be used as offsets into src
+        ldp                 w8, w9,  [x5]               // filterPos[idx + 0], [idx + 1]
+        ldp                 w10, w11, [x5, 8]           // filterPos[idx + 2], [idx + 3]
+        ldp                 w12, w13, [x5, 16]          // filterPos[idx + 4], [idx + 5]
+        ldp                 w14, w15, [x5, 24]          // filterPos[idx + 6], [idx + 7]
+        add                 x5, x5, #32                 // advance filterPos
+
+        // gather random access data from src into contiguous memory
+        ldr                 w8, [x3, w8, UXTW]          // src[filterPos[idx + 0]][0..3]
+        ldr                 w9, [x3, w9, UXTW]          // src[filterPos[idx + 1]][0..3]
+        ldr                 w10, [x3, w10, UXTW]        // src[filterPos[idx + 2]][0..3]
+        ldr                 w11, [x3, w11, UXTW]        // src[filterPos[idx + 3]][0..3]
+        ldr                 w12, [x3, w12, UXTW]        // src[filterPos[idx + 4]][0..3]
+        ldr                 w13, [x3, w13, UXTW]        // src[filterPos[idx + 5]][0..3]
+        ldr                 w14, [x3, w14, UXTW]        // src[filterPos[idx + 6]][0..3]
+        ldr                 w15, [x3, w15, UXTW]        // src[filterPos[idx + 7]][0..3]
+        stp                 w8, w9, [sp]                // *scratch_mem = { src[filterPos[idx + 0]][0..3], src[filterPos[idx + 1]][0..3] }
+        stp                 w10, w11, [sp, 8]           // *scratch_mem = { src[filterPos[idx + 2]][0..3], src[filterPos[idx + 3]][0..3] }
+        stp                 w12, w13, [sp, 16]          // *scratch_mem = { src[filterPos[idx + 4]][0..3], src[filterPos[idx + 5]][0..3] }
+        stp                 w14, w15, [sp, 24]          // *scratch_mem = { src[filterPos[idx + 6]][0..3], src[filterPos[idx + 7]][0..3] }
+
+1:
+        ld4                 {v16.8B, v17.8B, v18.8B, v19.8B}, [sp] // transpose 8 bytes each from src into 4 registers
+
+        // load 8 values from filterPos to be used as offsets into src
+        ldp                 w8, w9,  [x5]               // filterPos[idx + 0][0..3], [idx + 1][0..3], next iteration
+        ldp                 w10, w11, [x5, 8]           // filterPos[idx + 2][0..3], [idx + 3][0..3], next iteration
+        ldp                 w12, w13, [x5, 16]          // filterPos[idx + 4][0..3], [idx + 5][0..3], next iteration
+        ldp                 w14, w15, [x5, 24]          // filterPos[idx + 6][0..3], [idx + 7][0..3], next iteration
+
+        movi                v0.2D, #0                   // Clear madd accumulator for idx 0..3
+        movi                v5.2D, #0                   // Clear madd accumulator for idx 4..7
+
+        ld4                 {v1.8H, v2.8H, v3.8H, v4.8H}, [x4], #64 // load filter idx + 0..7
+
+        add                 x5, x5, #32                 // advance filterPos
+
+        // interleaved SIMD and prefetching intended to keep ld/st and vector pipelines busy
+        uxtl                v16.8H, v16.8B              // unsigned extend long, covert src data to 16-bit
+        uxtl                v17.8H, v17.8B              // unsigned extend long, covert src data to 16-bit
+        ldr                 w8, [x3, w8, UXTW]          // src[filterPos[idx + 0]], next iteration
+        ldr                 w9, [x3, w9, UXTW]          // src[filterPos[idx + 1]], next iteration
+        uxtl                v18.8H, v18.8B              // unsigned extend long, covert src data to 16-bit
+        uxtl                v19.8H, v19.8B              // unsigned extend long, covert src data to 16-bit
+        ldr                 w10, [x3, w10, UXTW]        // src[filterPos[idx + 2]], next iteration
+        ldr                 w11, [x3, w11, UXTW]        // src[filterPos[idx + 3]], next iteration
+
+        smlal               v0.4S, v1.4H, v16.4H        // multiply accumulate inner loop j = 0, idx = 0..3
+        smlal               v0.4S, v2.4H, v17.4H        // multiply accumulate inner loop j = 1, idx = 0..3
+        ldr                 w12, [x3, w12, UXTW]        // src[filterPos[idx + 4]], next iteration
+        ldr                 w13, [x3, w13, UXTW]        // src[filterPos[idx + 5]], next iteration
+        smlal               v0.4S, v3.4H, v18.4H        // multiply accumulate inner loop j = 2, idx = 0..3
+        smlal               v0.4S, v4.4H, v19.4H        // multiply accumulate inner loop j = 3, idx = 0..3
+        ldr                 w14, [x3, w14, UXTW]        // src[filterPos[idx + 6]], next iteration
+        ldr                 w15, [x3, w15, UXTW]        // src[filterPos[idx + 7]], next iteration
+
+        smlal2              v5.4S, v1.8H, v16.8H        // multiply accumulate inner loop j = 0, idx = 4..7
+        smlal2              v5.4S, v2.8H, v17.8H        // multiply accumulate inner loop j = 1, idx = 4..7
+        stp                 w8, w9, [sp]                // *scratch_mem = { src[filterPos[idx + 0]][0..3], src[filterPos[idx + 1]][0..3] }
+        stp                 w10, w11, [sp, 8]           // *scratch_mem = { src[filterPos[idx + 2]][0..3], src[filterPos[idx + 3]][0..3] }
+        smlal2              v5.4S, v3.8H, v18.8H        // multiply accumulate inner loop j = 2, idx = 4..7
+        smlal2              v5.4S, v4.8H, v19.8H        // multiply accumulate inner loop j = 3, idx = 4..7
+        stp                 w12, w13, [sp, 16]          // *scratch_mem = { src[filterPos[idx + 4]][0..3], src[filterPos[idx + 5]][0..3] }
+        stp                 w14, w15, [sp, 24]          // *scratch_mem = { src[filterPos[idx + 6]][0..3], src[filterPos[idx + 7]][0..3] }
+
+        sub                 w2, w2, #8                  // dstW -= 8
+        sqshrn              v0.4H, v0.4S, #7            // shift and clip the 2x16-bit final values
+        sqshrn              v1.4H, v5.4S, #7            // shift and clip the 2x16-bit final values
+        st1                 {v0.4H, v1.4H}, [x1], #16   // write to dst[idx + 0..7]
+        cmp                 w2, #16                     // continue on main loop if there are at least 16 iterations left
+        b.ge                1b
+
+        // last full iteration
+        ld4                 {v16.8B, v17.8B, v18.8B, v19.8B}, [sp]
+        ld4                 {v1.8H, v2.8H, v3.8H, v4.8H}, [x4], #64 // load filter idx + 0..7
+
+        movi                v0.2D, #0                   // Clear madd accumulator for idx 0..3
+        movi                v5.2D, #0                   // Clear madd accumulator for idx 4..7
+
+        uxtl                v16.8H, v16.8B              // unsigned extend long, covert src data to 16-bit
+        uxtl                v17.8H, v17.8B              // unsigned extend long, covert src data to 16-bit
+        uxtl                v18.8H, v18.8B              // unsigned extend long, covert src data to 16-bit
+        uxtl                v19.8H, v19.8B              // unsigned extend long, covert src data to 16-bit
+
+        smlal               v0.4S, v1.4H, v16.4H        // multiply accumulate inner loop j = 0, idx = 0..3
+        smlal               v0.4S, v2.4H, v17.4H        // multiply accumulate inner loop j = 1, idx = 0..3
+        smlal               v0.4S, v3.4H, v18.4H        // multiply accumulate inner loop j = 2, idx = 0..3
+        smlal               v0.4S, v4.4H, v19.4H        // multiply accumulate inner loop j = 3, idx = 0..3
+
+        smlal2              v5.4S, v1.8H, v16.8H        // multiply accumulate inner loop j = 0, idx = 4..7
+        smlal2              v5.4S, v2.8H, v17.8H        // multiply accumulate inner loop j = 1, idx = 4..7
+        smlal2              v5.4S, v3.8H, v18.8H        // multiply accumulate inner loop j = 2, idx = 4..7
+        smlal2              v5.4S, v4.8H, v19.8H        // multiply accumulate inner loop j = 3, idx = 4..7
+
+        subs                w2, w2, #8                  // dstW -= 8
+        sqshrn              v0.4H, v0.4S, #7            // shift and clip the 2x16-bit final values
+        sqshrn              v1.4H, v5.4S, #7            // shift and clip the 2x16-bit final values
+        st1                 {v0.4H, v1.4H}, [x1], #16   // write to dst[idx + 0..7]
+
+        cbnz                w2, 2f                      // if >0 iterations remain, jump to the wrap up section
+
+        add                 sp, sp, #32                 // clean up stack
+        ret
+
+        // finish up when dstW % 8 != 0 or dstW < 16
+2:
+        // load src
+        ldr                 w8, [x5], #4                // filterPos[i]
+        ldr                 w9, [x3, w8, UXTW]          // src[filterPos[i] + 0..3]
+        ins                 v5.S[0], w9                 // move to simd register
+        // load filter
+        ld1                 {v6.4H}, [x4], #8           // filter[filterSize * i + 0..3]
+
+        uxtl                v5.8H, v5.8B                // unsigned exten long, convert src data to 16-bit
+        smull               v0.4S, v5.4H, v6.4H         // 4 iterations of src[...] * filter[...]
+        addp                v0.4S, v0.4S, v0.4S         // accumulate the smull results
+        addp                v0.4S, v0.4S, v0.4S         // accumulate the smull results
+        sqshrn              v0.4H, v0.4S, #7            // shift and clip the 2x16-bit final values
+        mov                 w10, v0.S[0]                // move back to general register (only one value from simd reg is used)
+        strh                w10, [x1], #2               // dst[i] = ...
+        sub                 w2, w2, #1                  // dstW--
+        cbnz                w2, 2b
+
+        add                 sp, sp, #32                 // clean up stack
+        ret
+endfunc
diff --git a/libswscale/aarch64/swscale.c b/libswscale/aarch64/swscale.c
index 09d0a7130e..2ea4ccb3a6 100644
--- a/libswscale/aarch64/swscale.c
+++ b/libswscale/aarch64/swscale.c
@@ -22,25 +22,48 @@ 
 #include "libswscale/swscale_internal.h"
 #include "libavutil/aarch64/cpu.h"
 
-void ff_hscale_8_to_15_neon(SwsContext *c, int16_t *dst, int dstW,
-                            const uint8_t *src, const int16_t *filter,
-                            const int32_t *filterPos, int filterSize);
+#define SCALE_FUNC(filter_n, from_bpc, to_bpc, opt) \
+void ff_hscale ## from_bpc ## to ## to_bpc ## _ ## filter_n ## _ ## opt( \
+                                                SwsContext *c, int16_t *data, \
+                                                int dstW, const uint8_t *src, \
+                                                const int16_t *filter, \
+                                                const int32_t *filterPos, int filterSize)
+#define SCALE_FUNCS(filter_n, opt) \
+    SCALE_FUNC(filter_n,  8, 15, opt);
+#define ALL_SCALE_FUNCS(opt) \
+    SCALE_FUNCS(4, opt); \
+    SCALE_FUNCS(8, opt); \
+    SCALE_FUNCS(X8, opt)
+
+ALL_SCALE_FUNCS(neon);
 
 void ff_yuv2planeX_8_neon(const int16_t *filter, int filterSize,
                           const int16_t **src, uint8_t *dest, int dstW,
                           const uint8_t *dither, int offset);
 
+#define ASSIGN_SCALE_FUNC2(hscalefn, filtersize, opt) do {              \
+    if (c->srcBpc == 8 && c->dstBpc <= 14) {                            \
+      hscalefn =                                                        \
+        ff_hscale8to15_ ## filtersize ## _ ## opt;                      \
+    }                                                                   \
+} while (0)
+
+#define ASSIGN_SCALE_FUNC(hscalefn, filtersize, opt)                    \
+  switch (filtersize) {                                                 \
+  case 4:  ASSIGN_SCALE_FUNC2(hscalefn, 4, opt); break;                 \
+  case 8:  ASSIGN_SCALE_FUNC2(hscalefn, 8, opt); break;                 \
+  default: if (filtersize % 8 == 0)                                     \
+               ASSIGN_SCALE_FUNC2(hscalefn, X8, opt);                   \
+           break;                                                       \
+  }
+
 av_cold void ff_sws_init_swscale_aarch64(SwsContext *c)
 {
     int cpu_flags = av_get_cpu_flags();
 
     if (have_neon(cpu_flags)) {
-        if (c->srcBpc == 8 && c->dstBpc <= 14 &&
-            (c->hLumFilterSize % 8) == 0 &&
-            (c->hChrFilterSize % 8) == 0)
-        {
-            c->hyScale = c->hcScale = ff_hscale_8_to_15_neon;
-        }
+        ASSIGN_SCALE_FUNC(c->hyScale, c->hLumFilterSize, neon);
+        ASSIGN_SCALE_FUNC(c->hcScale, c->hChrFilterSize, neon);
         if (c->dstBpc == 8) {
             c->yuv2planeX = ff_yuv2planeX_8_neon;
         }
diff --git a/libswscale/utils.c b/libswscale/utils.c
index c5ea8853d5..2f2b8e73a9 100644
--- a/libswscale/utils.c
+++ b/libswscale/utils.c
@@ -1825,7 +1825,7 @@  av_cold int sws_init_context(SwsContext *c, SwsFilter *srcFilter,
         {
             const int filterAlign = X86_MMX(cpu_flags)     ? 4 :
                                     PPC_ALTIVEC(cpu_flags) ? 8 :
-                                    have_neon(cpu_flags)   ? 8 : 1;
+                                    have_neon(cpu_flags)   ? 4 : 1;
 
             if ((ret = initFilter(&c->hLumFilter, &c->hLumFilterPos,
                            &c->hLumFilterSize, c->lumXInc,