diff mbox series

[FFmpeg-devel,1/5] avcodec/vvc_mc: split the SAD dsp prototype into one function per blocksize width

Message ID 20240523122716.2158-1-jamrial@gmail.com
State New
Headers show
Series [FFmpeg-devel,1/5] avcodec/vvc_mc: split the SAD dsp prototype into one function per blocksize width | expand

Checks

Context Check Description
yinshiyou/make_loongarch64 success Make finished
yinshiyou/make_fate_loongarch64 success Make fate finished
andriy/make_x86 success Make finished
andriy/make_fate_x86 success Make fate finished

Commit Message

James Almer May 23, 2024, 12:27 p.m. UTC
Signed-off-by: James Almer <jamrial@gmail.com>
---
 libavcodec/vvc/dsp.h             |  2 +-
 libavcodec/vvc/inter.c           |  6 ++++--
 libavcodec/vvc/inter_template.c  |  6 +++++-
 libavcodec/x86/vvc/vvc_sad.asm   | 32 ++++++++++++++++++++++++++------
 libavcodec/x86/vvc/vvcdsp_init.c | 22 +++++++++++++++++-----
 tests/checkasm/vvc_mc.c          |  3 ++-
 6 files changed, 55 insertions(+), 16 deletions(-)

Comments

Andreas Rheinhardt May 23, 2024, 12:35 p.m. UTC | #1
James Almer:
> Signed-off-by: James Almer <jamrial@gmail.com>
> ---

The commit message should explain what the advantage of this is.
Particularly, what is the advantage of this over jumping in the function
to based upon blocksize vs. selecting an appropriate function in the
generic code (even when these functions turn out all the same as is for
the C version).

>  libavcodec/vvc/dsp.h             |  2 +-
>  libavcodec/vvc/inter.c           |  6 ++++--
>  libavcodec/vvc/inter_template.c  |  6 +++++-
>  libavcodec/x86/vvc/vvc_sad.asm   | 32 ++++++++++++++++++++++++++------
>  libavcodec/x86/vvc/vvcdsp_init.c | 22 +++++++++++++++++-----
>  tests/checkasm/vvc_mc.c          |  3 ++-
>  6 files changed, 55 insertions(+), 16 deletions(-)
> 
> diff --git a/libavcodec/vvc/dsp.h b/libavcodec/vvc/dsp.h
> index 1f14096c41..55c4c81f53 100644
> --- a/libavcodec/vvc/dsp.h
> +++ b/libavcodec/vvc/dsp.h
> @@ -99,7 +99,7 @@ typedef struct VVCInterDSPContext {
>  
>      void (*apply_bdof)(uint8_t *dst, ptrdiff_t dst_stride, int16_t *src0, int16_t *src1, int block_w, int block_h);
>  
> -    int (*sad)(const int16_t *src0, const int16_t *src1, int dx, int dy, int block_w, int block_h);
> +    int (*sad[5])(const int16_t *src0, const int16_t *src1, int dx, int dy, int block_w, int block_h);
>      void (*dmvr[2][2])(int16_t *dst, const uint8_t *src, ptrdiff_t src_stride, int height,
>          intptr_t mx, intptr_t my, int width);
>  } VVCInterDSPContext;
> diff --git a/libavcodec/vvc/inter.c b/libavcodec/vvc/inter.c
> index e1011b4fa1..0214e46634 100644
> --- a/libavcodec/vvc/inter.c
> +++ b/libavcodec/vvc/inter.c
> @@ -740,6 +740,8 @@ static void dmvr_mv_refine(VVCLocalContext *lc, MvField *mvf, MvField *orig_mv,
>      const AVFrame *ref0, const AVFrame *ref1, const int x_off, const int y_off, const int block_w, const int block_h)
>  {
>      const VVCFrameContext *fc   = lc->fc;
> +    static const uint8_t sad_tab[16] = { 0, 1, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4 };
> +    const int tab               = sad_tab[(FFALIGN(block_w, 8) >> 3) - 1];
>      const int sr_range          = 2;
>      const AVFrame *ref[]        = { ref0, ref1 };
>      int16_t *tmp[]              = { lc->tmp, lc->tmp1 };
> @@ -763,7 +765,7 @@ static void dmvr_mv_refine(VVCLocalContext *lc, MvField *mvf, MvField *orig_mv,
>          fc->vvcdsp.inter.dmvr[!!my][!!mx](tmp[i], src, src_stride, pred_h, mx, my, pred_w);
>      }
>  
> -    min_sad = fc->vvcdsp.inter.sad(tmp[L0], tmp[L1], dx, dy, block_w, block_h);
> +    min_sad = fc->vvcdsp.inter.sad[tab](tmp[L0], tmp[L1], dx, dy, block_w, block_h);
>      min_sad -= min_sad >> 2;
>      sad[dy][dx] = min_sad;
>  
> @@ -773,7 +775,7 @@ static void dmvr_mv_refine(VVCLocalContext *lc, MvField *mvf, MvField *orig_mv,
>          for (dy = 0; dy < SAD_ARRAY_SIZE; dy++) {
>              for (dx = 0; dx < SAD_ARRAY_SIZE; dx++) {
>                  if (dx != sr_range || dy != sr_range) {
> -                    sad[dy][dx] = fc->vvcdsp.inter.sad(lc->tmp, lc->tmp1, dx, dy, block_w, block_h);
> +                    sad[dy][dx] = fc->vvcdsp.inter.sad[tab](lc->tmp, lc->tmp1, dx, dy, block_w, block_h);
>                      if (sad[dy][dx] < min_sad) {
>                          min_sad = sad[dy][dx];
>                          min_dx = dx;
> diff --git a/libavcodec/vvc/inter_template.c b/libavcodec/vvc/inter_template.c
> index a8068f4ba8..34485321d3 100644
> --- a/libavcodec/vvc/inter_template.c
> +++ b/libavcodec/vvc/inter_template.c
> @@ -626,7 +626,11 @@ static void FUNC(ff_vvc_inter_dsp_init)(VVCInterDSPContext *const inter)
>      inter->apply_prof_uni_w     = FUNC(apply_prof_uni_w);
>      inter->apply_bdof           = FUNC(apply_bdof);
>      inter->prof_grad_filter     = FUNC(prof_grad_filter);
> -    inter->sad                  = vvc_sad;
> +    inter->sad[0]               =
> +    inter->sad[1]               =
> +    inter->sad[2]               =
> +    inter->sad[3]               =
> +    inter->sad[4]               = vvc_sad;
>  }
>  
>  #undef FUNCS
> diff --git a/libavcodec/x86/vvc/vvc_sad.asm b/libavcodec/x86/vvc/vvc_sad.asm
> index b468d89ac2..a20818530f 100644
> --- a/libavcodec/x86/vvc/vvc_sad.asm
> +++ b/libavcodec/x86/vvc/vvc_sad.asm
> @@ -51,7 +51,7 @@ SECTION .text
>  
>  INIT_YMM avx2
>  
> -cglobal vvc_sad, 6, 9, 5, src1, src2, dx, dy, block_w, block_h, off1, off2, row_idx
> +cglobal vvc_sad_8, 6, 9, 5, src1, src2, dx, dy, block_w, block_h, off1, off2, row_idx
>      movsxdifnidn    dxq, dxd
>      movsxdifnidn    dyq, dyd
>  
> @@ -76,10 +76,6 @@ cglobal vvc_sad, 6, 9, 5, src1, src2, dx, dy, block_w, block_h, off1, off2, row_
>      pxor               m3, m3
>      vpbroadcastd       m4, [pw_1]
>  
> -    cmp          block_wd, 16
> -    jge    vvc_sad_16_128
> -
> -    vvc_sad_8:
>          .loop_height:
>          movu              xm0, [src1q]
>          vinserti128        m0, m0, [src1q + MAX_PB_SIZE * ROWS * 2], 1
> @@ -100,7 +96,31 @@ cglobal vvc_sad, 6, 9, 5, src1, src2, dx, dy, block_w, block_h, off1, off2, row_
>          movd          eax, xm0
>      RET
>  
> -    vvc_sad_16_128:
> +cglobal vvc_sad_16, 6, 9, 5, src1, src2, dx, dy, block_w, block_h, off1, off2, row_idx
> +    movsxdifnidn    dxq, dxd
> +    movsxdifnidn    dyq, dyd
> +
> +    sub             dxq, 2
> +    sub             dyq, 2
> +
> +    mov             off1q, 2
> +    mov             off2q, 2
> +
> +    add             off1q, dyq
> +    sub             off2q, dyq
> +
> +    shl             off1q, 7
> +    shl             off2q, 7
> +
> +    add             off1q, dxq
> +    sub             off2q, dxq
> +
> +    lea             src1q, [src1q + off1q * 2 + 2 * 2]
> +    lea             src2q, [src2q + off2q * 2 + 2 * 2]
> +
> +    pxor               m3, m3
> +    vpbroadcastd       m4, [pw_1]
> +
>          sar      block_wd, 4
>          .loop_height:
>          mov         off1q, src1q
> diff --git a/libavcodec/x86/vvc/vvcdsp_init.c b/libavcodec/x86/vvc/vvcdsp_init.c
> index 4b4a2aa937..bd60963432 100644
> --- a/libavcodec/x86/vvc/vvcdsp_init.c
> +++ b/libavcodec/x86/vvc/vvcdsp_init.c
> @@ -312,8 +312,20 @@ ALF_FUNCS(16, 12, avx2)
>      c->alf.classify       = ff_vvc_alf_classify_##bd##_avx2;         \
>  } while (0)
>  
> -int ff_vvc_sad_avx2(const int16_t *src0, const int16_t *src1, int dx, int dy, int block_w, int block_h);
> -#define SAD_INIT() c->inter.sad = ff_vvc_sad_avx2
> +#define SAD_PROTOTYPE(w, opt)                                        \
> +int bf(ff_vvc_sad, w, opt)(const int16_t *src0, const int16_t *src1, \
> +                           int dx, int dy, int block_w, int block_h) \
> +
> +SAD_PROTOTYPE(8,   avx2);
> +SAD_PROTOTYPE(16,  avx2);
> +
> +#define SAD_INIT(opt) do {                   \
> +    c->inter.sad[0] = ff_vvc_sad_8_##opt;    \
> +    c->inter.sad[1] =                        \
> +    c->inter.sad[2] =                        \
> +    c->inter.sad[3] =                        \
> +    c->inter.sad[4] = ff_vvc_sad_16_##opt;   \
> +} while (0)
>  #endif
>  
>  void ff_vvc_dsp_init_x86(VVCDSPContext *const c, const int bd)
> @@ -330,7 +342,7 @@ void ff_vvc_dsp_init_x86(VVCDSPContext *const c, const int bd)
>              ALF_INIT(8);
>              AVG_INIT(8, avx2);
>              MC_LINKS_AVX2(8);
> -            SAD_INIT();
> +            SAD_INIT(avx2);
>          }
>          break;
>      case 10:
> @@ -342,7 +354,7 @@ void ff_vvc_dsp_init_x86(VVCDSPContext *const c, const int bd)
>              AVG_INIT(10, avx2);
>              MC_LINKS_AVX2(10);
>              MC_LINKS_16BPC_AVX2(10);
> -            SAD_INIT();
> +            SAD_INIT(avx2);
>          }
>          break;
>      case 12:
> @@ -354,7 +366,7 @@ void ff_vvc_dsp_init_x86(VVCDSPContext *const c, const int bd)
>              AVG_INIT(12, avx2);
>              MC_LINKS_AVX2(12);
>              MC_LINKS_16BPC_AVX2(12);
> -            SAD_INIT();
> +            SAD_INIT(avx2);
>          }
>          break;
>      default:
> diff --git a/tests/checkasm/vvc_mc.c b/tests/checkasm/vvc_mc.c
> index 1e889e2cff..deae1014d2 100644
> --- a/tests/checkasm/vvc_mc.c
> +++ b/tests/checkasm/vvc_mc.c
> @@ -327,6 +327,7 @@ static void check_avg(void)
>  static void check_vvc_sad(void)
>  {
>      const int bit_depth = 10;
> +    static const uint8_t sad_tab[16] = { 0, 1, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4 };
>      VVCDSPContext c;
>      LOCAL_ALIGNED_32(uint16_t, src0, [MAX_CTU_SIZE * MAX_CTU_SIZE * 4]);
>      LOCAL_ALIGNED_32(uint16_t, src1, [MAX_CTU_SIZE * MAX_CTU_SIZE * 4]);
> @@ -341,7 +342,7 @@ static void check_vvc_sad(void)
>          for (int w = 8; w <= MAX_CTU_SIZE; w *= 2) {
>              for(int offy = 0; offy <= 4; offy++) {
>                  for(int offx = 0; offx <= 4; offx++) {
> -                    if(check_func(c.inter.sad, "sad_%dx%d", w, h)) {
> +                    if(check_func(c.inter.sad[sad_tab[(w >> 3) - 1]], "sad_%dx%d", w, h)) {
>                          int result0;
>                          int result1;
>
diff mbox series

Patch

diff --git a/libavcodec/vvc/dsp.h b/libavcodec/vvc/dsp.h
index 1f14096c41..55c4c81f53 100644
--- a/libavcodec/vvc/dsp.h
+++ b/libavcodec/vvc/dsp.h
@@ -99,7 +99,7 @@  typedef struct VVCInterDSPContext {
 
     void (*apply_bdof)(uint8_t *dst, ptrdiff_t dst_stride, int16_t *src0, int16_t *src1, int block_w, int block_h);
 
-    int (*sad)(const int16_t *src0, const int16_t *src1, int dx, int dy, int block_w, int block_h);
+    int (*sad[5])(const int16_t *src0, const int16_t *src1, int dx, int dy, int block_w, int block_h);
     void (*dmvr[2][2])(int16_t *dst, const uint8_t *src, ptrdiff_t src_stride, int height,
         intptr_t mx, intptr_t my, int width);
 } VVCInterDSPContext;
diff --git a/libavcodec/vvc/inter.c b/libavcodec/vvc/inter.c
index e1011b4fa1..0214e46634 100644
--- a/libavcodec/vvc/inter.c
+++ b/libavcodec/vvc/inter.c
@@ -740,6 +740,8 @@  static void dmvr_mv_refine(VVCLocalContext *lc, MvField *mvf, MvField *orig_mv,
     const AVFrame *ref0, const AVFrame *ref1, const int x_off, const int y_off, const int block_w, const int block_h)
 {
     const VVCFrameContext *fc   = lc->fc;
+    static const uint8_t sad_tab[16] = { 0, 1, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4 };
+    const int tab               = sad_tab[(FFALIGN(block_w, 8) >> 3) - 1];
     const int sr_range          = 2;
     const AVFrame *ref[]        = { ref0, ref1 };
     int16_t *tmp[]              = { lc->tmp, lc->tmp1 };
@@ -763,7 +765,7 @@  static void dmvr_mv_refine(VVCLocalContext *lc, MvField *mvf, MvField *orig_mv,
         fc->vvcdsp.inter.dmvr[!!my][!!mx](tmp[i], src, src_stride, pred_h, mx, my, pred_w);
     }
 
-    min_sad = fc->vvcdsp.inter.sad(tmp[L0], tmp[L1], dx, dy, block_w, block_h);
+    min_sad = fc->vvcdsp.inter.sad[tab](tmp[L0], tmp[L1], dx, dy, block_w, block_h);
     min_sad -= min_sad >> 2;
     sad[dy][dx] = min_sad;
 
@@ -773,7 +775,7 @@  static void dmvr_mv_refine(VVCLocalContext *lc, MvField *mvf, MvField *orig_mv,
         for (dy = 0; dy < SAD_ARRAY_SIZE; dy++) {
             for (dx = 0; dx < SAD_ARRAY_SIZE; dx++) {
                 if (dx != sr_range || dy != sr_range) {
-                    sad[dy][dx] = fc->vvcdsp.inter.sad(lc->tmp, lc->tmp1, dx, dy, block_w, block_h);
+                    sad[dy][dx] = fc->vvcdsp.inter.sad[tab](lc->tmp, lc->tmp1, dx, dy, block_w, block_h);
                     if (sad[dy][dx] < min_sad) {
                         min_sad = sad[dy][dx];
                         min_dx = dx;
diff --git a/libavcodec/vvc/inter_template.c b/libavcodec/vvc/inter_template.c
index a8068f4ba8..34485321d3 100644
--- a/libavcodec/vvc/inter_template.c
+++ b/libavcodec/vvc/inter_template.c
@@ -626,7 +626,11 @@  static void FUNC(ff_vvc_inter_dsp_init)(VVCInterDSPContext *const inter)
     inter->apply_prof_uni_w     = FUNC(apply_prof_uni_w);
     inter->apply_bdof           = FUNC(apply_bdof);
     inter->prof_grad_filter     = FUNC(prof_grad_filter);
-    inter->sad                  = vvc_sad;
+    inter->sad[0]               =
+    inter->sad[1]               =
+    inter->sad[2]               =
+    inter->sad[3]               =
+    inter->sad[4]               = vvc_sad;
 }
 
 #undef FUNCS
diff --git a/libavcodec/x86/vvc/vvc_sad.asm b/libavcodec/x86/vvc/vvc_sad.asm
index b468d89ac2..a20818530f 100644
--- a/libavcodec/x86/vvc/vvc_sad.asm
+++ b/libavcodec/x86/vvc/vvc_sad.asm
@@ -51,7 +51,7 @@  SECTION .text
 
 INIT_YMM avx2
 
-cglobal vvc_sad, 6, 9, 5, src1, src2, dx, dy, block_w, block_h, off1, off2, row_idx
+cglobal vvc_sad_8, 6, 9, 5, src1, src2, dx, dy, block_w, block_h, off1, off2, row_idx
     movsxdifnidn    dxq, dxd
     movsxdifnidn    dyq, dyd
 
@@ -76,10 +76,6 @@  cglobal vvc_sad, 6, 9, 5, src1, src2, dx, dy, block_w, block_h, off1, off2, row_
     pxor               m3, m3
     vpbroadcastd       m4, [pw_1]
 
-    cmp          block_wd, 16
-    jge    vvc_sad_16_128
-
-    vvc_sad_8:
         .loop_height:
         movu              xm0, [src1q]
         vinserti128        m0, m0, [src1q + MAX_PB_SIZE * ROWS * 2], 1
@@ -100,7 +96,31 @@  cglobal vvc_sad, 6, 9, 5, src1, src2, dx, dy, block_w, block_h, off1, off2, row_
         movd          eax, xm0
     RET
 
-    vvc_sad_16_128:
+cglobal vvc_sad_16, 6, 9, 5, src1, src2, dx, dy, block_w, block_h, off1, off2, row_idx
+    movsxdifnidn    dxq, dxd
+    movsxdifnidn    dyq, dyd
+
+    sub             dxq, 2
+    sub             dyq, 2
+
+    mov             off1q, 2
+    mov             off2q, 2
+
+    add             off1q, dyq
+    sub             off2q, dyq
+
+    shl             off1q, 7
+    shl             off2q, 7
+
+    add             off1q, dxq
+    sub             off2q, dxq
+
+    lea             src1q, [src1q + off1q * 2 + 2 * 2]
+    lea             src2q, [src2q + off2q * 2 + 2 * 2]
+
+    pxor               m3, m3
+    vpbroadcastd       m4, [pw_1]
+
         sar      block_wd, 4
         .loop_height:
         mov         off1q, src1q
diff --git a/libavcodec/x86/vvc/vvcdsp_init.c b/libavcodec/x86/vvc/vvcdsp_init.c
index 4b4a2aa937..bd60963432 100644
--- a/libavcodec/x86/vvc/vvcdsp_init.c
+++ b/libavcodec/x86/vvc/vvcdsp_init.c
@@ -312,8 +312,20 @@  ALF_FUNCS(16, 12, avx2)
     c->alf.classify       = ff_vvc_alf_classify_##bd##_avx2;         \
 } while (0)
 
-int ff_vvc_sad_avx2(const int16_t *src0, const int16_t *src1, int dx, int dy, int block_w, int block_h);
-#define SAD_INIT() c->inter.sad = ff_vvc_sad_avx2
+#define SAD_PROTOTYPE(w, opt)                                        \
+int bf(ff_vvc_sad, w, opt)(const int16_t *src0, const int16_t *src1, \
+                           int dx, int dy, int block_w, int block_h) \
+
+SAD_PROTOTYPE(8,   avx2);
+SAD_PROTOTYPE(16,  avx2);
+
+#define SAD_INIT(opt) do {                   \
+    c->inter.sad[0] = ff_vvc_sad_8_##opt;    \
+    c->inter.sad[1] =                        \
+    c->inter.sad[2] =                        \
+    c->inter.sad[3] =                        \
+    c->inter.sad[4] = ff_vvc_sad_16_##opt;   \
+} while (0)
 #endif
 
 void ff_vvc_dsp_init_x86(VVCDSPContext *const c, const int bd)
@@ -330,7 +342,7 @@  void ff_vvc_dsp_init_x86(VVCDSPContext *const c, const int bd)
             ALF_INIT(8);
             AVG_INIT(8, avx2);
             MC_LINKS_AVX2(8);
-            SAD_INIT();
+            SAD_INIT(avx2);
         }
         break;
     case 10:
@@ -342,7 +354,7 @@  void ff_vvc_dsp_init_x86(VVCDSPContext *const c, const int bd)
             AVG_INIT(10, avx2);
             MC_LINKS_AVX2(10);
             MC_LINKS_16BPC_AVX2(10);
-            SAD_INIT();
+            SAD_INIT(avx2);
         }
         break;
     case 12:
@@ -354,7 +366,7 @@  void ff_vvc_dsp_init_x86(VVCDSPContext *const c, const int bd)
             AVG_INIT(12, avx2);
             MC_LINKS_AVX2(12);
             MC_LINKS_16BPC_AVX2(12);
-            SAD_INIT();
+            SAD_INIT(avx2);
         }
         break;
     default:
diff --git a/tests/checkasm/vvc_mc.c b/tests/checkasm/vvc_mc.c
index 1e889e2cff..deae1014d2 100644
--- a/tests/checkasm/vvc_mc.c
+++ b/tests/checkasm/vvc_mc.c
@@ -327,6 +327,7 @@  static void check_avg(void)
 static void check_vvc_sad(void)
 {
     const int bit_depth = 10;
+    static const uint8_t sad_tab[16] = { 0, 1, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4 };
     VVCDSPContext c;
     LOCAL_ALIGNED_32(uint16_t, src0, [MAX_CTU_SIZE * MAX_CTU_SIZE * 4]);
     LOCAL_ALIGNED_32(uint16_t, src1, [MAX_CTU_SIZE * MAX_CTU_SIZE * 4]);
@@ -341,7 +342,7 @@  static void check_vvc_sad(void)
         for (int w = 8; w <= MAX_CTU_SIZE; w *= 2) {
             for(int offy = 0; offy <= 4; offy++) {
                 for(int offx = 0; offx <= 4; offx++) {
-                    if(check_func(c.inter.sad, "sad_%dx%d", w, h)) {
+                    if(check_func(c.inter.sad[sad_tab[(w >> 3) - 1]], "sad_%dx%d", w, h)) {
                         int result0;
                         int result1;