diff mbox series

[FFmpeg-devel] avfilter/vf_atadenoise: add sigma options

Message ID 20210121175233.20931-1-onemda@gmail.com
State Accepted
Commit 95183d25e8900e7c7cb507a70616def0f5d1abf3
Headers show
Series [FFmpeg-devel] avfilter/vf_atadenoise: add sigma options
Related show

Checks

Context Check Description
andriy/x86_make success Make finished
andriy/x86_make_fate success Make fate finished
andriy/PPC64_make success Make finished
andriy/PPC64_make_fate success Make fate finished

Commit Message

Paul B Mahol Jan. 21, 2021, 5:52 p.m. UTC
Signed-off-by: Paul B Mahol <onemda@gmail.com>
---
 doc/filters.texi                     |   9 ++
 libavfilter/atadenoise.h             |  10 +-
 libavfilter/vf_atadenoise.c          | 148 +++++++++++++++++++++++++--
 libavfilter/x86/vf_atadenoise_init.c |  18 ++--
 4 files changed, 162 insertions(+), 23 deletions(-)
diff mbox series

Patch

diff --git a/doc/filters.texi b/doc/filters.texi
index 3ce6699d7c..22cce1ecc2 100644
--- a/doc/filters.texi
+++ b/doc/filters.texi
@@ -7096,6 +7096,15 @@  Alternatively can be set to @code{s} serial.
 Parallel can be faster then serial, while other way around is never true.
 Parallel will abort early on first change being greater then thresholds, while serial
 will continue processing other side of frames if they are equal or below thresholds.
+
+@item 0s
+@item 1s
+@item 2s
+Set sigma for 1st plane, 2nd plane or 3rd plane. Default is 32767.
+Valid range is from 0 to 32767.
+This options controls weight for each pixel in radius defined by size.
+Default value means every pixel have same weight.
+Setting this option to 0 effectively disables filtering.
 @end table
 
 @subsection Commands
diff --git a/libavfilter/atadenoise.h b/libavfilter/atadenoise.h
index 26cb20b9c8..7d92ece0d3 100644
--- a/libavfilter/atadenoise.h
+++ b/libavfilter/atadenoise.h
@@ -31,12 +31,12 @@  enum ATAAlgorithm {
 };
 
 typedef struct ATADenoiseDSPContext {
-    void (*filter_row)(const uint8_t *src, uint8_t *dst,
-                       const uint8_t **srcf,
-                       int w, int mid, int size,
-                       int thra, int thrb);
+    void (*filter_row[4])(const uint8_t *src, uint8_t *dst,
+                          const uint8_t **srcf,
+                          int w, int mid, int size,
+                          int thra, int thrb, const float *weight);
 } ATADenoiseDSPContext;
 
-void ff_atadenoise_init_x86(ATADenoiseDSPContext *dsp, int depth, int algorithm);
+void ff_atadenoise_init_x86(ATADenoiseDSPContext *dsp, int depth, int algorithm, const float *sigma);
 
 #endif /* AVFILTER_ATADENOISE_H */
diff --git a/libavfilter/vf_atadenoise.c b/libavfilter/vf_atadenoise.c
index e1a822045f..b543665ebf 100644
--- a/libavfilter/vf_atadenoise.c
+++ b/libavfilter/vf_atadenoise.c
@@ -44,6 +44,7 @@  typedef struct ATADenoiseContext {
     const AVClass *class;
 
     float fthra[4], fthrb[4];
+    float sigma[4];
     int thra[4], thrb[4];
     int algorithm;
 
@@ -55,7 +56,8 @@  typedef struct ATADenoiseContext {
     struct FFBufQueue q;
     void *data[4][SIZE];
     int linesize[4][SIZE];
-    int size, mid;
+    float weights[4][SIZE];
+    int size, mid, radius;
     int available;
 
     int (*filter_slice)(AVFilterContext *ctx, void *arg, int jobnr, int nb_jobs);
@@ -79,6 +81,9 @@  static const AVOption atadenoise_options[] = {
     { "a",  "set variant of algorithm",      OFFSET(algorithm),AV_OPT_TYPE_INT,   {.i64=PARALLEL},  0, NB_ATAA-1, FLAGS, "a" },
     { "p",  "parallel",                      0,                AV_OPT_TYPE_CONST, {.i64=PARALLEL},  0, 0,         FLAGS, "a" },
     { "s",  "serial",                        0,                AV_OPT_TYPE_CONST, {.i64=SERIAL},    0, 0,         FLAGS, "a" },
+    { "0s", "set sigma for 1st plane",       OFFSET(sigma[0]), AV_OPT_TYPE_FLOAT, {.dbl=INT16_MAX}, 0, INT16_MAX, FLAGS },
+    { "1s", "set sigma for 2nd plane",       OFFSET(sigma[1]), AV_OPT_TYPE_FLOAT, {.dbl=INT16_MAX}, 0, INT16_MAX, FLAGS },
+    { "2s", "set sigma for 3rd plane",       OFFSET(sigma[2]), AV_OPT_TYPE_FLOAT, {.dbl=INT16_MAX}, 0, INT16_MAX, FLAGS },
     { NULL }
 };
 
@@ -129,7 +134,8 @@  static av_cold int init(AVFilterContext *ctx)
         av_log(ctx, AV_LOG_WARNING, "size %d is invalid. Must be an odd value, setting it to %d.\n", s->size, s->size|1);
         s->size |= 1;
     }
-    s->mid = s->size / 2 + 1;
+    s->radius = s->size / 2;
+    s->mid = s->radius + 1;
 
     return 0;
 }
@@ -138,11 +144,114 @@  typedef struct ThreadData {
     AVFrame *in, *out;
 } ThreadData;
 
+#define WFILTER_ROW(type, name)                                             \
+static void fweight_row##name(const uint8_t *ssrc, uint8_t *ddst,           \
+                              const uint8_t *ssrcf[SIZE],                   \
+                              int w, int mid, int size,                     \
+                              int thra, int thrb, const float *weights)     \
+{                                                                           \
+    const type *src = (const type *)ssrc;                                   \
+    const type **srcf = (const type **)ssrcf;                               \
+    type *dst = (type *)ddst;                                               \
+                                                                            \
+    for (int x = 0; x < w; x++) {                                           \
+       const int srcx = src[x];                                             \
+       unsigned lsumdiff = 0, rsumdiff = 0;                                 \
+       unsigned ldiff, rdiff;                                               \
+       float sum = srcx;                                                    \
+       float wsum = 1.f;                                                    \
+       int l = 0, r = 0;                                                    \
+       int srcjx, srcix;                                                    \
+                                                                            \
+       for (int j = mid - 1, i = mid + 1; j >= 0 && i < size; j--, i++) {   \
+           srcjx = srcf[j][x];                                              \
+                                                                            \
+           ldiff = FFABS(srcx - srcjx);                                     \
+           lsumdiff += ldiff;                                               \
+           if (ldiff > thra ||                                              \
+               lsumdiff > thrb)                                             \
+               break;                                                       \
+           l++;                                                             \
+           sum += srcjx * weights[j];                                       \
+           wsum += weights[j];                                              \
+                                                                            \
+           srcix = srcf[i][x];                                              \
+                                                                            \
+           rdiff = FFABS(srcx - srcix);                                     \
+           rsumdiff += rdiff;                                               \
+           if (rdiff > thra ||                                              \
+               rsumdiff > thrb)                                             \
+               break;                                                       \
+           r++;                                                             \
+           sum += srcix * weights[i];                                       \
+           wsum += weights[i];                                              \
+       }                                                                    \
+                                                                            \
+       dst[x] = lrintf(sum / wsum);                                         \
+   }                                                                        \
+}
+
+WFILTER_ROW(uint8_t, 8)
+WFILTER_ROW(uint16_t, 16)
+
+#define WFILTER_ROW_SERIAL(type, name)                                      \
+static void fweight_row##name##_serial(const uint8_t *ssrc, uint8_t *ddst,  \
+                                       const uint8_t *ssrcf[SIZE],          \
+                                       int w, int mid, int size,            \
+                                       int thra, int thrb,                  \
+                                       const float *weights)                \
+{                                                                           \
+    const type *src = (const type *)ssrc;                                   \
+    const type **srcf = (const type **)ssrcf;                               \
+    type *dst = (type *)ddst;                                               \
+                                                                            \
+    for (int x = 0; x < w; x++) {                                           \
+       const int srcx = src[x];                                             \
+       unsigned lsumdiff = 0, rsumdiff = 0;                                 \
+       unsigned ldiff, rdiff;                                               \
+       float sum = srcx;                                                    \
+       float wsum = 1.f;                                                    \
+       int l = 0, r = 0;                                                    \
+       int srcjx, srcix;                                                    \
+                                                                            \
+       for (int j = mid - 1; j >= 0; j--) {                                 \
+           srcjx = srcf[j][x];                                              \
+                                                                            \
+           ldiff = FFABS(srcx - srcjx);                                     \
+           lsumdiff += ldiff;                                               \
+           if (ldiff > thra ||                                              \
+               lsumdiff > thrb)                                             \
+               break;                                                       \
+           l++;                                                             \
+           sum += srcjx * weights[j];                                       \
+           wsum += weights[j];                                              \
+       }                                                                    \
+                                                                            \
+       for (int i = mid + 1; i < size; i++) {                               \
+           srcix = srcf[i][x];                                              \
+                                                                            \
+           rdiff = FFABS(srcx - srcix);                                     \
+           rsumdiff += rdiff;                                               \
+           if (rdiff > thra ||                                              \
+               rsumdiff > thrb)                                             \
+               break;                                                       \
+           r++;                                                             \
+           sum += srcix * weights[i];                                       \
+           wsum += weights[i];                                              \
+       }                                                                    \
+                                                                            \
+       dst[x] = lrintf(sum / wsum);                                         \
+   }                                                                        \
+}
+
+WFILTER_ROW_SERIAL(uint8_t, 8)
+WFILTER_ROW_SERIAL(uint16_t, 16)
+
 #define FILTER_ROW(type, name)                                              \
 static void filter_row##name(const uint8_t *ssrc, uint8_t *ddst,            \
                              const uint8_t *ssrcf[SIZE],                    \
                              int w, int mid, int size,                      \
-                             int thra, int thrb)                            \
+                             int thra, int thrb, const float *weights)      \
 {                                                                           \
     const type *src = (const type *)ssrc;                                   \
     const type **srcf = (const type **)ssrcf;                               \
@@ -189,7 +298,8 @@  FILTER_ROW(uint16_t, 16)
 static void filter_row##name##_serial(const uint8_t *ssrc, uint8_t *ddst,   \
                                       const uint8_t *ssrcf[SIZE],           \
                                       int w, int mid, int size,             \
-                                      int thra, int thrb)                   \
+                                      int thra, int thrb,                   \
+                                      const float *weights)                 \
 {                                                                           \
     const type *src = (const type *)ssrc;                                   \
     const type **srcf = (const type **)ssrcf;                               \
@@ -245,6 +355,7 @@  static int filter_slice(AVFilterContext *ctx, void *arg, int jobnr, int nb_jobs)
     int p, y, i;
 
     for (p = 0; p < s->nb_planes; p++) {
+        const float *weights = s->weights[p];
         const int h = s->planeheight[p];
         const int w = s->planewidth[p];
         const int slice_start = (h * jobnr) / nb_jobs;
@@ -267,7 +378,7 @@  static int filter_slice(AVFilterContext *ctx, void *arg, int jobnr, int nb_jobs)
             srcf[i] = data[i] + slice_start * linesize[i];
 
         for (y = slice_start; y < slice_end; y++) {
-            s->dsp.filter_row(src, dst, srcf, w, mid, size, thra, thrb);
+            s->dsp.filter_row[p](src, dst, srcf, w, mid, size, thra, thrb, weights);
 
             dst += out->linesize[p];
             src += in->linesize[p];
@@ -296,10 +407,17 @@  static int config_input(AVFilterLink *inlink)
 
     depth = desc->comp[0].depth;
     s->filter_slice = filter_slice;
-    if (depth == 8)
-        s->dsp.filter_row = s->algorithm == PARALLEL ? filter_row8 : filter_row8_serial;
-    else
-        s->dsp.filter_row = s->algorithm == PARALLEL ? filter_row16 : filter_row16_serial;
+
+    for (int p = 0; p < s->nb_planes; p++) {
+        if (depth == 8 && s->sigma[p] == INT16_MAX)
+            s->dsp.filter_row[p] = s->algorithm == PARALLEL ? filter_row8 : filter_row8_serial;
+        else if (s->sigma[p] == INT16_MAX)
+            s->dsp.filter_row[p] = s->algorithm == PARALLEL ? filter_row16 : filter_row16_serial;
+        else if (depth == 8 && s->sigma[p] < INT16_MAX)
+            s->dsp.filter_row[p] = s->algorithm == PARALLEL ? fweight_row8 : fweight_row8_serial;
+        else if (s->sigma[p] < INT16_MAX)
+            s->dsp.filter_row[p] = s->algorithm == PARALLEL ? fweight_row16 : fweight_row16_serial;
+    }
 
     s->thra[0] = s->fthra[0] * (1 << depth) - 1;
     s->thra[1] = s->fthra[1] * (1 << depth) - 1;
@@ -308,8 +426,18 @@  static int config_input(AVFilterLink *inlink)
     s->thrb[1] = s->fthrb[1] * (1 << depth) - 1;
     s->thrb[2] = s->fthrb[2] * (1 << depth) - 1;
 
+    for (int p = 0; p < s->nb_planes; p++) {
+        float sigma = s->radius * s->sigma[p];
+
+        s->weights[p][s->mid] = 1.f;
+        for (int n = 1; n <= s->radius; n++) {
+            s->weights[p][s->radius + n] =
+            s->weights[p][s->radius - n] = expf(-0.5 * (n + 1) * (n + 1) / (sigma * sigma));
+        }
+    }
+
     if (ARCH_X86)
-        ff_atadenoise_init_x86(&s->dsp, depth, s->algorithm);
+        ff_atadenoise_init_x86(&s->dsp, depth, s->algorithm, s->sigma);
 
     return 0;
 }
diff --git a/libavfilter/x86/vf_atadenoise_init.c b/libavfilter/x86/vf_atadenoise_init.c
index 1f69b1af3f..3f87f3c445 100644
--- a/libavfilter/x86/vf_atadenoise_init.c
+++ b/libavfilter/x86/vf_atadenoise_init.c
@@ -28,22 +28,24 @@ 
 void ff_atadenoise_filter_row8_sse4(const uint8_t *src, uint8_t *dst,
                                     const uint8_t **srcf,
                                     int w, int mid, int size,
-                                    int thra, int thrb);
+                                    int thra, int thrb, const float *weights);
 
 void ff_atadenoise_filter_row8_serial_sse4(const uint8_t *src, uint8_t *dst,
                                            const uint8_t **srcf,
                                            int w, int mid, int size,
-                                           int thra, int thrb);
+                                           int thra, int thrb, const float *weights);
 
-av_cold void ff_atadenoise_init_x86(ATADenoiseDSPContext *dsp, int depth, int algorithm)
+av_cold void ff_atadenoise_init_x86(ATADenoiseDSPContext *dsp, int depth, int algorithm, const float *sigma)
 {
     int cpu_flags = av_get_cpu_flags();
 
-    if (ARCH_X86_64 && EXTERNAL_SSE4(cpu_flags) && depth <= 8 && algorithm == PARALLEL) {
-        dsp->filter_row = ff_atadenoise_filter_row8_sse4;
-    }
+    for (int p = 0; p < 4; p++) {
+        if (ARCH_X86_64 && EXTERNAL_SSE4(cpu_flags) && depth <= 8 && algorithm == PARALLEL && sigma[p] == INT16_MAX) {
+            dsp->filter_row[p] = ff_atadenoise_filter_row8_sse4;
+        }
 
-    if (ARCH_X86_64 && EXTERNAL_SSE4(cpu_flags) && depth <= 8 && algorithm == SERIAL) {
-        dsp->filter_row = ff_atadenoise_filter_row8_serial_sse4;
+        if (ARCH_X86_64 && EXTERNAL_SSE4(cpu_flags) && depth <= 8 && algorithm == SERIAL && sigma[p] == INT16_MAX) {
+            dsp->filter_row[p] = ff_atadenoise_filter_row8_serial_sse4;
+        }
     }
 }