diff mbox series

[FFmpeg-devel,v2,1/2] avfilter/palettegen, paletteuse: Extend the palette conversion filters to support palettes with alpha

Message ID MN2PR04MB598196C2F4246B36399EDED0BAA69@MN2PR04MB5981.namprd04.prod.outlook.com
State New
Headers show
Series [FFmpeg-devel,v2,1/2] avfilter/palettegen, paletteuse: Extend the palette conversion filters to support palettes with alpha
Related show

Checks

Context Check Description
andriy/make_x86 success Make finished
andriy/make_fate_x86 success Make fate finished
andriy/makex86 warning New warnings during build
andriy/make_ppc success Make finished
andriy/make_fate_ppc success Make fate finished

Commit Message

Soft Works Sept. 26, 2021, 5:22 p.m. UTC
Usage example:

ffmpeg -y -loglevel verbose -i "..\fate-suite\apng\o_sample.png" -filter_complex "split[split1][split2];[split1]palettegen=max_colors=254:use_alpha=1[pal1];[split2][pal1]paletteuse=use_alpha=1" -frames:v 1 out.png

Signed-off-by: softworkz <softworkz@hotmail.com>
---
 doc/filters.texi            |   8 ++
 libavfilter/vf_palettegen.c | 136 +++++++++++++++-------
 libavfilter/vf_paletteuse.c | 225 +++++++++++++++++++++---------------
 3 files changed, 231 insertions(+), 138 deletions(-)

Comments

Michael Niedermayer Sept. 27, 2021, 6:16 p.m. UTC | #1
On Sun, Sep 26, 2021 at 05:22:59PM +0000, Soft Works wrote:
> Usage example:
> 
> ffmpeg -y -loglevel verbose -i "..\fate-suite\apng\o_sample.png" -filter_complex "split[split1][split2];[split1]palettegen=max_colors=254:use_alpha=1[pal1];[split2][pal1]paletteuse=use_alpha=1" -frames:v 1 out.png
> 
> Signed-off-by: softworkz <softworkz@hotmail.com>
> ---
>  doc/filters.texi            |   8 ++
>  libavfilter/vf_palettegen.c | 136 +++++++++++++++-------
>  libavfilter/vf_paletteuse.c | 225 +++++++++++++++++++++---------------
>  3 files changed, 231 insertions(+), 138 deletions(-)

patchset LGTM

thx

[...]
diff mbox series

Patch

diff --git a/doc/filters.texi b/doc/filters.texi
index 36113e5c4b..7e4806235c 100644
--- a/doc/filters.texi
+++ b/doc/filters.texi
@@ -16454,6 +16454,9 @@  Compute new histogram for each frame.
 @end table
 
 Default value is @var{full}.
+@item use_alpha
+Create a palette of colors with alpha components.
+Setting this, will automatically disable 'reserve_transparent'.
 @end table
 
 The filter also exports the frame metadata @code{lavfi.color_quant_ratio}
@@ -16532,6 +16535,11 @@  will be treated as completely opaque, and values below this threshold will be
 treated as completely transparent.
 
 The option must be an integer value in the range [0,255]. Default is @var{128}.
+
+@item use_alpha
+Apply the palette by taking alpha values into account. Only useful with 
+palettes that are containing multiple colors with alpha components.
+Setting this will automatically disable 'alpha_treshold'.
 @end table
 
 @subsection Examples
diff --git a/libavfilter/vf_palettegen.c b/libavfilter/vf_palettegen.c
index 4c2fbd36d7..98dff46fe0 100644
--- a/libavfilter/vf_palettegen.c
+++ b/libavfilter/vf_palettegen.c
@@ -59,7 +59,7 @@  enum {
 };
 
 #define NBITS 5
-#define HIST_SIZE (1<<(3*NBITS))
+#define HIST_SIZE (1<<(4*NBITS))
 
 typedef struct PaletteGenContext {
     const AVClass *class;
@@ -67,6 +67,7 @@  typedef struct PaletteGenContext {
     int max_colors;
     int reserve_transparent;
     int stats_mode;
+    int use_alpha;
 
     AVFrame *prev_frame;                    // previous frame used for the diff stats_mode
     struct hist_node histogram[HIST_SIZE];  // histogram/hashtable of the colors
@@ -88,6 +89,7 @@  static const AVOption palettegen_options[] = {
         { "full", "compute full frame histograms", 0, AV_OPT_TYPE_CONST, {.i64=STATS_MODE_ALL_FRAMES}, INT_MIN, INT_MAX, FLAGS, "mode" },
         { "diff", "compute histograms only for the part that differs from previous frame", 0, AV_OPT_TYPE_CONST, {.i64=STATS_MODE_DIFF_FRAMES}, INT_MIN, INT_MAX, FLAGS, "mode" },
         { "single", "compute new histogram for each frame", 0, AV_OPT_TYPE_CONST, {.i64=STATS_MODE_SINGLE_FRAMES}, INT_MIN, INT_MAX, FLAGS, "mode" },
+    { "use_alpha", "create a palette including alpha values", OFFSET(use_alpha), AV_OPT_TYPE_BOOL, {.i64 = 0}, 0, 1, FLAGS },
     { NULL }
 };
 
@@ -113,15 +115,16 @@  static int cmp_##name(const void *pa, const void *pb)   \
 {                                                       \
     const struct color_ref * const *a = pa;             \
     const struct color_ref * const *b = pb;             \
-    return   (int)((*a)->color >> (8 * (2 - (pos))) & 0xff)  \
-           - (int)((*b)->color >> (8 * (2 - (pos))) & 0xff); \
+    return   (int)((*a)->color >> (8 * (3 - (pos))) & 0xff)  \
+           - (int)((*b)->color >> (8 * (3 - (pos))) & 0xff); \
 }
 
-DECLARE_CMP_FUNC(r, 0)
-DECLARE_CMP_FUNC(g, 1)
-DECLARE_CMP_FUNC(b, 2)
+DECLARE_CMP_FUNC(a, 0)
+DECLARE_CMP_FUNC(r, 1)
+DECLARE_CMP_FUNC(g, 2)
+DECLARE_CMP_FUNC(b, 3)
 
-static const cmp_func cmp_funcs[] = {cmp_r, cmp_g, cmp_b};
+static const cmp_func cmp_funcs[] = {cmp_a, cmp_r, cmp_g, cmp_b};
 
 /**
  * Simple color comparison for sorting the final palette
@@ -143,6 +146,17 @@  static av_always_inline int diff(const uint32_t a, const uint32_t b)
     return dr*dr + dg*dg + db*db;
 }
 
+static av_always_inline int diff_alpha(const uint32_t a, const uint32_t b)
+{
+    const uint8_t c1[] = {a >> 24 & 0xff, a >> 16 & 0xff, a >> 8 & 0xff, a & 0xff};
+    const uint8_t c2[] = {b >> 24 & 0xff, b >> 16 & 0xff, b >> 8 & 0xff, b & 0xff};
+    const int da = c1[0] - c2[0];
+    const int dr = c1[1] - c2[1];
+    const int dg = c1[2] - c2[2];
+    const int db = c1[3] - c2[3];
+    return da*da + dr*dr + dg*dg + db*db;
+}
+
 /**
  * Find the next box to split: pick the one with the highest variance
  */
@@ -164,7 +178,10 @@  static int get_next_box_id_to_split(PaletteGenContext *s)
 
                 for (i = 0; i < box->len; i++) {
                     const struct color_ref *ref = s->refs[box->start + i];
-                    variance += diff(ref->color, box->color) * ref->count;
+                    if (s->use_alpha)
+                        variance += (int64_t)diff_alpha(ref->color, box->color) * ref->count;
+                    else
+                        variance += (int64_t)diff(ref->color, box->color) * ref->count;
                 }
                 box->variance = variance;
             }
@@ -184,24 +201,31 @@  static int get_next_box_id_to_split(PaletteGenContext *s)
  * specified box. Takes into account the weight of each color.
  */
 static uint32_t get_avg_color(struct color_ref * const *refs,
-                              const struct range_box *box)
+                              const struct range_box *box, int use_alpha)
 {
     int i;
     const int n = box->len;
-    uint64_t r = 0, g = 0, b = 0, div = 0;
+    uint64_t a = 0, r = 0, g = 0, b = 0, div = 0;
 
     for (i = 0; i < n; i++) {
         const struct color_ref *ref = refs[box->start + i];
-        r += (ref->color >> 16 & 0xff) * ref->count;
-        g += (ref->color >>  8 & 0xff) * ref->count;
-        b += (ref->color       & 0xff) * ref->count;
+        if (use_alpha)
+            a += (ref->color >> 24 & 0xff) * ref->count;
+        r += (ref->color     >> 16 & 0xff) * ref->count;
+        g += (ref->color     >>  8 & 0xff) * ref->count;
+        b += (ref->color           & 0xff) * ref->count;
         div += ref->count;
     }
 
+    if (use_alpha)
+        a = a / div;
     r = r / div;
     g = g / div;
     b = b / div;
 
+    if (use_alpha)
+        return a<<24 | r<<16 | g<<8 | b;
+
     return 0xffU<<24 | r<<16 | g<<8 | b;
 }
 
@@ -220,8 +244,8 @@  static void split_box(PaletteGenContext *s, struct range_box *box, int n)
     av_assert0(box->len     >= 1);
     av_assert0(new_box->len >= 1);
 
-    box->color     = get_avg_color(s->refs, box);
-    new_box->color = get_avg_color(s->refs, new_box);
+    box->color     = get_avg_color(s->refs, box, s->use_alpha);
+    new_box->color = get_avg_color(s->refs, new_box, s->use_alpha);
     box->variance     = -1;
     new_box->variance = -1;
 }
@@ -251,7 +275,7 @@  static void write_palette(AVFilterContext *ctx, AVFrame *out)
         pal += pal_linesize;
     }
 
-    if (s->reserve_transparent) {
+    if (s->reserve_transparent && !s->use_alpha) {
         av_assert0(s->nb_boxes < 256);
         pal[out->width - pal_linesize - 1] = AV_RB32(&s->transparency_color) >> 8;
     }
@@ -319,40 +343,49 @@  static AVFrame *get_palette_frame(AVFilterContext *ctx)
     box = &s->boxes[box_id];
     box->len = s->nb_refs;
     box->sorted_by = -1;
-    box->color = get_avg_color(s->refs, box);
+    box->color = get_avg_color(s->refs, box, s->use_alpha);
     box->variance = -1;
     s->nb_boxes = 1;
 
     while (box && box->len > 1) {
-        int i, rr, gr, br, longest;
+        int i, ar, rr, gr, br, longest;
         uint64_t median, box_weight = 0;
 
         /* compute the box weight (sum all the weights of the colors in the
          * range) and its boundings */
-        uint8_t min[3] = {0xff, 0xff, 0xff};
-        uint8_t max[3] = {0x00, 0x00, 0x00};
+        uint8_t min[4] = {0xff, 0xff, 0xff, 0xff};
+        uint8_t max[4] = {0x00, 0x00, 0x00, 0x00};
         for (i = box->start; i < box->start + box->len; i++) {
             const struct color_ref *ref = s->refs[i];
             const uint32_t rgb = ref->color;
-            const uint8_t r = rgb >> 16 & 0xff, g = rgb >> 8 & 0xff, b = rgb & 0xff;
-            min[0] = FFMIN(r, min[0]), max[0] = FFMAX(r, max[0]);
-            min[1] = FFMIN(g, min[1]), max[1] = FFMAX(g, max[1]);
-            min[2] = FFMIN(b, min[2]), max[2] = FFMAX(b, max[2]);
+            const uint8_t a = rgb >> 24 & 0xff, r = rgb >> 16 & 0xff, g = rgb >> 8 & 0xff, b = rgb & 0xff;
+            min[0] = FFMIN(a, min[0]); max[0] = FFMAX(a, max[0]);
+            min[1] = FFMIN(r, min[1]); max[1] = FFMAX(r, max[1]);
+            min[2] = FFMIN(g, min[2]); max[2] = FFMAX(g, max[2]);
+            min[3] = FFMIN(b, min[3]); max[3] = FFMAX(b, max[3]);
             box_weight += ref->count;
         }
 
         /* define the axis to sort by according to the widest range of colors */
-        rr = max[0] - min[0];
-        gr = max[1] - min[1];
-        br = max[2] - min[2];
-        longest = 1; // pick green by default (the color the eye is the most sensitive to)
-        if (br >= rr && br >= gr) longest = 2;
-        if (rr >= gr && rr >= br) longest = 0;
-        if (gr >= rr && gr >= br) longest = 1; // prefer green again
-
-        ff_dlog(ctx, "box #%02X [%6d..%-6d] (%6d) w:%-6"PRIu64" ranges:[%2x %2x %2x] sort by %c (already sorted:%c) ",
+        ar = max[0] - min[0];
+        rr = max[1] - min[1];
+        gr = max[2] - min[2];
+        br = max[3] - min[3];
+        longest = 2; // pick green by default (the color the eye is the most sensitive to)
+        if (s->use_alpha) {
+            if (ar >= rr && ar >= br && ar >= gr) longest = 0;
+            if (br >= rr && br >= gr && br >= ar) longest = 3;
+            if (rr >= gr && rr >= br && rr >= ar) longest = 1;
+            if (gr >= rr && gr >= br && gr >= ar) longest = 2; // prefer green again
+        } else {
+            if (br >= rr && br >= gr) longest = 3;
+            if (rr >= gr && rr >= br) longest = 1;
+            if (gr >= rr && gr >= br) longest = 2; // prefer green again
+        }
+
+        ff_dlog(ctx, "box #%02X [%6d..%-6d] (%6d) w:%-6"PRIu64" ranges:[%2x %2x %2x %2x] sort by %c (already sorted:%c) ",
                 box_id, box->start, box->start + box->len - 1, box->len, box_weight,
-                rr, gr, br, "rgb"[longest], box->sorted_by == longest ? 'y':'n');
+                ar, rr, gr, br, "argb"[longest], box->sorted_by == longest ? 'y' : 'n');
 
         /* sort the range by its longest axis if it's not already sorted */
         if (box->sorted_by != longest) {
@@ -394,21 +427,27 @@  static AVFrame *get_palette_frame(AVFilterContext *ctx)
  * It keeps the NBITS least significant bit of each component to make it
  * "random" even if the scene doesn't have much different colors.
  */
-static inline unsigned color_hash(uint32_t color)
+static inline unsigned color_hash(uint32_t color, int use_alpha)
 {
     const uint8_t r = color >> 16 & ((1<<NBITS)-1);
     const uint8_t g = color >>  8 & ((1<<NBITS)-1);
     const uint8_t b = color       & ((1<<NBITS)-1);
+
+    if (use_alpha) {
+        const uint8_t a = color >> 24 & ((1 << NBITS) - 1);
+        return a << (NBITS * 3) | r << (NBITS * 2) | g << NBITS | b;
+    }
+
     return r<<(NBITS*2) | g<<NBITS | b;
 }
 
 /**
  * Locate the color in the hash table and increment its counter.
  */
-static int color_inc(struct hist_node *hist, uint32_t color)
+static int color_inc(struct hist_node *hist, uint32_t color, int use_alpha)
 {
     int i;
-    const unsigned hash = color_hash(color);
+    const unsigned hash = color_hash(color, use_alpha);
     struct hist_node *node = &hist[hash];
     struct color_ref *e;
 
@@ -433,7 +472,7 @@  static int color_inc(struct hist_node *hist, uint32_t color)
  * Update histogram when pixels differ from previous frame.
  */
 static int update_histogram_diff(struct hist_node *hist,
-                                 const AVFrame *f1, const AVFrame *f2)
+                                 const AVFrame *f1, const AVFrame *f2, int use_alpha)
 {
     int x, y, ret, nb_diff_colors = 0;
 
@@ -444,7 +483,7 @@  static int update_histogram_diff(struct hist_node *hist,
         for (x = 0; x < f1->width; x++) {
             if (p[x] == q[x])
                 continue;
-            ret = color_inc(hist, p[x]);
+            ret = color_inc(hist, p[x], use_alpha);
             if (ret < 0)
                 return ret;
             nb_diff_colors += ret;
@@ -456,7 +495,7 @@  static int update_histogram_diff(struct hist_node *hist,
 /**
  * Simple histogram of the frame.
  */
-static int update_histogram_frame(struct hist_node *hist, const AVFrame *f)
+static int update_histogram_frame(struct hist_node *hist, const AVFrame *f, int use_alpha)
 {
     int x, y, ret, nb_diff_colors = 0;
 
@@ -464,7 +503,7 @@  static int update_histogram_frame(struct hist_node *hist, const AVFrame *f)
         const uint32_t *p = (const uint32_t *)(f->data[0] + y*f->linesize[0]);
 
         for (x = 0; x < f->width; x++) {
-            ret = color_inc(hist, p[x]);
+            ret = color_inc(hist, p[x], use_alpha);
             if (ret < 0)
                 return ret;
             nb_diff_colors += ret;
@@ -480,8 +519,8 @@  static int filter_frame(AVFilterLink *inlink, AVFrame *in)
 {
     AVFilterContext *ctx = inlink->dst;
     PaletteGenContext *s = ctx->priv;
-    int ret = s->prev_frame ? update_histogram_diff(s->histogram, s->prev_frame, in)
-                            : update_histogram_frame(s->histogram, in);
+    int ret = s->prev_frame ? update_histogram_diff(s->histogram, s->prev_frame, in, s->use_alpha)
+                            : update_histogram_frame(s->histogram, in, s->use_alpha);
 
     if (ret > 0)
         s->nb_refs += ret;
@@ -540,6 +579,16 @@  static int config_output(AVFilterLink *outlink)
     return 0;
 }
 
+static int init(AVFilterContext *ctx)
+{
+    PaletteGenContext* s = ctx->priv;
+
+    if (s->use_alpha && s->reserve_transparent)
+        s->reserve_transparent = 0;
+
+    return 0;
+}
+
 static av_cold void uninit(AVFilterContext *ctx)
 {
     int i;
@@ -572,6 +621,7 @@  const AVFilter ff_vf_palettegen = {
     .name          = "palettegen",
     .description   = NULL_IF_CONFIG_SMALL("Find the optimal palette for a given stream."),
     .priv_size     = sizeof(PaletteGenContext),
+    .init        = init,
     .uninit        = uninit,
     .query_formats = query_formats,
     FILTER_INPUTS(palettegen_inputs),
diff --git a/libavfilter/vf_paletteuse.c b/libavfilter/vf_paletteuse.c
index f9bc28f7d0..2ac30b4e5d 100644
--- a/libavfilter/vf_paletteuse.c
+++ b/libavfilter/vf_paletteuse.c
@@ -28,7 +28,6 @@ 
 #include "libavutil/opt.h"
 #include "libavutil/qsort.h"
 #include "avfilter.h"
-#include "filters.h"
 #include "framesync.h"
 #include "internal.h"
 
@@ -63,7 +62,7 @@  struct color_node {
 };
 
 #define NBITS 5
-#define CACHE_SIZE (1<<(3*NBITS))
+#define CACHE_SIZE (1<<(4*NBITS))
 
 struct cached_color {
     uint32_t color;
@@ -88,6 +87,7 @@  typedef struct PaletteUseContext {
     uint32_t palette[AVPALETTE_COUNT];
     int transparency_index; /* index in the palette of transparency. -1 if there is no transparency in the palette. */
     int trans_thresh;
+    int use_alpha;
     int palette_loaded;
     int dither;
     int new;
@@ -107,7 +107,7 @@  typedef struct PaletteUseContext {
 } PaletteUseContext;
 
 #define OFFSET(x) offsetof(PaletteUseContext, x)
-#define FLAGS AV_OPT_FLAG_FILTERING_PARAM|AV_OPT_FLAG_VIDEO_PARAM
+#define FLAGS (AV_OPT_FLAG_FILTERING_PARAM | AV_OPT_FLAG_VIDEO_PARAM)
 static const AVOption paletteuse_options[] = {
     { "dither", "select dithering mode", OFFSET(dither), AV_OPT_TYPE_INT, {.i64=DITHERING_SIERRA2_4A}, 0, NB_DITHERING-1, FLAGS, "dithering_mode" },
         { "bayer",           "ordered 8x8 bayer dithering (deterministic)",                            0, AV_OPT_TYPE_CONST, {.i64=DITHERING_BAYER},           INT_MIN, INT_MAX, FLAGS, "dithering_mode" },
@@ -120,6 +120,7 @@  static const AVOption paletteuse_options[] = {
         { "rectangle", "process smallest different rectangle", 0, AV_OPT_TYPE_CONST, {.i64=DIFF_MODE_RECTANGLE}, INT_MIN, INT_MAX, FLAGS, "diff_mode" },
     { "new", "take new palette for each output frame", OFFSET(new), AV_OPT_TYPE_BOOL, {.i64=0}, 0, 1, FLAGS },
     { "alpha_threshold", "set the alpha threshold for transparency", OFFSET(trans_thresh), AV_OPT_TYPE_INT, {.i64=128}, 0, 255, FLAGS },
+    { "use_alpha", "use alpha channel for mapping", OFFSET(use_alpha), AV_OPT_TYPE_BOOL, {.i64=0}, 0, 1, FLAGS },
 
     /* following are the debug options, not part of the official API */
     { "debug_kdtree", "save Graphviz graph of the kdtree in specified file", OFFSET(dot_filename), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, FLAGS },
@@ -161,37 +162,41 @@  static av_always_inline uint32_t dither_color(uint32_t px, int er, int eg,
          | av_clip_uint8((px       & 0xff) + ((eb * scale) / (1<<shift)));
 }
 
-static av_always_inline int diff(const uint8_t *c1, const uint8_t *c2, const int trans_thresh)
+static av_always_inline int diff(const uint8_t *c1, const uint8_t *c2, const PaletteUseContext *s)
 {
     // XXX: try L*a*b with CIE76 (dL*dL + da*da + db*db)
+    const int da = c1[0] - c2[0];
     const int dr = c1[1] - c2[1];
     const int dg = c1[2] - c2[2];
     const int db = c1[3] - c2[3];
 
-    if (c1[0] < trans_thresh && c2[0] < trans_thresh) {
+    if (s->use_alpha)
+        return da*da + dr*dr + dg*dg + db*db;
+
+    if (c1[0] < s->trans_thresh && c2[0] < s->trans_thresh) {
         return 0;
-    } else if (c1[0] >= trans_thresh && c2[0] >= trans_thresh) {
+    } else if (c1[0] >= s->trans_thresh && c2[0] >= s->trans_thresh) {
         return dr*dr + dg*dg + db*db;
     } else {
         return 255*255 + 255*255 + 255*255;
     }
 }
 
-static av_always_inline uint8_t colormap_nearest_bruteforce(const uint32_t *palette, const uint8_t *argb, const int trans_thresh)
+static av_always_inline uint8_t colormap_nearest_bruteforce(const PaletteUseContext *s, const uint8_t *argb)
 {
     int i, pal_id = -1, min_dist = INT_MAX;
 
     for (i = 0; i < AVPALETTE_COUNT; i++) {
-        const uint32_t c = palette[i];
+        const uint32_t c = s->palette[i];
 
-        if (c >> 24 >= trans_thresh) { // ignore transparent entry
+        if (s->use_alpha || c >> 24 >= s->trans_thresh) { // ignore transparent entry
             const uint8_t palargb[] = {
-                palette[i]>>24 & 0xff,
-                palette[i]>>16 & 0xff,
-                palette[i]>> 8 & 0xff,
-                palette[i]     & 0xff,
+                s->palette[i]>>24 & 0xff,
+                s->palette[i]>>16 & 0xff,
+                s->palette[i]>> 8 & 0xff,
+                s->palette[i]     & 0xff,
             };
-            const int d = diff(palargb, argb, trans_thresh);
+            const int d = diff(palargb, argb, s);
             if (d < min_dist) {
                 pal_id = i;
                 min_dist = d;
@@ -207,17 +212,17 @@  struct nearest_color {
     int dist_sqd;
 };
 
-static void colormap_nearest_node(const struct color_node *map,
+static void colormap_nearest_node(const PaletteUseContext *s,
+                                  const struct color_node *map,
                                   const int node_pos,
                                   const uint8_t *target,
-                                  const int trans_thresh,
                                   struct nearest_color *nearest)
 {
     const struct color_node *kd = map + node_pos;
-    const int s = kd->split;
+    const int split = kd->split;
     int dx, nearer_kd_id, further_kd_id;
     const uint8_t *current = kd->val;
-    const int current_to_target = diff(target, current, trans_thresh);
+    const int current_to_target = diff(target, current, s);
 
     if (current_to_target < nearest->dist_sqd) {
         nearest->node_pos = node_pos;
@@ -225,23 +230,23 @@  static void colormap_nearest_node(const struct color_node *map,
     }
 
     if (kd->left_id != -1 || kd->right_id != -1) {
-        dx = target[s] - current[s];
+        dx = target[split] - current[split];
 
         if (dx <= 0) nearer_kd_id = kd->left_id,  further_kd_id = kd->right_id;
         else         nearer_kd_id = kd->right_id, further_kd_id = kd->left_id;
 
         if (nearer_kd_id != -1)
-            colormap_nearest_node(map, nearer_kd_id, target, trans_thresh, nearest);
+            colormap_nearest_node(s, map, nearer_kd_id, target, nearest);
 
         if (further_kd_id != -1 && dx*dx < nearest->dist_sqd)
-            colormap_nearest_node(map, further_kd_id, target, trans_thresh, nearest);
+            colormap_nearest_node(s, map, further_kd_id, target, nearest);
     }
 }
 
-static av_always_inline uint8_t colormap_nearest_recursive(const struct color_node *node, const uint8_t *rgb, const int trans_thresh)
+static av_always_inline uint8_t colormap_nearest_recursive(const PaletteUseContext *s, const struct color_node *node, const uint8_t *rgb)
 {
     struct nearest_color res = {.dist_sqd = INT_MAX, .node_pos = -1};
-    colormap_nearest_node(node, 0, rgb, trans_thresh, &res);
+    colormap_nearest_node(s, node, 0, rgb, &res);
     return node[res.node_pos].palette_id;
 }
 
@@ -250,7 +255,7 @@  struct stack_node {
     int dx2;
 };
 
-static av_always_inline uint8_t colormap_nearest_iterative(const struct color_node *root, const uint8_t *target, const int trans_thresh)
+static av_always_inline uint8_t colormap_nearest_iterative(const PaletteUseContext *s, const struct color_node *root, const uint8_t *target)
 {
     int pos = 0, best_node_id = -1, best_dist = INT_MAX, cur_color_id = 0;
     struct stack_node nodes[16];
@@ -260,7 +265,7 @@  static av_always_inline uint8_t colormap_nearest_iterative(const struct color_no
 
         const struct color_node *kd = &root[cur_color_id];
         const uint8_t *current = kd->val;
-        const int current_to_target = diff(target, current, trans_thresh);
+        const int current_to_target = diff(target, current, s);
 
         /* Compare current color node to the target and update our best node if
          * it's actually better. */
@@ -322,10 +327,10 @@  end:
     return root[best_node_id].palette_id;
 }
 
-#define COLORMAP_NEAREST(search, palette, root, target, trans_thresh)                                    \
-    search == COLOR_SEARCH_NNS_ITERATIVE ? colormap_nearest_iterative(root, target, trans_thresh) :      \
-    search == COLOR_SEARCH_NNS_RECURSIVE ? colormap_nearest_recursive(root, target, trans_thresh) :      \
-                                           colormap_nearest_bruteforce(palette, target, trans_thresh)
+#define COLORMAP_NEAREST(s, search, root, target)                                    \
+    search == COLOR_SEARCH_NNS_ITERATIVE ? colormap_nearest_iterative(s, root, target) :      \
+    search == COLOR_SEARCH_NNS_RECURSIVE ? colormap_nearest_recursive(s, root, target) :      \
+                                           colormap_nearest_bruteforce(s, target)
 
 /**
  * Check if the requested color is in the cache already. If not, find it in the
@@ -362,13 +367,13 @@  static av_always_inline int color_get(PaletteUseContext *s, uint32_t color,
     if (!e)
         return AVERROR(ENOMEM);
     e->color = color;
-    e->pal_entry = COLORMAP_NEAREST(search_method, s->palette, s->map, argb_elts, s->trans_thresh);
+    e->pal_entry = COLORMAP_NEAREST(s, search_method, s->map, argb_elts);
 
     return e->pal_entry;
 }
 
 static av_always_inline int get_dst_color_err(PaletteUseContext *s,
-                                              uint32_t c, int *er, int *eg, int *eb,
+                                              uint32_t c, int *ea, int *er, int *eg, int *eb,
                                               const enum color_search_method search_method)
 {
     const uint8_t a = c >> 24 & 0xff;
@@ -381,8 +386,9 @@  static av_always_inline int get_dst_color_err(PaletteUseContext *s,
         return dstx;
     dstc = s->palette[dstx];
     if (dstx == s->transparency_index) {
-        *er = *eg = *eb = 0;
+        *ea =*er = *eg = *eb = 0;
     } else {
+        *ea = (int)a - (int)(dstc >> 24 & 0xff);
         *er = (int)r - (int)(dstc >> 16 & 0xff);
         *eg = (int)g - (int)(dstc >>  8 & 0xff);
         *eb = (int)b - (int)(dstc       & 0xff);
@@ -406,7 +412,7 @@  static av_always_inline int set_frame(PaletteUseContext *s, AVFrame *out, AVFram
 
     for (y = y_start; y < h; y++) {
         for (x = x_start; x < w; x++) {
-            int er, eg, eb;
+            int ea, er, eg, eb;
 
             if (dither == DITHERING_BAYER) {
                 const int d = s->ordered_dither[(y & 7)<<3 | (x & 7)];
@@ -425,7 +431,7 @@  static av_always_inline int set_frame(PaletteUseContext *s, AVFrame *out, AVFram
 
             } else if (dither == DITHERING_HECKBERT) {
                 const int right = x < w - 1, down = y < h - 1;
-                const int color = get_dst_color_err(s, src[x], &er, &eg, &eb, search_method);
+                const int color = get_dst_color_err(s, src[x], &ea, &er, &eg, &eb, search_method);
 
                 if (color < 0)
                     return color;
@@ -437,7 +443,7 @@  static av_always_inline int set_frame(PaletteUseContext *s, AVFrame *out, AVFram
 
             } else if (dither == DITHERING_FLOYD_STEINBERG) {
                 const int right = x < w - 1, down = y < h - 1, left = x > x_start;
-                const int color = get_dst_color_err(s, src[x], &er, &eg, &eb, search_method);
+                const int color = get_dst_color_err(s, src[x], &ea, &er, &eg, &eb, search_method);
 
                 if (color < 0)
                     return color;
@@ -451,7 +457,7 @@  static av_always_inline int set_frame(PaletteUseContext *s, AVFrame *out, AVFram
             } else if (dither == DITHERING_SIERRA2) {
                 const int right  = x < w - 1, down  = y < h - 1, left  = x > x_start;
                 const int right2 = x < w - 2,                    left2 = x > x_start + 1;
-                const int color = get_dst_color_err(s, src[x], &er, &eg, &eb, search_method);
+                const int color = get_dst_color_err(s, src[x], &ea, &er, &eg, &eb, search_method);
 
                 if (color < 0)
                     return color;
@@ -470,7 +476,7 @@  static av_always_inline int set_frame(PaletteUseContext *s, AVFrame *out, AVFram
 
             } else if (dither == DITHERING_SIERRA2_4A) {
                 const int right = x < w - 1, down = y < h - 1, left = x > x_start;
-                const int color = get_dst_color_err(s, src[x], &er, &eg, &eb, search_method);
+                const int color = get_dst_color_err(s, src[x], &ea, &er, &eg, &eb, search_method);
 
                 if (color < 0)
                     return color;
@@ -553,8 +559,7 @@  static int disp_tree(const struct color_node *node, const char *fname)
     return 0;
 }
 
-static int debug_accuracy(const struct color_node *node, const uint32_t *palette, const int trans_thresh,
-                          const enum color_search_method search_method)
+static int debug_accuracy(const PaletteUseContext *s)
 {
     int r, g, b, ret = 0;
 
@@ -562,19 +567,26 @@  static int debug_accuracy(const struct color_node *node, const uint32_t *palette
         for (g = 0; g < 256; g++) {
             for (b = 0; b < 256; b++) {
                 const uint8_t argb[] = {0xff, r, g, b};
-                const int r1 = COLORMAP_NEAREST(search_method, palette, node, argb, trans_thresh);
-                const int r2 = colormap_nearest_bruteforce(palette, argb, trans_thresh);
+                const int r1 = COLORMAP_NEAREST(s, s->color_search_method, s->map, argb);
+                const int r2 = colormap_nearest_bruteforce(s, argb);
                 if (r1 != r2) {
-                    const uint32_t c1 = palette[r1];
-                    const uint32_t c2 = palette[r2];
-                    const uint8_t palargb1[] = { 0xff, c1>>16 & 0xff, c1>> 8 & 0xff, c1 & 0xff };
-                    const uint8_t palargb2[] = { 0xff, c2>>16 & 0xff, c2>> 8 & 0xff, c2 & 0xff };
-                    const int d1 = diff(palargb1, argb, trans_thresh);
-                    const int d2 = diff(palargb2, argb, trans_thresh);
+                    const uint32_t c1 = s->palette[r1];
+                    const uint32_t c2 = s->palette[r2];
+                    const uint8_t a1 = s->use_alpha ? c1>>24 & 0xff : 0xff;
+                    const uint8_t a2 = s->use_alpha ? c2>>24 & 0xff : 0xff;
+                    const uint8_t palargb1[] = { a1, c1>>16 & 0xff, c1>> 8 & 0xff, c1 & 0xff };
+                    const uint8_t palargb2[] = { a2, c2>>16 & 0xff, c2>> 8 & 0xff, c2 & 0xff };
+                    const int d1 = diff(palargb1, argb, s);
+                    const int d2 = diff(palargb2, argb, s);
                     if (d1 != d2) {
-                        av_log(NULL, AV_LOG_ERROR,
-                               "/!\\ %02X%02X%02X: %d ! %d (%06"PRIX32" ! %06"PRIX32") / dist: %d ! %d\n",
-                               r, g, b, r1, r2, c1 & 0xffffff, c2 & 0xffffff, d1, d2);
+                        if (s->use_alpha)
+                            av_log(NULL, AV_LOG_ERROR,
+                                   "/!\\ %02X%02X%02X: %d ! %d (%08"PRIX32" ! %08"PRIX32") / dist: %d ! %d\n",
+                                   r, g, b, r1, r2, c1, c2, d1, d2);
+                        else
+                            av_log(NULL, AV_LOG_ERROR,
+                                   "/!\\ %02X%02X%02X: %d ! %d (%06"PRIX32" ! %06"PRIX32") / dist: %d ! %d\n",
+                                   r, g, b, r1, r2, c1 & 0xffffff, c2 & 0xffffff, d1, d2);
                         ret = 1;
                     }
                 }
@@ -590,8 +602,8 @@  struct color {
 };
 
 struct color_rect {
-    uint8_t min[3];
-    uint8_t max[3];
+    uint8_t min[4];
+    uint8_t max[4];
 };
 
 typedef int (*cmp_func)(const void *, const void *);
@@ -612,43 +624,47 @@  DECLARE_CMP_FUNC(b, 3)
 
 static const cmp_func cmp_funcs[] = {cmp_a, cmp_r, cmp_g, cmp_b};
 
-static int get_next_color(const uint8_t *color_used, const uint32_t *palette,
-                          const int trans_thresh,
+static int get_next_color(const uint8_t *color_used, const PaletteUseContext *s,
                           int *component, const struct color_rect *box)
 {
-    int wr, wg, wb;
+    int wa, wr, wg, wb;
     int i, longest = 0;
     unsigned nb_color = 0;
     struct color_rect ranges;
     struct color tmp_pal[256];
     cmp_func cmpf;
 
-    ranges.min[0] = ranges.min[1] = ranges.min[2] = 0xff;
-    ranges.max[0] = ranges.max[1] = ranges.max[2] = 0x00;
+    ranges.min[0] = ranges.min[1] = ranges.min[2]  = ranges.min[3]= 0xff;
+    ranges.max[0] = ranges.max[1] = ranges.max[2]  = ranges.max[3]= 0x00;
 
     for (i = 0; i < AVPALETTE_COUNT; i++) {
-        const uint32_t c = palette[i];
+        const uint32_t c = s->palette[i];
         const uint8_t a = c >> 24 & 0xff;
         const uint8_t r = c >> 16 & 0xff;
         const uint8_t g = c >>  8 & 0xff;
         const uint8_t b = c       & 0xff;
 
-        if (a < trans_thresh) {
+        if (!s->use_alpha && a < s->trans_thresh) {
             continue;
         }
 
-        if (color_used[i] || (a != 0xff) ||
-            r < box->min[0] || g < box->min[1] || b < box->min[2] ||
-            r > box->max[0] || g > box->max[1] || b > box->max[2])
+        if (color_used[i] || (a != 0xff && !s->use_alpha) ||
+            r < box->min[1] || g < box->min[2] || b < box->min[3] ||
+            r > box->max[1] || g > box->max[2] || b > box->max[3])
             continue;
 
-        if (r < ranges.min[0]) ranges.min[0] = r;
-        if (g < ranges.min[1]) ranges.min[1] = g;
-        if (b < ranges.min[2]) ranges.min[2] = b;
+        if (s->use_alpha && (a < box->min[0] || a > box->max[0]))
+            continue;
+
+        if (a < ranges.min[0]) ranges.min[0] = a;
+        if (r < ranges.min[1]) ranges.min[1] = r;
+        if (g < ranges.min[2]) ranges.min[2] = g;
+        if (b < ranges.min[3]) ranges.min[3] = b;
 
-        if (r > ranges.max[0]) ranges.max[0] = r;
-        if (g > ranges.max[1]) ranges.max[1] = g;
-        if (b > ranges.max[2]) ranges.max[2] = b;
+        if (a > ranges.max[0]) ranges.max[0] = a;
+        if (r > ranges.max[1]) ranges.max[1] = r;
+        if (g > ranges.max[2]) ranges.max[2] = g;
+        if (b > ranges.max[3]) ranges.max[3] = b;
 
         tmp_pal[nb_color].value  = c;
         tmp_pal[nb_color].pal_id = i;
@@ -660,12 +676,22 @@  static int get_next_color(const uint8_t *color_used, const uint32_t *palette,
         return -1;
 
     /* define longest axis that will be the split component */
-    wr = ranges.max[0] - ranges.min[0];
-    wg = ranges.max[1] - ranges.min[1];
-    wb = ranges.max[2] - ranges.min[2];
-    if (wr >= wg && wr >= wb) longest = 1;
-    if (wg >= wr && wg >= wb) longest = 2;
-    if (wb >= wr && wb >= wg) longest = 3;
+    wa = ranges.max[0] - ranges.min[0];
+    wr = ranges.max[1] - ranges.min[1];
+    wg = ranges.max[2] - ranges.min[2];
+    wb = ranges.max[3] - ranges.min[3];
+
+    if (s->use_alpha) {
+        if (wa >= wr && wa >= wb && wa >= wg) longest = 0;
+        if (wr >= wg && wr >= wb && wr >= wa) longest = 1;
+        if (wg >= wr && wg >= wb && wg >= wa) longest = 2;
+        if (wb >= wr && wb >= wg && wb >= wa) longest = 3;
+    } else {
+        if (wr >= wg && wr >= wb) longest = 1;
+        if (wg >= wr && wg >= wb) longest = 2;
+        if (wb >= wr && wb >= wg) longest = 3;
+    }
+
     cmpf = cmp_funcs[longest];
     *component = longest;
 
@@ -678,8 +704,7 @@  static int get_next_color(const uint8_t *color_used, const uint32_t *palette,
 static int colormap_insert(struct color_node *map,
                            uint8_t *color_used,
                            int *nb_used,
-                           const uint32_t *palette,
-                           const int trans_thresh,
+                           const PaletteUseContext *s,
                            const struct color_rect *box)
 {
     uint32_t c;
@@ -687,14 +712,14 @@  static int colormap_insert(struct color_node *map,
     int node_left_id = -1, node_right_id = -1;
     struct color_node *node;
     struct color_rect box1, box2;
-    const int pal_id = get_next_color(color_used, palette, trans_thresh, &component, box);
+    const int pal_id = get_next_color(color_used, s, &component, box);
 
     if (pal_id < 0)
         return -1;
 
     /* create new node with that color */
     cur_id = (*nb_used)++;
-    c = palette[pal_id];
+    c = s->palette[pal_id];
     node = &map[cur_id];
     node->split = component;
     node->palette_id = pal_id;
@@ -707,13 +732,13 @@  static int colormap_insert(struct color_node *map,
 
     /* get the two boxes this node creates */
     box1 = box2 = *box;
-    box1.max[component-1] = node->val[component];
-    box2.min[component-1] = FFMIN(node->val[component] + 1, 255);
+    box1.max[component] = node->val[component];
+    box2.min[component] = FFMIN(node->val[component] + 1, 255);
 
-    node_left_id = colormap_insert(map, color_used, nb_used, palette, trans_thresh, &box1);
+    node_left_id = colormap_insert(map, color_used, nb_used, s, &box1);
 
-    if (box2.min[component-1] <= box2.max[component-1])
-        node_right_id = colormap_insert(map, color_used, nb_used, palette, trans_thresh, &box2);
+    if (box2.min[component] <= box2.max[component])
+        node_right_id = colormap_insert(map, color_used, nb_used, s, &box2);
 
     node->left_id  = node_left_id;
     node->right_id = node_right_id;
@@ -728,6 +753,13 @@  static int cmp_pal_entry(const void *a, const void *b)
     return c1 - c2;
 }
 
+static int cmp_pal_entry_alpha(const void *a, const void *b)
+{
+    const int c1 = *(const uint32_t *)a;
+    const int c2 = *(const uint32_t *)b;
+    return c1 - c2;
+}
+
 static void load_colormap(PaletteUseContext *s)
 {
     int i, nb_used = 0;
@@ -735,12 +767,13 @@  static void load_colormap(PaletteUseContext *s)
     uint32_t last_color = 0;
     struct color_rect box;
 
-    if (s->transparency_index >= 0) {
+    if (!s->use_alpha && s->transparency_index >= 0) {
         FFSWAP(uint32_t, s->palette[s->transparency_index], s->palette[255]);
     }
 
     /* disable transparent colors and dups */
-    qsort(s->palette, AVPALETTE_COUNT-(s->transparency_index >= 0), sizeof(*s->palette), cmp_pal_entry);
+    qsort(s->palette, AVPALETTE_COUNT-(s->transparency_index >= 0), sizeof(*s->palette),
+        s->use_alpha ? cmp_pal_entry_alpha : cmp_pal_entry);
 
     for (i = 0; i < AVPALETTE_COUNT; i++) {
         const uint32_t c = s->palette[i];
@@ -749,22 +782,22 @@  static void load_colormap(PaletteUseContext *s)
             continue;
         }
         last_color = c;
-        if (c >> 24 < s->trans_thresh) {
+        if (!s->use_alpha && c >> 24 < s->trans_thresh) {
             color_used[i] = 1; // ignore transparent color(s)
             continue;
         }
     }
 
-    box.min[0] = box.min[1] = box.min[2] = 0x00;
-    box.max[0] = box.max[1] = box.max[2] = 0xff;
+    box.min[0] = box.min[1] = box.min[2] = box.min[3] = 0x00;
+    box.max[0] = box.max[1] = box.max[2] = box.max[3] = 0xff;
 
-    colormap_insert(s->map, color_used, &nb_used, s->palette, s->trans_thresh, &box);
+    colormap_insert(s->map, color_used, &nb_used, s, &box);
 
     if (s->dot_filename)
         disp_tree(s->map, s->dot_filename);
 
     if (s->debug_accuracy) {
-        if (!debug_accuracy(s->map, s->palette, s->trans_thresh, s->color_search_method))
+        if (!debug_accuracy(s))
             av_log(NULL, AV_LOG_INFO, "Accuracy check passed\n");
     }
 }
@@ -778,16 +811,18 @@  static void debug_mean_error(PaletteUseContext *s, const AVFrame *in1,
     uint8_t  *src2 =             in2->data[0];
     const int src1_linesize = in1->linesize[0] >> 2;
     const int src2_linesize = in2->linesize[0];
-    const float div = in1->width * in1->height * 3;
+    const float div = in1->width * in1->height * s->use_alpha ? 4 : 3;
     unsigned mean_err = 0;
 
     for (y = 0; y < in1->height; y++) {
         for (x = 0; x < in1->width; x++) {
             const uint32_t c1 = src1[x];
             const uint32_t c2 = palette[src2[x]];
-            const uint8_t argb1[] = {0xff, c1 >> 16 & 0xff, c1 >> 8 & 0xff, c1 & 0xff};
-            const uint8_t argb2[] = {0xff, c2 >> 16 & 0xff, c2 >> 8 & 0xff, c2 & 0xff};
-            mean_err += diff(argb1, argb2, s->trans_thresh);
+            const uint8_t a1 = s->use_alpha ? c1>>24 & 0xff : 0xff;
+            const uint8_t a2 = s->use_alpha ? c2>>24 & 0xff : 0xff;
+            const uint8_t argb1[] = {a1, c1 >> 16 & 0xff, c1 >> 8 & 0xff, c1 & 0xff};
+            const uint8_t argb2[] = {a2, c2 >> 16 & 0xff, c2 >> 8 & 0xff, c2 & 0xff};
+            mean_err += diff(argb1, argb2, s);
         }
         src1 += src1_linesize;
         src2 += src2_linesize;
@@ -987,7 +1022,7 @@  static void load_palette(PaletteUseContext *s, const AVFrame *palette_frame)
     for (y = 0; y < palette_frame->height; y++) {
         for (x = 0; x < palette_frame->width; x++) {
             s->palette[i] = p[x];
-            if (p[x]>>24 < s->trans_thresh) {
+            if (!s->use_alpha && p[x]>>24 < s->trans_thresh) {
                 s->transparency_index = i; // we are assuming at most one transparent color in palette
             }
             i++;