diff mbox series

[FFmpeg-devel] avfilter/palettegen, paletteuse: Extend the palette conversion filters to support palettes with alpha

Message ID MN2PR04MB59810446F0C93D0C52F9B3CEBAA69@MN2PR04MB5981.namprd04.prod.outlook.com
State Superseded, archived
Headers show
Series [FFmpeg-devel] avfilter/palettegen, paletteuse: Extend the palette conversion filters to support palettes with alpha | expand

Checks

Context Check Description
andriy/make_x86 success Make finished
andriy/make_fate_x86 success Make fate finished
andriy/makex86 warning New warnings during build
andriy/make_ppc success Make finished
andriy/make_fate_ppc success Make fate finished

Commit Message

Soft Works Sept. 26, 2021, 12:43 a.m. UTC
Usage example:

ffmpeg -y -loglevel verbose -i "..\fate-suite\apng\o_sample.png" -filter_complex "split[split1][split2];[split1]palettegen=max_colors=254:use_alpha=1[pal1];[split2][pal1]paletteuse=use_alpha=1" -frames:v 1 out.png

Signed-off-by: softworkz <softworkz@hotmail.com>
---
 doc/filters.texi            |   8 ++
 libavfilter/vf_palettegen.c | 140 ++++++++++++++--------
 libavfilter/vf_paletteuse.c | 225 +++++++++++++++++++++---------------
 3 files changed, 233 insertions(+), 140 deletions(-)

Comments

Michael Niedermayer Sept. 26, 2021, 5:01 p.m. UTC | #1
On Sun, Sep 26, 2021 at 12:43:37AM +0000, Soft Works wrote:
> Usage example:
> 
> ffmpeg -y -loglevel verbose -i "..\fate-suite\apng\o_sample.png" -filter_complex "split[split1][split2];[split1]palettegen=max_colors=254:use_alpha=1[pal1];[split2][pal1]paletteuse=use_alpha=1" -frames:v 1 out.png
> 
> Signed-off-by: softworkz <softworkz@hotmail.com>
> ---
>  doc/filters.texi            |   8 ++
>  libavfilter/vf_palettegen.c | 140 ++++++++++++++--------
>  libavfilter/vf_paletteuse.c | 225 +++++++++++++++++++++---------------
>  3 files changed, 233 insertions(+), 140 deletions(-)
> 
> diff --git a/doc/filters.texi b/doc/filters.texi
> index 36113e5c4b..7e4806235c 100644
> --- a/doc/filters.texi
> +++ b/doc/filters.texi
> @@ -16454,6 +16454,9 @@ Compute new histogram for each frame.
>  @end table
>  
>  Default value is @var{full}.
> +@item use_alpha
> +Create a palette of colors with alpha components.
> +Setting this, will automatically disable 'reserve_transparent'.
>  @end table
>  
>  The filter also exports the frame metadata @code{lavfi.color_quant_ratio}
> @@ -16532,6 +16535,11 @@ will be treated as completely opaque, and values below this threshold will be
>  treated as completely transparent.
>  
>  The option must be an integer value in the range [0,255]. Default is @var{128}.
> +
> +@item use_alpha
> +Apply the palette by taking alpha values into account. Only useful with 
> +palettes that are containing multiple colors with alpha components.
> +Setting this will automatically disable 'alpha_treshold'.
>  @end table
>  
>  @subsection Examples
> diff --git a/libavfilter/vf_palettegen.c b/libavfilter/vf_palettegen.c
> index 4c2fbd36d7..7a74a3752f 100644
> --- a/libavfilter/vf_palettegen.c
> +++ b/libavfilter/vf_palettegen.c
> @@ -59,7 +59,7 @@ enum {
>  };
>  
>  #define NBITS 5
> -#define HIST_SIZE (1<<(3*NBITS))
> +#define HIST_SIZE (1<<(4*NBITS))
>  
>  typedef struct PaletteGenContext {
>      const AVClass *class;
> @@ -67,6 +67,7 @@ typedef struct PaletteGenContext {
>      int max_colors;
>      int reserve_transparent;
>      int stats_mode;
> +    int use_alpha;
>  
>      AVFrame *prev_frame;                    // previous frame used for the diff stats_mode
>      struct hist_node histogram[HIST_SIZE];  // histogram/hashtable of the colors
> @@ -88,6 +89,7 @@ static const AVOption palettegen_options[] = {
>          { "full", "compute full frame histograms", 0, AV_OPT_TYPE_CONST, {.i64=STATS_MODE_ALL_FRAMES}, INT_MIN, INT_MAX, FLAGS, "mode" },
>          { "diff", "compute histograms only for the part that differs from previous frame", 0, AV_OPT_TYPE_CONST, {.i64=STATS_MODE_DIFF_FRAMES}, INT_MIN, INT_MAX, FLAGS, "mode" },
>          { "single", "compute new histogram for each frame", 0, AV_OPT_TYPE_CONST, {.i64=STATS_MODE_SINGLE_FRAMES}, INT_MIN, INT_MAX, FLAGS, "mode" },
> +    { "use_alpha", "create a palette including alpha values", OFFSET(use_alpha), AV_OPT_TYPE_BOOL, {.i64 = 0}, 0, 1, FLAGS },
>      { NULL }
>  };
>  
> @@ -113,15 +115,16 @@ static int cmp_##name(const void *pa, const void *pb)   \
>  {                                                       \
>      const struct color_ref * const *a = pa;             \
>      const struct color_ref * const *b = pb;             \
> -    return   (int)((*a)->color >> (8 * (2 - (pos))) & 0xff)  \
> -           - (int)((*b)->color >> (8 * (2 - (pos))) & 0xff); \
> +    return   (int)((*a)->color >> (8 * (3 - (pos))) & 0xff)  \
> +           - (int)((*b)->color >> (8 * (3 - (pos))) & 0xff); \
>  }
>  
> -DECLARE_CMP_FUNC(r, 0)
> -DECLARE_CMP_FUNC(g, 1)
> -DECLARE_CMP_FUNC(b, 2)
> +DECLARE_CMP_FUNC(a, 0)
> +DECLARE_CMP_FUNC(r, 1)
> +DECLARE_CMP_FUNC(g, 2)
> +DECLARE_CMP_FUNC(b, 3)
>  
> -static const cmp_func cmp_funcs[] = {cmp_r, cmp_g, cmp_b};
> +static const cmp_func cmp_funcs[] = {cmp_a, cmp_r, cmp_g, cmp_b};
>  
>  /**
>   * Simple color comparison for sorting the final palette
> @@ -143,6 +146,17 @@ static av_always_inline int diff(const uint32_t a, const uint32_t b)
>      return dr*dr + dg*dg + db*db;
>  }
>  
> +static av_always_inline int diff_alpha(const uint32_t a, const uint32_t b)
> +{
> +    const uint8_t c1[] = {a >> 24 & 0xff, a >> 16 & 0xff, a >> 8 & 0xff, a & 0xff};
> +    const uint8_t c2[] = {b >> 24 & 0xff, b >> 16 & 0xff, b >> 8 & 0xff, b & 0xff};
> +    const int da = c1[0] - c2[0];
> +    const int dr = c1[1] - c2[1];
> +    const int dg = c1[2] - c2[2];
> +    const int db = c1[3] - c2[3];
> +    return da*da + dr*dr + dg*dg + db*db;
> +}
> +
>  /**
>   * Find the next box to split: pick the one with the highest variance
>   */
> @@ -164,7 +178,10 @@ static int get_next_box_id_to_split(PaletteGenContext *s)
>  
>                  for (i = 0; i < box->len; i++) {
>                      const struct color_ref *ref = s->refs[box->start + i];
> -                    variance += diff(ref->color, box->color) * ref->count;
> +                    if (s->use_alpha)
> +                        variance += (int64_t)diff_alpha(ref->color, box->color) * ref->count;
> +                    else
> +                        variance += (int64_t)diff(ref->color, box->color) * ref->count;
>                  }
>                  box->variance = variance;
>              }
> @@ -184,24 +201,31 @@ static int get_next_box_id_to_split(PaletteGenContext *s)
>   * specified box. Takes into account the weight of each color.
>   */
>  static uint32_t get_avg_color(struct color_ref * const *refs,
> -                              const struct range_box *box)
> +                              const struct range_box *box, int use_alpha)
>  {
>      int i;
>      const int n = box->len;
> -    uint64_t r = 0, g = 0, b = 0, div = 0;
> +    uint64_t a = 0, r = 0, g = 0, b = 0, div = 0;
>  
>      for (i = 0; i < n; i++) {
>          const struct color_ref *ref = refs[box->start + i];
> -        r += (ref->color >> 16 & 0xff) * ref->count;
> -        g += (ref->color >>  8 & 0xff) * ref->count;
> -        b += (ref->color       & 0xff) * ref->count;
> +        if (use_alpha)
> +            a += (ref->color >> 24 & 0xff) * ref->count;
> +        r += (ref->color     >> 16 & 0xff) * ref->count;
> +        g += (ref->color     >>  8 & 0xff) * ref->count;
> +        b += (ref->color           & 0xff) * ref->count;
>          div += ref->count;
>      }
>  
> +    if (use_alpha)
> +        a = a / div;
>      r = r / div;
>      g = g / div;
>      b = b / div;
>  
> +    if (use_alpha)
> +        return a<<24 | r<<16 | g<<8 | b;
> +
>      return 0xffU<<24 | r<<16 | g<<8 | b;
>  }
>  
> @@ -220,8 +244,8 @@ static void split_box(PaletteGenContext *s, struct range_box *box, int n)
>      av_assert0(box->len     >= 1);
>      av_assert0(new_box->len >= 1);
>  
> -    box->color     = get_avg_color(s->refs, box);
> -    new_box->color = get_avg_color(s->refs, new_box);
> +    box->color     = get_avg_color(s->refs, box, s->use_alpha);
> +    new_box->color = get_avg_color(s->refs, new_box, s->use_alpha);
>      box->variance     = -1;
>      new_box->variance = -1;
>  }

> @@ -242,7 +266,7 @@ static void write_palette(AVFilterContext *ctx, AVFrame *out)
>              if (box_id < s->nb_boxes) {
>                  pal[x] = s->boxes[box_id++].color;
>                  if ((x || y) && pal[x] == last_color)
> -                    av_log(ctx, AV_LOG_WARNING, "Dupped color: %08"PRIX32"\n", pal[x]);
> +                    av_log(ctx, AV_LOG_WARNING, "Duped color: %08"PRIX32"\n", pal[x]);
>                  last_color = pal[x];
>              } else {
>                  pal[x] = last_color; // pad with last color

should be in a seperate patch, this is not related to alpha


[...]

> -    return r<<(NBITS*2) | g<<NBITS | b;
> +    return r << (NBITS * 2) | g << NBITS | b;

[...]
> -    const int s = kd->split;
> +    const int split = kd->split;

also unrelated to the alpha change
all these changes are ok but please do them in a seperate patch, that makes the commits
easier to read in the future

thx

[...]
Soft Works Sept. 26, 2021, 5:22 p.m. UTC | #2
> -----Original Message-----
> From: ffmpeg-devel <ffmpeg-devel-bounces@ffmpeg.org> On Behalf Of
> Michael Niedermayer
> Sent: Sunday, 26 September 2021 19:01
> To: FFmpeg development discussions and patches <ffmpeg-
> devel@ffmpeg.org>
> Subject: Re: [FFmpeg-devel] [PATCH] avfilter/palettegen, paletteuse:
> Extend the palette conversion filters to support palettes with alpha
> 
> 
> [...]
> > -    const int s = kd->split;
> > +    const int split = kd->split;
> 
> also unrelated to the alpha change

This change is required due to the newly added parameter 's'.
The context is named 's' in all other functions, that's why
was required to name the previous local 's' variable.

Thanks,
softworkz
Soft Works Sept. 26, 2021, 5:33 p.m. UTC | #3
> -----Original Message-----
> From: ffmpeg-devel <ffmpeg-devel-bounces@ffmpeg.org> On Behalf Of
> Michael Niedermayer
> Sent: Sunday, 26 September 2021 19:01
> To: FFmpeg development discussions and patches <ffmpeg-
> devel@ffmpeg.org>
> Subject: Re: [FFmpeg-devel] [PATCH] avfilter/palettegen, paletteuse:
> Extend the palette conversion filters to support palettes with alpha
> 

A general coding question on the use of the comma operator:

  if (dx <= 0) nearer_kd_id = kd->left_id,  further_kd_id = kd->right_id;
  else         nearer_kd_id = kd->right_id, further_kd_id = kd->left_id;

I was tempted to change this, but I'm not sure how it is considered 
by the developers here?


softworkz
diff mbox series

Patch

diff --git a/doc/filters.texi b/doc/filters.texi
index 36113e5c4b..7e4806235c 100644
--- a/doc/filters.texi
+++ b/doc/filters.texi
@@ -16454,6 +16454,9 @@  Compute new histogram for each frame.
 @end table
 
 Default value is @var{full}.
+@item use_alpha
+Create a palette of colors with alpha components.
+Setting this, will automatically disable 'reserve_transparent'.
 @end table
 
 The filter also exports the frame metadata @code{lavfi.color_quant_ratio}
@@ -16532,6 +16535,11 @@  will be treated as completely opaque, and values below this threshold will be
 treated as completely transparent.
 
 The option must be an integer value in the range [0,255]. Default is @var{128}.
+
+@item use_alpha
+Apply the palette by taking alpha values into account. Only useful with 
+palettes that are containing multiple colors with alpha components.
+Setting this will automatically disable 'alpha_treshold'.
 @end table
 
 @subsection Examples
diff --git a/libavfilter/vf_palettegen.c b/libavfilter/vf_palettegen.c
index 4c2fbd36d7..7a74a3752f 100644
--- a/libavfilter/vf_palettegen.c
+++ b/libavfilter/vf_palettegen.c
@@ -59,7 +59,7 @@  enum {
 };
 
 #define NBITS 5
-#define HIST_SIZE (1<<(3*NBITS))
+#define HIST_SIZE (1<<(4*NBITS))
 
 typedef struct PaletteGenContext {
     const AVClass *class;
@@ -67,6 +67,7 @@  typedef struct PaletteGenContext {
     int max_colors;
     int reserve_transparent;
     int stats_mode;
+    int use_alpha;
 
     AVFrame *prev_frame;                    // previous frame used for the diff stats_mode
     struct hist_node histogram[HIST_SIZE];  // histogram/hashtable of the colors
@@ -88,6 +89,7 @@  static const AVOption palettegen_options[] = {
         { "full", "compute full frame histograms", 0, AV_OPT_TYPE_CONST, {.i64=STATS_MODE_ALL_FRAMES}, INT_MIN, INT_MAX, FLAGS, "mode" },
         { "diff", "compute histograms only for the part that differs from previous frame", 0, AV_OPT_TYPE_CONST, {.i64=STATS_MODE_DIFF_FRAMES}, INT_MIN, INT_MAX, FLAGS, "mode" },
         { "single", "compute new histogram for each frame", 0, AV_OPT_TYPE_CONST, {.i64=STATS_MODE_SINGLE_FRAMES}, INT_MIN, INT_MAX, FLAGS, "mode" },
+    { "use_alpha", "create a palette including alpha values", OFFSET(use_alpha), AV_OPT_TYPE_BOOL, {.i64 = 0}, 0, 1, FLAGS },
     { NULL }
 };
 
@@ -113,15 +115,16 @@  static int cmp_##name(const void *pa, const void *pb)   \
 {                                                       \
     const struct color_ref * const *a = pa;             \
     const struct color_ref * const *b = pb;             \
-    return   (int)((*a)->color >> (8 * (2 - (pos))) & 0xff)  \
-           - (int)((*b)->color >> (8 * (2 - (pos))) & 0xff); \
+    return   (int)((*a)->color >> (8 * (3 - (pos))) & 0xff)  \
+           - (int)((*b)->color >> (8 * (3 - (pos))) & 0xff); \
 }
 
-DECLARE_CMP_FUNC(r, 0)
-DECLARE_CMP_FUNC(g, 1)
-DECLARE_CMP_FUNC(b, 2)
+DECLARE_CMP_FUNC(a, 0)
+DECLARE_CMP_FUNC(r, 1)
+DECLARE_CMP_FUNC(g, 2)
+DECLARE_CMP_FUNC(b, 3)
 
-static const cmp_func cmp_funcs[] = {cmp_r, cmp_g, cmp_b};
+static const cmp_func cmp_funcs[] = {cmp_a, cmp_r, cmp_g, cmp_b};
 
 /**
  * Simple color comparison for sorting the final palette
@@ -143,6 +146,17 @@  static av_always_inline int diff(const uint32_t a, const uint32_t b)
     return dr*dr + dg*dg + db*db;
 }
 
+static av_always_inline int diff_alpha(const uint32_t a, const uint32_t b)
+{
+    const uint8_t c1[] = {a >> 24 & 0xff, a >> 16 & 0xff, a >> 8 & 0xff, a & 0xff};
+    const uint8_t c2[] = {b >> 24 & 0xff, b >> 16 & 0xff, b >> 8 & 0xff, b & 0xff};
+    const int da = c1[0] - c2[0];
+    const int dr = c1[1] - c2[1];
+    const int dg = c1[2] - c2[2];
+    const int db = c1[3] - c2[3];
+    return da*da + dr*dr + dg*dg + db*db;
+}
+
 /**
  * Find the next box to split: pick the one with the highest variance
  */
@@ -164,7 +178,10 @@  static int get_next_box_id_to_split(PaletteGenContext *s)
 
                 for (i = 0; i < box->len; i++) {
                     const struct color_ref *ref = s->refs[box->start + i];
-                    variance += diff(ref->color, box->color) * ref->count;
+                    if (s->use_alpha)
+                        variance += (int64_t)diff_alpha(ref->color, box->color) * ref->count;
+                    else
+                        variance += (int64_t)diff(ref->color, box->color) * ref->count;
                 }
                 box->variance = variance;
             }
@@ -184,24 +201,31 @@  static int get_next_box_id_to_split(PaletteGenContext *s)
  * specified box. Takes into account the weight of each color.
  */
 static uint32_t get_avg_color(struct color_ref * const *refs,
-                              const struct range_box *box)
+                              const struct range_box *box, int use_alpha)
 {
     int i;
     const int n = box->len;
-    uint64_t r = 0, g = 0, b = 0, div = 0;
+    uint64_t a = 0, r = 0, g = 0, b = 0, div = 0;
 
     for (i = 0; i < n; i++) {
         const struct color_ref *ref = refs[box->start + i];
-        r += (ref->color >> 16 & 0xff) * ref->count;
-        g += (ref->color >>  8 & 0xff) * ref->count;
-        b += (ref->color       & 0xff) * ref->count;
+        if (use_alpha)
+            a += (ref->color >> 24 & 0xff) * ref->count;
+        r += (ref->color     >> 16 & 0xff) * ref->count;
+        g += (ref->color     >>  8 & 0xff) * ref->count;
+        b += (ref->color           & 0xff) * ref->count;
         div += ref->count;
     }
 
+    if (use_alpha)
+        a = a / div;
     r = r / div;
     g = g / div;
     b = b / div;
 
+    if (use_alpha)
+        return a<<24 | r<<16 | g<<8 | b;
+
     return 0xffU<<24 | r<<16 | g<<8 | b;
 }
 
@@ -220,8 +244,8 @@  static void split_box(PaletteGenContext *s, struct range_box *box, int n)
     av_assert0(box->len     >= 1);
     av_assert0(new_box->len >= 1);
 
-    box->color     = get_avg_color(s->refs, box);
-    new_box->color = get_avg_color(s->refs, new_box);
+    box->color     = get_avg_color(s->refs, box, s->use_alpha);
+    new_box->color = get_avg_color(s->refs, new_box, s->use_alpha);
     box->variance     = -1;
     new_box->variance = -1;
 }
@@ -242,7 +266,7 @@  static void write_palette(AVFilterContext *ctx, AVFrame *out)
             if (box_id < s->nb_boxes) {
                 pal[x] = s->boxes[box_id++].color;
                 if ((x || y) && pal[x] == last_color)
-                    av_log(ctx, AV_LOG_WARNING, "Dupped color: %08"PRIX32"\n", pal[x]);
+                    av_log(ctx, AV_LOG_WARNING, "Duped color: %08"PRIX32"\n", pal[x]);
                 last_color = pal[x];
             } else {
                 pal[x] = last_color; // pad with last color
@@ -251,7 +275,7 @@  static void write_palette(AVFilterContext *ctx, AVFrame *out)
         pal += pal_linesize;
     }
 
-    if (s->reserve_transparent) {
+    if (s->reserve_transparent && !s->use_alpha) {
         av_assert0(s->nb_boxes < 256);
         pal[out->width - pal_linesize - 1] = AV_RB32(&s->transparency_color) >> 8;
     }
@@ -319,40 +343,49 @@  static AVFrame *get_palette_frame(AVFilterContext *ctx)
     box = &s->boxes[box_id];
     box->len = s->nb_refs;
     box->sorted_by = -1;
-    box->color = get_avg_color(s->refs, box);
+    box->color = get_avg_color(s->refs, box, s->use_alpha);
     box->variance = -1;
     s->nb_boxes = 1;
 
     while (box && box->len > 1) {
-        int i, rr, gr, br, longest;
+        int i, ar, rr, gr, br, longest;
         uint64_t median, box_weight = 0;
 
         /* compute the box weight (sum all the weights of the colors in the
          * range) and its boundings */
-        uint8_t min[3] = {0xff, 0xff, 0xff};
-        uint8_t max[3] = {0x00, 0x00, 0x00};
+        uint8_t min[4] = {0xff, 0xff, 0xff, 0xff};
+        uint8_t max[4] = {0x00, 0x00, 0x00, 0x00};
         for (i = box->start; i < box->start + box->len; i++) {
             const struct color_ref *ref = s->refs[i];
             const uint32_t rgb = ref->color;
-            const uint8_t r = rgb >> 16 & 0xff, g = rgb >> 8 & 0xff, b = rgb & 0xff;
-            min[0] = FFMIN(r, min[0]), max[0] = FFMAX(r, max[0]);
-            min[1] = FFMIN(g, min[1]), max[1] = FFMAX(g, max[1]);
-            min[2] = FFMIN(b, min[2]), max[2] = FFMAX(b, max[2]);
+            const uint8_t a = rgb >> 24 & 0xff, r = rgb >> 16 & 0xff, g = rgb >> 8 & 0xff, b = rgb & 0xff;
+            min[0] = FFMIN(a, min[0]); max[0] = FFMAX(a, max[0]);
+            min[1] = FFMIN(r, min[1]); max[1] = FFMAX(r, max[1]);
+            min[2] = FFMIN(g, min[2]); max[2] = FFMAX(g, max[2]);
+            min[3] = FFMIN(b, min[3]); max[3] = FFMAX(b, max[3]);
             box_weight += ref->count;
         }
 
         /* define the axis to sort by according to the widest range of colors */
-        rr = max[0] - min[0];
-        gr = max[1] - min[1];
-        br = max[2] - min[2];
-        longest = 1; // pick green by default (the color the eye is the most sensitive to)
-        if (br >= rr && br >= gr) longest = 2;
-        if (rr >= gr && rr >= br) longest = 0;
-        if (gr >= rr && gr >= br) longest = 1; // prefer green again
-
-        ff_dlog(ctx, "box #%02X [%6d..%-6d] (%6d) w:%-6"PRIu64" ranges:[%2x %2x %2x] sort by %c (already sorted:%c) ",
+        ar = max[0] - min[0];
+        rr = max[1] - min[1];
+        gr = max[2] - min[2];
+        br = max[3] - min[3];
+        longest = 2; // pick green by default (the color the eye is the most sensitive to)
+        if (s->use_alpha) {
+            if (ar >= rr && ar >= br && ar >= gr) longest = 0;
+            if (br >= rr && br >= gr && br >= ar) longest = 3;
+            if (rr >= gr && rr >= br && rr >= ar) longest = 1;
+            if (gr >= rr && gr >= br && gr >= ar) longest = 2; // prefer green again
+        } else {
+            if (br >= rr && br >= gr) longest = 3;
+            if (rr >= gr && rr >= br) longest = 1;
+            if (gr >= rr && gr >= br) longest = 2; // prefer green again
+        }
+
+        ff_dlog(ctx, "box #%02X [%6d..%-6d] (%6d) w:%-6"PRIu64" ranges:[%2x %2x %2x %2x] sort by %c (already sorted:%c) ",
                 box_id, box->start, box->start + box->len - 1, box->len, box_weight,
-                rr, gr, br, "rgb"[longest], box->sorted_by == longest ? 'y':'n');
+                ar, rr, gr, br, "argb"[longest], box->sorted_by == longest ? 'y' : 'n');
 
         /* sort the range by its longest axis if it's not already sorted */
         if (box->sorted_by != longest) {
@@ -394,21 +427,27 @@  static AVFrame *get_palette_frame(AVFilterContext *ctx)
  * It keeps the NBITS least significant bit of each component to make it
  * "random" even if the scene doesn't have much different colors.
  */
-static inline unsigned color_hash(uint32_t color)
+static inline unsigned color_hash(uint32_t color, int use_alpha)
 {
     const uint8_t r = color >> 16 & ((1<<NBITS)-1);
     const uint8_t g = color >>  8 & ((1<<NBITS)-1);
     const uint8_t b = color       & ((1<<NBITS)-1);
-    return r<<(NBITS*2) | g<<NBITS | b;
+
+    if (use_alpha) {
+        const uint8_t a = color >> 24 & ((1 << NBITS) - 1);
+        return a << (NBITS * 3) | r << (NBITS * 2) | g << NBITS | b;
+    }
+
+    return r << (NBITS * 2) | g << NBITS | b;
 }
 
 /**
  * Locate the color in the hash table and increment its counter.
  */
-static int color_inc(struct hist_node *hist, uint32_t color)
+static int color_inc(struct hist_node *hist, uint32_t color, int use_alpha)
 {
     int i;
-    const unsigned hash = color_hash(color);
+    const unsigned hash = color_hash(color, use_alpha);
     struct hist_node *node = &hist[hash];
     struct color_ref *e;
 
@@ -433,7 +472,7 @@  static int color_inc(struct hist_node *hist, uint32_t color)
  * Update histogram when pixels differ from previous frame.
  */
 static int update_histogram_diff(struct hist_node *hist,
-                                 const AVFrame *f1, const AVFrame *f2)
+                                 const AVFrame *f1, const AVFrame *f2, int use_alpha)
 {
     int x, y, ret, nb_diff_colors = 0;
 
@@ -444,7 +483,7 @@  static int update_histogram_diff(struct hist_node *hist,
         for (x = 0; x < f1->width; x++) {
             if (p[x] == q[x])
                 continue;
-            ret = color_inc(hist, p[x]);
+            ret = color_inc(hist, p[x], use_alpha);
             if (ret < 0)
                 return ret;
             nb_diff_colors += ret;
@@ -456,7 +495,7 @@  static int update_histogram_diff(struct hist_node *hist,
 /**
  * Simple histogram of the frame.
  */
-static int update_histogram_frame(struct hist_node *hist, const AVFrame *f)
+static int update_histogram_frame(struct hist_node *hist, const AVFrame *f, int use_alpha)
 {
     int x, y, ret, nb_diff_colors = 0;
 
@@ -464,7 +503,7 @@  static int update_histogram_frame(struct hist_node *hist, const AVFrame *f)
         const uint32_t *p = (const uint32_t *)(f->data[0] + y*f->linesize[0]);
 
         for (x = 0; x < f->width; x++) {
-            ret = color_inc(hist, p[x]);
+            ret = color_inc(hist, p[x], use_alpha);
             if (ret < 0)
                 return ret;
             nb_diff_colors += ret;
@@ -480,8 +519,8 @@  static int filter_frame(AVFilterLink *inlink, AVFrame *in)
 {
     AVFilterContext *ctx = inlink->dst;
     PaletteGenContext *s = ctx->priv;
-    int ret = s->prev_frame ? update_histogram_diff(s->histogram, s->prev_frame, in)
-                            : update_histogram_frame(s->histogram, in);
+    int ret = s->prev_frame ? update_histogram_diff(s->histogram, s->prev_frame, in, s->use_alpha)
+                            : update_histogram_frame(s->histogram, in, s->use_alpha);
 
     if (ret > 0)
         s->nb_refs += ret;
@@ -540,6 +579,16 @@  static int config_output(AVFilterLink *outlink)
     return 0;
 }
 
+static int init(AVFilterContext *ctx)
+{
+    PaletteGenContext* s = ctx->priv;
+
+    if (s->use_alpha && s->reserve_transparent)
+        s->reserve_transparent = 0;
+
+    return 0;
+}
+
 static av_cold void uninit(AVFilterContext *ctx)
 {
     int i;
@@ -572,6 +621,7 @@  const AVFilter ff_vf_palettegen = {
     .name          = "palettegen",
     .description   = NULL_IF_CONFIG_SMALL("Find the optimal palette for a given stream."),
     .priv_size     = sizeof(PaletteGenContext),
+    .init        = init,
     .uninit        = uninit,
     .query_formats = query_formats,
     FILTER_INPUTS(palettegen_inputs),
diff --git a/libavfilter/vf_paletteuse.c b/libavfilter/vf_paletteuse.c
index f9bc28f7d0..2ac30b4e5d 100644
--- a/libavfilter/vf_paletteuse.c
+++ b/libavfilter/vf_paletteuse.c
@@ -28,7 +28,6 @@ 
 #include "libavutil/opt.h"
 #include "libavutil/qsort.h"
 #include "avfilter.h"
-#include "filters.h"
 #include "framesync.h"
 #include "internal.h"
 
@@ -63,7 +62,7 @@  struct color_node {
 };
 
 #define NBITS 5
-#define CACHE_SIZE (1<<(3*NBITS))
+#define CACHE_SIZE (1<<(4*NBITS))
 
 struct cached_color {
     uint32_t color;
@@ -88,6 +87,7 @@  typedef struct PaletteUseContext {
     uint32_t palette[AVPALETTE_COUNT];
     int transparency_index; /* index in the palette of transparency. -1 if there is no transparency in the palette. */
     int trans_thresh;
+    int use_alpha;
     int palette_loaded;
     int dither;
     int new;
@@ -107,7 +107,7 @@  typedef struct PaletteUseContext {
 } PaletteUseContext;
 
 #define OFFSET(x) offsetof(PaletteUseContext, x)
-#define FLAGS AV_OPT_FLAG_FILTERING_PARAM|AV_OPT_FLAG_VIDEO_PARAM
+#define FLAGS (AV_OPT_FLAG_FILTERING_PARAM | AV_OPT_FLAG_VIDEO_PARAM)
 static const AVOption paletteuse_options[] = {
     { "dither", "select dithering mode", OFFSET(dither), AV_OPT_TYPE_INT, {.i64=DITHERING_SIERRA2_4A}, 0, NB_DITHERING-1, FLAGS, "dithering_mode" },
         { "bayer",           "ordered 8x8 bayer dithering (deterministic)",                            0, AV_OPT_TYPE_CONST, {.i64=DITHERING_BAYER},           INT_MIN, INT_MAX, FLAGS, "dithering_mode" },
@@ -120,6 +120,7 @@  static const AVOption paletteuse_options[] = {
         { "rectangle", "process smallest different rectangle", 0, AV_OPT_TYPE_CONST, {.i64=DIFF_MODE_RECTANGLE}, INT_MIN, INT_MAX, FLAGS, "diff_mode" },
     { "new", "take new palette for each output frame", OFFSET(new), AV_OPT_TYPE_BOOL, {.i64=0}, 0, 1, FLAGS },
     { "alpha_threshold", "set the alpha threshold for transparency", OFFSET(trans_thresh), AV_OPT_TYPE_INT, {.i64=128}, 0, 255, FLAGS },
+    { "use_alpha", "use alpha channel for mapping", OFFSET(use_alpha), AV_OPT_TYPE_BOOL, {.i64=0}, 0, 1, FLAGS },
 
     /* following are the debug options, not part of the official API */
     { "debug_kdtree", "save Graphviz graph of the kdtree in specified file", OFFSET(dot_filename), AV_OPT_TYPE_STRING, {.str=NULL}, 0, 0, FLAGS },
@@ -161,37 +162,41 @@  static av_always_inline uint32_t dither_color(uint32_t px, int er, int eg,
          | av_clip_uint8((px       & 0xff) + ((eb * scale) / (1<<shift)));
 }
 
-static av_always_inline int diff(const uint8_t *c1, const uint8_t *c2, const int trans_thresh)
+static av_always_inline int diff(const uint8_t *c1, const uint8_t *c2, const PaletteUseContext *s)
 {
     // XXX: try L*a*b with CIE76 (dL*dL + da*da + db*db)
+    const int da = c1[0] - c2[0];
     const int dr = c1[1] - c2[1];
     const int dg = c1[2] - c2[2];
     const int db = c1[3] - c2[3];
 
-    if (c1[0] < trans_thresh && c2[0] < trans_thresh) {
+    if (s->use_alpha)
+        return da*da + dr*dr + dg*dg + db*db;
+
+    if (c1[0] < s->trans_thresh && c2[0] < s->trans_thresh) {
         return 0;
-    } else if (c1[0] >= trans_thresh && c2[0] >= trans_thresh) {
+    } else if (c1[0] >= s->trans_thresh && c2[0] >= s->trans_thresh) {
         return dr*dr + dg*dg + db*db;
     } else {
         return 255*255 + 255*255 + 255*255;
     }
 }
 
-static av_always_inline uint8_t colormap_nearest_bruteforce(const uint32_t *palette, const uint8_t *argb, const int trans_thresh)
+static av_always_inline uint8_t colormap_nearest_bruteforce(const PaletteUseContext *s, const uint8_t *argb)
 {
     int i, pal_id = -1, min_dist = INT_MAX;
 
     for (i = 0; i < AVPALETTE_COUNT; i++) {
-        const uint32_t c = palette[i];
+        const uint32_t c = s->palette[i];
 
-        if (c >> 24 >= trans_thresh) { // ignore transparent entry
+        if (s->use_alpha || c >> 24 >= s->trans_thresh) { // ignore transparent entry
             const uint8_t palargb[] = {
-                palette[i]>>24 & 0xff,
-                palette[i]>>16 & 0xff,
-                palette[i]>> 8 & 0xff,
-                palette[i]     & 0xff,
+                s->palette[i]>>24 & 0xff,
+                s->palette[i]>>16 & 0xff,
+                s->palette[i]>> 8 & 0xff,
+                s->palette[i]     & 0xff,
             };
-            const int d = diff(palargb, argb, trans_thresh);
+            const int d = diff(palargb, argb, s);
             if (d < min_dist) {
                 pal_id = i;
                 min_dist = d;
@@ -207,17 +212,17 @@  struct nearest_color {
     int dist_sqd;
 };
 
-static void colormap_nearest_node(const struct color_node *map,
+static void colormap_nearest_node(const PaletteUseContext *s,
+                                  const struct color_node *map,
                                   const int node_pos,
                                   const uint8_t *target,
-                                  const int trans_thresh,
                                   struct nearest_color *nearest)
 {
     const struct color_node *kd = map + node_pos;
-    const int s = kd->split;
+    const int split = kd->split;
     int dx, nearer_kd_id, further_kd_id;
     const uint8_t *current = kd->val;
-    const int current_to_target = diff(target, current, trans_thresh);
+    const int current_to_target = diff(target, current, s);
 
     if (current_to_target < nearest->dist_sqd) {
         nearest->node_pos = node_pos;
@@ -225,23 +230,23 @@  static void colormap_nearest_node(const struct color_node *map,
     }
 
     if (kd->left_id != -1 || kd->right_id != -1) {
-        dx = target[s] - current[s];
+        dx = target[split] - current[split];
 
         if (dx <= 0) nearer_kd_id = kd->left_id,  further_kd_id = kd->right_id;
         else         nearer_kd_id = kd->right_id, further_kd_id = kd->left_id;
 
         if (nearer_kd_id != -1)
-            colormap_nearest_node(map, nearer_kd_id, target, trans_thresh, nearest);
+            colormap_nearest_node(s, map, nearer_kd_id, target, nearest);
 
         if (further_kd_id != -1 && dx*dx < nearest->dist_sqd)
-            colormap_nearest_node(map, further_kd_id, target, trans_thresh, nearest);
+            colormap_nearest_node(s, map, further_kd_id, target, nearest);
     }
 }
 
-static av_always_inline uint8_t colormap_nearest_recursive(const struct color_node *node, const uint8_t *rgb, const int trans_thresh)
+static av_always_inline uint8_t colormap_nearest_recursive(const PaletteUseContext *s, const struct color_node *node, const uint8_t *rgb)
 {
     struct nearest_color res = {.dist_sqd = INT_MAX, .node_pos = -1};
-    colormap_nearest_node(node, 0, rgb, trans_thresh, &res);
+    colormap_nearest_node(s, node, 0, rgb, &res);
     return node[res.node_pos].palette_id;
 }
 
@@ -250,7 +255,7 @@  struct stack_node {
     int dx2;
 };
 
-static av_always_inline uint8_t colormap_nearest_iterative(const struct color_node *root, const uint8_t *target, const int trans_thresh)
+static av_always_inline uint8_t colormap_nearest_iterative(const PaletteUseContext *s, const struct color_node *root, const uint8_t *target)
 {
     int pos = 0, best_node_id = -1, best_dist = INT_MAX, cur_color_id = 0;
     struct stack_node nodes[16];
@@ -260,7 +265,7 @@  static av_always_inline uint8_t colormap_nearest_iterative(const struct color_no
 
         const struct color_node *kd = &root[cur_color_id];
         const uint8_t *current = kd->val;
-        const int current_to_target = diff(target, current, trans_thresh);
+        const int current_to_target = diff(target, current, s);
 
         /* Compare current color node to the target and update our best node if
          * it's actually better. */
@@ -322,10 +327,10 @@  end:
     return root[best_node_id].palette_id;
 }
 
-#define COLORMAP_NEAREST(search, palette, root, target, trans_thresh)                                    \
-    search == COLOR_SEARCH_NNS_ITERATIVE ? colormap_nearest_iterative(root, target, trans_thresh) :      \
-    search == COLOR_SEARCH_NNS_RECURSIVE ? colormap_nearest_recursive(root, target, trans_thresh) :      \
-                                           colormap_nearest_bruteforce(palette, target, trans_thresh)
+#define COLORMAP_NEAREST(s, search, root, target)                                    \
+    search == COLOR_SEARCH_NNS_ITERATIVE ? colormap_nearest_iterative(s, root, target) :      \
+    search == COLOR_SEARCH_NNS_RECURSIVE ? colormap_nearest_recursive(s, root, target) :      \
+                                           colormap_nearest_bruteforce(s, target)
 
 /**
  * Check if the requested color is in the cache already. If not, find it in the
@@ -362,13 +367,13 @@  static av_always_inline int color_get(PaletteUseContext *s, uint32_t color,
     if (!e)
         return AVERROR(ENOMEM);
     e->color = color;
-    e->pal_entry = COLORMAP_NEAREST(search_method, s->palette, s->map, argb_elts, s->trans_thresh);
+    e->pal_entry = COLORMAP_NEAREST(s, search_method, s->map, argb_elts);
 
     return e->pal_entry;
 }
 
 static av_always_inline int get_dst_color_err(PaletteUseContext *s,
-                                              uint32_t c, int *er, int *eg, int *eb,
+                                              uint32_t c, int *ea, int *er, int *eg, int *eb,
                                               const enum color_search_method search_method)
 {
     const uint8_t a = c >> 24 & 0xff;
@@ -381,8 +386,9 @@  static av_always_inline int get_dst_color_err(PaletteUseContext *s,
         return dstx;
     dstc = s->palette[dstx];
     if (dstx == s->transparency_index) {
-        *er = *eg = *eb = 0;
+        *ea =*er = *eg = *eb = 0;
     } else {
+        *ea = (int)a - (int)(dstc >> 24 & 0xff);
         *er = (int)r - (int)(dstc >> 16 & 0xff);
         *eg = (int)g - (int)(dstc >>  8 & 0xff);
         *eb = (int)b - (int)(dstc       & 0xff);
@@ -406,7 +412,7 @@  static av_always_inline int set_frame(PaletteUseContext *s, AVFrame *out, AVFram
 
     for (y = y_start; y < h; y++) {
         for (x = x_start; x < w; x++) {
-            int er, eg, eb;
+            int ea, er, eg, eb;
 
             if (dither == DITHERING_BAYER) {
                 const int d = s->ordered_dither[(y & 7)<<3 | (x & 7)];
@@ -425,7 +431,7 @@  static av_always_inline int set_frame(PaletteUseContext *s, AVFrame *out, AVFram
 
             } else if (dither == DITHERING_HECKBERT) {
                 const int right = x < w - 1, down = y < h - 1;
-                const int color = get_dst_color_err(s, src[x], &er, &eg, &eb, search_method);
+                const int color = get_dst_color_err(s, src[x], &ea, &er, &eg, &eb, search_method);
 
                 if (color < 0)
                     return color;
@@ -437,7 +443,7 @@  static av_always_inline int set_frame(PaletteUseContext *s, AVFrame *out, AVFram
 
             } else if (dither == DITHERING_FLOYD_STEINBERG) {
                 const int right = x < w - 1, down = y < h - 1, left = x > x_start;
-                const int color = get_dst_color_err(s, src[x], &er, &eg, &eb, search_method);
+                const int color = get_dst_color_err(s, src[x], &ea, &er, &eg, &eb, search_method);
 
                 if (color < 0)
                     return color;
@@ -451,7 +457,7 @@  static av_always_inline int set_frame(PaletteUseContext *s, AVFrame *out, AVFram
             } else if (dither == DITHERING_SIERRA2) {
                 const int right  = x < w - 1, down  = y < h - 1, left  = x > x_start;
                 const int right2 = x < w - 2,                    left2 = x > x_start + 1;
-                const int color = get_dst_color_err(s, src[x], &er, &eg, &eb, search_method);
+                const int color = get_dst_color_err(s, src[x], &ea, &er, &eg, &eb, search_method);
 
                 if (color < 0)
                     return color;
@@ -470,7 +476,7 @@  static av_always_inline int set_frame(PaletteUseContext *s, AVFrame *out, AVFram
 
             } else if (dither == DITHERING_SIERRA2_4A) {
                 const int right = x < w - 1, down = y < h - 1, left = x > x_start;
-                const int color = get_dst_color_err(s, src[x], &er, &eg, &eb, search_method);
+                const int color = get_dst_color_err(s, src[x], &ea, &er, &eg, &eb, search_method);
 
                 if (color < 0)
                     return color;
@@ -553,8 +559,7 @@  static int disp_tree(const struct color_node *node, const char *fname)
     return 0;
 }
 
-static int debug_accuracy(const struct color_node *node, const uint32_t *palette, const int trans_thresh,
-                          const enum color_search_method search_method)
+static int debug_accuracy(const PaletteUseContext *s)
 {
     int r, g, b, ret = 0;
 
@@ -562,19 +567,26 @@  static int debug_accuracy(const struct color_node *node, const uint32_t *palette
         for (g = 0; g < 256; g++) {
             for (b = 0; b < 256; b++) {
                 const uint8_t argb[] = {0xff, r, g, b};
-                const int r1 = COLORMAP_NEAREST(search_method, palette, node, argb, trans_thresh);
-                const int r2 = colormap_nearest_bruteforce(palette, argb, trans_thresh);
+                const int r1 = COLORMAP_NEAREST(s, s->color_search_method, s->map, argb);
+                const int r2 = colormap_nearest_bruteforce(s, argb);
                 if (r1 != r2) {
-                    const uint32_t c1 = palette[r1];
-                    const uint32_t c2 = palette[r2];
-                    const uint8_t palargb1[] = { 0xff, c1>>16 & 0xff, c1>> 8 & 0xff, c1 & 0xff };
-                    const uint8_t palargb2[] = { 0xff, c2>>16 & 0xff, c2>> 8 & 0xff, c2 & 0xff };
-                    const int d1 = diff(palargb1, argb, trans_thresh);
-                    const int d2 = diff(palargb2, argb, trans_thresh);
+                    const uint32_t c1 = s->palette[r1];
+                    const uint32_t c2 = s->palette[r2];
+                    const uint8_t a1 = s->use_alpha ? c1>>24 & 0xff : 0xff;
+                    const uint8_t a2 = s->use_alpha ? c2>>24 & 0xff : 0xff;
+                    const uint8_t palargb1[] = { a1, c1>>16 & 0xff, c1>> 8 & 0xff, c1 & 0xff };
+                    const uint8_t palargb2[] = { a2, c2>>16 & 0xff, c2>> 8 & 0xff, c2 & 0xff };
+                    const int d1 = diff(palargb1, argb, s);
+                    const int d2 = diff(palargb2, argb, s);
                     if (d1 != d2) {
-                        av_log(NULL, AV_LOG_ERROR,
-                               "/!\\ %02X%02X%02X: %d ! %d (%06"PRIX32" ! %06"PRIX32") / dist: %d ! %d\n",
-                               r, g, b, r1, r2, c1 & 0xffffff, c2 & 0xffffff, d1, d2);
+                        if (s->use_alpha)
+                            av_log(NULL, AV_LOG_ERROR,
+                                   "/!\\ %02X%02X%02X: %d ! %d (%08"PRIX32" ! %08"PRIX32") / dist: %d ! %d\n",
+                                   r, g, b, r1, r2, c1, c2, d1, d2);
+                        else
+                            av_log(NULL, AV_LOG_ERROR,
+                                   "/!\\ %02X%02X%02X: %d ! %d (%06"PRIX32" ! %06"PRIX32") / dist: %d ! %d\n",
+                                   r, g, b, r1, r2, c1 & 0xffffff, c2 & 0xffffff, d1, d2);
                         ret = 1;
                     }
                 }
@@ -590,8 +602,8 @@  struct color {
 };
 
 struct color_rect {
-    uint8_t min[3];
-    uint8_t max[3];
+    uint8_t min[4];
+    uint8_t max[4];
 };
 
 typedef int (*cmp_func)(const void *, const void *);
@@ -612,43 +624,47 @@  DECLARE_CMP_FUNC(b, 3)
 
 static const cmp_func cmp_funcs[] = {cmp_a, cmp_r, cmp_g, cmp_b};
 
-static int get_next_color(const uint8_t *color_used, const uint32_t *palette,
-                          const int trans_thresh,
+static int get_next_color(const uint8_t *color_used, const PaletteUseContext *s,
                           int *component, const struct color_rect *box)
 {
-    int wr, wg, wb;
+    int wa, wr, wg, wb;
     int i, longest = 0;
     unsigned nb_color = 0;
     struct color_rect ranges;
     struct color tmp_pal[256];
     cmp_func cmpf;
 
-    ranges.min[0] = ranges.min[1] = ranges.min[2] = 0xff;
-    ranges.max[0] = ranges.max[1] = ranges.max[2] = 0x00;
+    ranges.min[0] = ranges.min[1] = ranges.min[2]  = ranges.min[3]= 0xff;
+    ranges.max[0] = ranges.max[1] = ranges.max[2]  = ranges.max[3]= 0x00;
 
     for (i = 0; i < AVPALETTE_COUNT; i++) {
-        const uint32_t c = palette[i];
+        const uint32_t c = s->palette[i];
         const uint8_t a = c >> 24 & 0xff;
         const uint8_t r = c >> 16 & 0xff;
         const uint8_t g = c >>  8 & 0xff;
         const uint8_t b = c       & 0xff;
 
-        if (a < trans_thresh) {
+        if (!s->use_alpha && a < s->trans_thresh) {
             continue;
         }
 
-        if (color_used[i] || (a != 0xff) ||
-            r < box->min[0] || g < box->min[1] || b < box->min[2] ||
-            r > box->max[0] || g > box->max[1] || b > box->max[2])
+        if (color_used[i] || (a != 0xff && !s->use_alpha) ||
+            r < box->min[1] || g < box->min[2] || b < box->min[3] ||
+            r > box->max[1] || g > box->max[2] || b > box->max[3])
             continue;
 
-        if (r < ranges.min[0]) ranges.min[0] = r;
-        if (g < ranges.min[1]) ranges.min[1] = g;
-        if (b < ranges.min[2]) ranges.min[2] = b;
+        if (s->use_alpha && (a < box->min[0] || a > box->max[0]))
+            continue;
+
+        if (a < ranges.min[0]) ranges.min[0] = a;
+        if (r < ranges.min[1]) ranges.min[1] = r;
+        if (g < ranges.min[2]) ranges.min[2] = g;
+        if (b < ranges.min[3]) ranges.min[3] = b;
 
-        if (r > ranges.max[0]) ranges.max[0] = r;
-        if (g > ranges.max[1]) ranges.max[1] = g;
-        if (b > ranges.max[2]) ranges.max[2] = b;
+        if (a > ranges.max[0]) ranges.max[0] = a;
+        if (r > ranges.max[1]) ranges.max[1] = r;
+        if (g > ranges.max[2]) ranges.max[2] = g;
+        if (b > ranges.max[3]) ranges.max[3] = b;
 
         tmp_pal[nb_color].value  = c;
         tmp_pal[nb_color].pal_id = i;
@@ -660,12 +676,22 @@  static int get_next_color(const uint8_t *color_used, const uint32_t *palette,
         return -1;
 
     /* define longest axis that will be the split component */
-    wr = ranges.max[0] - ranges.min[0];
-    wg = ranges.max[1] - ranges.min[1];
-    wb = ranges.max[2] - ranges.min[2];
-    if (wr >= wg && wr >= wb) longest = 1;
-    if (wg >= wr && wg >= wb) longest = 2;
-    if (wb >= wr && wb >= wg) longest = 3;
+    wa = ranges.max[0] - ranges.min[0];
+    wr = ranges.max[1] - ranges.min[1];
+    wg = ranges.max[2] - ranges.min[2];
+    wb = ranges.max[3] - ranges.min[3];
+
+    if (s->use_alpha) {
+        if (wa >= wr && wa >= wb && wa >= wg) longest = 0;
+        if (wr >= wg && wr >= wb && wr >= wa) longest = 1;
+        if (wg >= wr && wg >= wb && wg >= wa) longest = 2;
+        if (wb >= wr && wb >= wg && wb >= wa) longest = 3;
+    } else {
+        if (wr >= wg && wr >= wb) longest = 1;
+        if (wg >= wr && wg >= wb) longest = 2;
+        if (wb >= wr && wb >= wg) longest = 3;
+    }
+
     cmpf = cmp_funcs[longest];
     *component = longest;
 
@@ -678,8 +704,7 @@  static int get_next_color(const uint8_t *color_used, const uint32_t *palette,
 static int colormap_insert(struct color_node *map,
                            uint8_t *color_used,
                            int *nb_used,
-                           const uint32_t *palette,
-                           const int trans_thresh,
+                           const PaletteUseContext *s,
                            const struct color_rect *box)
 {
     uint32_t c;
@@ -687,14 +712,14 @@  static int colormap_insert(struct color_node *map,
     int node_left_id = -1, node_right_id = -1;
     struct color_node *node;
     struct color_rect box1, box2;
-    const int pal_id = get_next_color(color_used, palette, trans_thresh, &component, box);
+    const int pal_id = get_next_color(color_used, s, &component, box);
 
     if (pal_id < 0)
         return -1;
 
     /* create new node with that color */
     cur_id = (*nb_used)++;
-    c = palette[pal_id];
+    c = s->palette[pal_id];
     node = &map[cur_id];
     node->split = component;
     node->palette_id = pal_id;
@@ -707,13 +732,13 @@  static int colormap_insert(struct color_node *map,
 
     /* get the two boxes this node creates */
     box1 = box2 = *box;
-    box1.max[component-1] = node->val[component];
-    box2.min[component-1] = FFMIN(node->val[component] + 1, 255);
+    box1.max[component] = node->val[component];
+    box2.min[component] = FFMIN(node->val[component] + 1, 255);
 
-    node_left_id = colormap_insert(map, color_used, nb_used, palette, trans_thresh, &box1);
+    node_left_id = colormap_insert(map, color_used, nb_used, s, &box1);
 
-    if (box2.min[component-1] <= box2.max[component-1])
-        node_right_id = colormap_insert(map, color_used, nb_used, palette, trans_thresh, &box2);
+    if (box2.min[component] <= box2.max[component])
+        node_right_id = colormap_insert(map, color_used, nb_used, s, &box2);
 
     node->left_id  = node_left_id;
     node->right_id = node_right_id;
@@ -728,6 +753,13 @@  static int cmp_pal_entry(const void *a, const void *b)
     return c1 - c2;
 }
 
+static int cmp_pal_entry_alpha(const void *a, const void *b)
+{
+    const int c1 = *(const uint32_t *)a;
+    const int c2 = *(const uint32_t *)b;
+    return c1 - c2;
+}
+
 static void load_colormap(PaletteUseContext *s)
 {
     int i, nb_used = 0;
@@ -735,12 +767,13 @@  static void load_colormap(PaletteUseContext *s)
     uint32_t last_color = 0;
     struct color_rect box;
 
-    if (s->transparency_index >= 0) {
+    if (!s->use_alpha && s->transparency_index >= 0) {
         FFSWAP(uint32_t, s->palette[s->transparency_index], s->palette[255]);
     }
 
     /* disable transparent colors and dups */
-    qsort(s->palette, AVPALETTE_COUNT-(s->transparency_index >= 0), sizeof(*s->palette), cmp_pal_entry);
+    qsort(s->palette, AVPALETTE_COUNT-(s->transparency_index >= 0), sizeof(*s->palette),
+        s->use_alpha ? cmp_pal_entry_alpha : cmp_pal_entry);
 
     for (i = 0; i < AVPALETTE_COUNT; i++) {
         const uint32_t c = s->palette[i];
@@ -749,22 +782,22 @@  static void load_colormap(PaletteUseContext *s)
             continue;
         }
         last_color = c;
-        if (c >> 24 < s->trans_thresh) {
+        if (!s->use_alpha && c >> 24 < s->trans_thresh) {
             color_used[i] = 1; // ignore transparent color(s)
             continue;
         }
     }
 
-    box.min[0] = box.min[1] = box.min[2] = 0x00;
-    box.max[0] = box.max[1] = box.max[2] = 0xff;
+    box.min[0] = box.min[1] = box.min[2] = box.min[3] = 0x00;
+    box.max[0] = box.max[1] = box.max[2] = box.max[3] = 0xff;
 
-    colormap_insert(s->map, color_used, &nb_used, s->palette, s->trans_thresh, &box);
+    colormap_insert(s->map, color_used, &nb_used, s, &box);
 
     if (s->dot_filename)
         disp_tree(s->map, s->dot_filename);
 
     if (s->debug_accuracy) {
-        if (!debug_accuracy(s->map, s->palette, s->trans_thresh, s->color_search_method))
+        if (!debug_accuracy(s))
             av_log(NULL, AV_LOG_INFO, "Accuracy check passed\n");
     }
 }
@@ -778,16 +811,18 @@  static void debug_mean_error(PaletteUseContext *s, const AVFrame *in1,
     uint8_t  *src2 =             in2->data[0];
     const int src1_linesize = in1->linesize[0] >> 2;
     const int src2_linesize = in2->linesize[0];
-    const float div = in1->width * in1->height * 3;
+    const float div = in1->width * in1->height * s->use_alpha ? 4 : 3;
     unsigned mean_err = 0;
 
     for (y = 0; y < in1->height; y++) {
         for (x = 0; x < in1->width; x++) {
             const uint32_t c1 = src1[x];
             const uint32_t c2 = palette[src2[x]];
-            const uint8_t argb1[] = {0xff, c1 >> 16 & 0xff, c1 >> 8 & 0xff, c1 & 0xff};
-            const uint8_t argb2[] = {0xff, c2 >> 16 & 0xff, c2 >> 8 & 0xff, c2 & 0xff};
-            mean_err += diff(argb1, argb2, s->trans_thresh);
+            const uint8_t a1 = s->use_alpha ? c1>>24 & 0xff : 0xff;
+            const uint8_t a2 = s->use_alpha ? c2>>24 & 0xff : 0xff;
+            const uint8_t argb1[] = {a1, c1 >> 16 & 0xff, c1 >> 8 & 0xff, c1 & 0xff};
+            const uint8_t argb2[] = {a2, c2 >> 16 & 0xff, c2 >> 8 & 0xff, c2 & 0xff};
+            mean_err += diff(argb1, argb2, s);
         }
         src1 += src1_linesize;
         src2 += src2_linesize;
@@ -987,7 +1022,7 @@  static void load_palette(PaletteUseContext *s, const AVFrame *palette_frame)
     for (y = 0; y < palette_frame->height; y++) {
         for (x = 0; x < palette_frame->width; x++) {
             s->palette[i] = p[x];
-            if (p[x]>>24 < s->trans_thresh) {
+            if (!s->use_alpha && p[x]>>24 < s->trans_thresh) {
                 s->transparency_index = i; // we are assuming at most one transparent color in palette
             }
             i++;