diff mbox series

[FFmpeg-devel,3/3,GSoC] Add x86-avx2 optimization for dnn_execute_layer_conv2d

Message ID 20200831170341.879003-3-xujunzz@sjtu.edu.cn
State New
Headers show
Series [FFmpeg-devel,1/3,GSoC] Add mutithread function for dnn_backend_native_layer_conv2d.c
Related show

Checks

Context Check Description
andriy/default pending
andriy/make success Make finished
andriy/make_fate success Make fate finished

Commit Message

Xu Jun Aug. 31, 2020, 5:03 p.m. UTC
From: Xu Jun <xujunzz@sjtu.edu.cn>

Can be tested with command "./ffmpeg_g -i test_1s.mp4 -vf \
format=yuvj420p,dnn_processing=dnn_backend=native:model= \
espcn.model:input=x:output=y -y sr_native.mp4 -benchmark"

before patch: utime=826.044s stime=0.550s rtime=39.680s
after patch:  utime=545.137s stime=0.467s rtime=27.113s

Signed-off-by: Xu Jun <xujunzz@sjtu.edu.cn>
---
 .../dnn/dnn_backend_native_layer_conv2d.c     |  10 +-
 .../dnn_backend_native_layer_conv2d_x86.asm   | 121 ++++++++++++++++++
 2 files changed, 130 insertions(+), 1 deletion(-)
diff mbox series

Patch

diff --git a/libavfilter/dnn/dnn_backend_native_layer_conv2d.c b/libavfilter/dnn/dnn_backend_native_layer_conv2d.c
index 92cc5313dc..089f724156 100644
--- a/libavfilter/dnn/dnn_backend_native_layer_conv2d.c
+++ b/libavfilter/dnn/dnn_backend_native_layer_conv2d.c
@@ -46,6 +46,7 @@  typedef struct execute_data{
     float *kernel;
 } execute_data;
 
+void ff_dnn_execute_layer_conv2d_avx2(execute_data *execute_data);
 void ff_dnn_execute_layer_conv2d_sse4(execute_data *execute_data);
 void ff_dnn_execute_layer_conv2d_c(execute_data *execute_data);
 
@@ -243,7 +244,12 @@  static void * dnn_execute_layer_conv2d_thread(void *threadarg)
     execute_data->filter_size = filter_size;
     execute_data->filter_linesize = filter_linesize;
     if ((thread_data->step >= 4) && (conv_params->input_num >= 4)) {
-        ff_dnn_execute_layer_conv2d_sse4(execute_data);
+        if ((thread_data->step == 8) && (conv_params->input_num >= 8)) {
+            ff_dnn_execute_layer_conv2d_avx2(execute_data);
+        }
+        else {
+            ff_dnn_execute_layer_conv2d_sse4(execute_data);
+        }
     }
     else {
         ff_dnn_execute_layer_conv2d_c(execute_data);
@@ -305,6 +311,8 @@  int dnn_execute_layer_conv2d(DnnOperand *operands, const int32_t *input_operand_
         int cpu_flags = av_get_cpu_flags();
         if (EXTERNAL_SSE4(cpu_flags))
             thread_data->step = 4;
+        if (EXTERNAL_AVX2(cpu_flags))
+            thread_data->step = 8;
     #endif
 
     //create threads
diff --git a/libavfilter/dnn/dnn_backend_native_layer_conv2d_x86.asm b/libavfilter/dnn/dnn_backend_native_layer_conv2d_x86.asm
index dc781d42e5..7c7285c4c5 100644
--- a/libavfilter/dnn/dnn_backend_native_layer_conv2d_x86.asm
+++ b/libavfilter/dnn/dnn_backend_native_layer_conv2d_x86.asm
@@ -210,5 +210,126 @@  cglobal dnn_execute_layer_conv2d, 8, 15, 3, execute_data,\
     cmp yd, tmp1d
     jl .loop_y
 
+    RET
+
+; void ff_dnn_execute_layer_conv2d_avx4(execute_data *execute_data);
+
+INIT_YMM avx2
+cglobal dnn_execute_layer_conv2d, 8, 15, 3, execute_data,\
+    x, y, n_filter, cha, kernel_x, kernel_y, x_pos, y_pos, kernel_pos,\
+    input, output, kernel, tmp1, tmp2
+
+%define thread_start [execute_dataq]
+%define thread_end [execute_dataq + 1 * 4]
+%define input_num [execute_dataq + 2 * 4]
+%define output_num [execute_dataq + 3 * 4]
+%define kernel_size [execute_dataq + 4 * 4]
+%define padding_method [execute_dataq + 5 * 4]
+%define dilation [execute_dataq + 6 * 4]
+%define pad_size [execute_dataq + 7 * 4]
+%define width [execute_dataq + 8 * 4]
+%define height [execute_dataq + 9 * 4]
+%define radius [execute_dataq + 10 * 4]
+%define src_linesize [execute_dataq + 11 * 4]
+%define filter_size [execute_dataq + 12 * 4]
+%define filter_linesize [execute_dataq + 13 * 4]
+%define SAME_CLAMP_TO_EDGE 2
+
+    mov inputq, [execute_dataq + 14 * 4]
+    mov outputq, [execute_dataq + 14 * 4 + 8]
+    mov kernelq, [execute_dataq + 14 * 4 + 2 * 8]
+
+    mov yd, thread_start
+.loop_y:
+    mov xd, pad_size
+    .loop_x:
+        xor n_filterd, n_filterd
+        xor kernel_posq, kernel_posq
+        .loop_filter:
+            xorps m2, m2
+            xor kernel_yd, kernel_yd
+
+            mov tmp1d, kernel_yd
+            sub tmp1d, radius
+            mov y_posd, dilation
+            imul y_posd, tmp1d
+            add y_posd, yd
+
+            .loop_kery:
+                xor kernel_xd, kernel_xd
+
+                mov tmp1d, kernel_xd
+                sub tmp1d, radius
+                mov x_posd, dilation
+                imul x_posd, tmp1d
+                add x_posd, xd
+
+                .loop_kerx:
+                    COUNT_INPUT
+                    xor chad, chad
+                    .loop_ch:
+                        cmp tmp1d, -1
+                        je .out
+
+                        movsxdifnidn tmp1q, tmp1d
+                        movups m0, [inputq + tmp1q * 4]
+                        add tmp1d, 8
+                        jmp .load_end
+
+                        .out:
+                        xorps m0, m0
+
+                        .load_end:
+
+                        movups m1, [kernelq + kernel_posq * 4]
+                        add kernel_posq, 8
+
+                        mulps m0, m1
+                        addps m2, m0
+
+                        add chad, 8
+                        mov tmp2d, input_num
+                        cmp chad, tmp2d
+                        jl .loop_ch
+
+                    add x_posd, dilation
+                    add kernel_xd, 1
+                    mov tmp1d, kernel_size
+                    cmp kernel_xd, tmp1d
+                    jl .loop_kerx
+
+                add y_posd, dilation
+                add kernel_yd, 1
+                mov tmp1d, kernel_size
+                cmp kernel_yd, tmp1d
+                jl .loop_kery
+
+            vperm2f128 m1, m2, m2, 1
+            addps m2, m1
+            haddps m2, m2
+            haddps m2, m2
+            movsxdifnidn n_filterq, n_filterd
+            movss [outputq + n_filterq * 4], xm2
+
+            add n_filterd, 1
+            mov tmp1d, output_num
+            cmp n_filterd, tmp1d
+            jl .loop_filter
+
+        mov tmp1d, output_num
+        movsxdifnidn tmp1q, tmp1d
+        shl tmp1d, 2
+        add outputq, tmp1q
+        add xd, 1
+        mov tmp2d, width
+        sub tmp2d, pad_size
+        cmp xd, tmp2d
+        jl .loop_x
+
+    add yd, 1
+    mov tmp1d, thread_end
+    cmp yd, tmp1d
+    jl .loop_y
+
     RET
 %endif