diff mbox series

[FFmpeg-devel,3/4] wavpack: fully support stream parameter changes

Message ID 20200405203241.13033-3-anton@khirnov.net
State Accepted
Commit f1e3e9e2042a77891d8a75ef501f95b9b820da11
Headers show
Series [FFmpeg-devel,1/4] pthread_frame: make sure ff_thread_release_buffer always cleans the frame | expand

Checks

Context Check Description
andriy/ffmpeg-patchwork success Make fate finished

Commit Message

Anton Khirnov April 5, 2020, 8:32 p.m. UTC
Fix invalid memory access on DSD streams with changing channel count.
---
 libavcodec/wavpack.c | 122 +++++++++++++++++++++++++++++++------------
 1 file changed, 90 insertions(+), 32 deletions(-)

Comments

David Bryant April 6, 2020, 3:21 a.m. UTC | #1
On 4/5/20 1:32 PM, Anton Khirnov wrote:
> Fix invalid memory access on DSD streams with changing channel count.
> ---
>  libavcodec/wavpack.c | 122 +++++++++++++++++++++++++++++++------------
>  1 file changed, 90 insertions(+), 32 deletions(-)
>
> diff --git a/libavcodec/wavpack.c b/libavcodec/wavpack.c
> index b27262b94e..9cc4104dd0 100644
> --- a/libavcodec/wavpack.c
> +++ b/libavcodec/wavpack.c
> @@ -20,6 +20,7 @@
>   * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
>   */
>  
> +#include "libavutil/buffer.h"
>  #include "libavutil/channel_layout.h"
>  
>  #define BITSTREAM_READER_LE
> @@ -109,7 +110,10 @@ typedef struct WavpackContext {
>      AVFrame *frame;
>      ThreadFrame curr_frame, prev_frame;
>      Modulation modulation;
> +
> +    AVBufferRef *dsd_ref;
>      DSDContext *dsdctx;
> +    int dsd_channels;
>  } WavpackContext;
>  
>  #define LEVEL_DECAY(a)  (((a) + 0x80) >> 8)
> @@ -978,6 +982,32 @@ static av_cold int wv_alloc_frame_context(WavpackContext *c)
>      return 0;
>  }
>  
> +static int wv_dsd_reset(WavpackContext *s, int channels)
> +{
> +    int i;
> +
> +    s->dsdctx = NULL;
> +    s->dsd_channels = 0;
> +    av_buffer_unref(&s->dsd_ref);
> +
> +    if (!channels)
> +        return 0;
> +
> +    if (channels > INT_MAX / sizeof(*s->dsdctx))
> +        return AVERROR(EINVAL);
> +
> +    s->dsd_ref = av_buffer_allocz(channels * sizeof(*s->dsdctx));
> +    if (!s->dsd_ref)
> +        return AVERROR(ENOMEM);
> +    s->dsdctx = (DSDContext*)s->dsd_ref->data;
> +    s->dsd_channels = channels;
> +
> +    for (i = 0; i < channels; i++)
> +        memset(s->dsdctx[i].buf, 0x69, sizeof(s->dsdctx[i].buf));
> +
> +    return 0;
> +}
> +
>  #if HAVE_THREADS
>  static int init_thread_copy(AVCodecContext *avctx)
>  {
> @@ -1008,6 +1038,17 @@ static int update_thread_context(AVCodecContext *dst, const AVCodecContext *src)
>              return ret;
>      }
>  
> +    av_buffer_unref(&fdst->dsd_ref);
> +    fdst->dsdctx = NULL;
> +    fdst->dsd_channels = 0;
> +    if (fsrc->dsd_ref) {
> +        fdst->dsd_ref = av_buffer_ref(fsrc->dsd_ref);
> +        if (!fdst->dsd_ref)
> +            return AVERROR(ENOMEM);
> +        fdst->dsdctx = (DSDContext*)fdst->dsd_ref->data;
> +        fdst->dsd_channels = fsrc->dsd_channels;
> +    }
> +
>      return 0;
>  }
>  #endif
> @@ -1025,15 +1066,9 @@ static av_cold int wavpack_decode_init(AVCodecContext *avctx)
>      s->curr_frame.f = av_frame_alloc();
>      s->prev_frame.f = av_frame_alloc();
>  
> -    // the DSD to PCM context is shared (and used serially) between all decoding threads
> -    s->dsdctx = av_calloc(avctx->channels, sizeof(DSDContext));
> -
> -    if (!s->curr_frame.f || !s->prev_frame.f || !s->dsdctx)
> +    if (!s->curr_frame.f || !s->prev_frame.f)
>          return AVERROR(ENOMEM);
>  
> -    for (int i = 0; i < avctx->channels; i++)
> -        memset(s->dsdctx[i].buf, 0x69, sizeof(s->dsdctx[i].buf));
> -
>      ff_init_dsd_data();
>  
>      return 0;
> @@ -1053,8 +1088,7 @@ static av_cold int wavpack_decode_end(AVCodecContext *avctx)
>      ff_thread_release_buffer(avctx, &s->prev_frame);
>      av_frame_free(&s->prev_frame.f);
>  
> -    if (!avctx->internal->is_copy)
> -        av_freep(&s->dsdctx);
> +    av_buffer_unref(&s->dsd_ref);
>  
>      return 0;
>  }
> @@ -1065,6 +1099,7 @@ static int wavpack_decode_block(AVCodecContext *avctx, int block_no,
>      WavpackContext *wc = avctx->priv_data;
>      WavpackFrameContext *s;
>      GetByteContext gb;
> +    enum AVSampleFormat sample_fmt;
>      void *samples_l = NULL, *samples_r = NULL;
>      int ret;
>      int got_terms   = 0, got_weights = 0, got_samples = 0,
> @@ -1102,7 +1137,15 @@ static int wavpack_decode_block(AVCodecContext *avctx, int block_no,
>          return AVERROR_INVALIDDATA;
>      }
>      s->frame_flags = bytestream2_get_le32(&gb);
> -    bpp            = av_get_bytes_per_sample(avctx->sample_fmt);
> +
> +    if (s->frame_flags & (WV_FLOAT_DATA | WV_DSD_DATA))
> +        sample_fmt = AV_SAMPLE_FMT_FLTP;
> +    else if ((s->frame_flags & 0x03) <= 1)
> +        sample_fmt = AV_SAMPLE_FMT_S16P;
> +    else
> +        sample_fmt          = AV_SAMPLE_FMT_S32P;
> +
> +    bpp            = av_get_bytes_per_sample(sample_fmt);
>      orig_bpp       = ((s->frame_flags & 0x03) + 1) << 3;
>      multiblock     = (s->frame_flags & WV_SINGLE_BLOCK) != WV_SINGLE_BLOCK;
>  
> @@ -1436,11 +1479,11 @@ static int wavpack_decode_block(AVCodecContext *avctx, int block_no,
>              av_log(avctx, AV_LOG_ERROR, "Hybrid config not found\n");
>              return AVERROR_INVALIDDATA;
>          }
> -        if (!got_float && avctx->sample_fmt == AV_SAMPLE_FMT_FLTP) {
> +        if (!got_float && sample_fmt == AV_SAMPLE_FMT_FLTP) {
>              av_log(avctx, AV_LOG_ERROR, "Float information not found\n");
>              return AVERROR_INVALIDDATA;
>          }
> -        if (s->got_extra_bits && avctx->sample_fmt != AV_SAMPLE_FMT_FLTP) {
> +        if (s->got_extra_bits && sample_fmt != AV_SAMPLE_FMT_FLTP) {
>              const int size   = get_bits_left(&s->gb_extra_bits);
>              const int wanted = s->samples * s->extra_bits << s->stereo_in;
>              if (size < wanted) {
> @@ -1462,27 +1505,54 @@ static int wavpack_decode_block(AVCodecContext *avctx, int block_no,
>      }
>  
>      if (!wc->ch_offset) {
> +        int      new_channels = avctx->channels;
> +        uint64_t new_chmask   = avctx->channel_layout;
> +        int new_samplerate;
>          int sr = (s->frame_flags >> 23) & 0xf;
>          if (sr == 0xf) {
>              if (!sample_rate) {
>                  av_log(avctx, AV_LOG_ERROR, "Custom sample rate missing.\n");
>                  return AVERROR_INVALIDDATA;
>              }
> -            avctx->sample_rate = sample_rate * rate_x;
> +            new_samplerate = sample_rate * rate_x;
>          } else
> -            avctx->sample_rate = wv_rates[sr] * rate_x;
> +            new_samplerate = wv_rates[sr] * rate_x;
>  
>          if (multiblock) {
>              if (chan)
> -                avctx->channels = chan;
> +                new_channels = chan;
>              if (chmask)
> -                avctx->channel_layout = chmask;
> +                new_chmask = chmask;
>          } else {
> -            avctx->channels       = s->stereo ? 2 : 1;
> -            avctx->channel_layout = s->stereo ? AV_CH_LAYOUT_STEREO :
> -                                                AV_CH_LAYOUT_MONO;
> +            new_channels = s->stereo ? 2 : 1;
> +            new_chmask   = s->stereo ? AV_CH_LAYOUT_STEREO :
> +                                       AV_CH_LAYOUT_MONO;
> +        }
> +
> +        if (new_chmask &&
> +            av_get_channel_layout_nb_channels(new_chmask) != new_channels) {
> +            av_log(avctx, AV_LOG_ERROR, "Channel mask does not match the channel count\n");
> +            return AVERROR_INVALIDDATA;
>          }
>  
> +        /* clear DSD state if stream properties change */
> +        if (new_channels   != wc->dsd_channels      ||
> +            new_chmask     != avctx->channel_layout ||
> +            new_samplerate != avctx->sample_rate    ||
> +            !!got_dsd      != !!wc->dsdctx) {
> +            ret = wv_dsd_reset(wc, got_dsd ? new_channels : 0);
> +            if (ret < 0) {
> +                av_log(avctx, AV_LOG_ERROR, "Error reinitializing the DSD context\n");
> +                return ret;
> +            }
> +            ff_thread_release_buffer(avctx, &wc->curr_frame);
> +        }
> +        avctx->channels            = new_channels;
> +        avctx->channel_layout      = new_chmask;
> +        avctx->sample_rate         = new_samplerate;
> +        avctx->sample_fmt          = sample_fmt;
> +        avctx->bits_per_raw_sample = orig_bpp;
> +
>          ff_thread_release_buffer(avctx, &wc->prev_frame);
>          FFSWAP(ThreadFrame, wc->curr_frame, wc->prev_frame);
>  
> @@ -1546,10 +1616,7 @@ static void wavpack_decode_flush(AVCodecContext *avctx)
>  {
>      WavpackContext *s = avctx->priv_data;
>  
> -    if (!avctx->internal->is_copy) {
> -        for (int i = 0; i < avctx->channels; i++)
> -            memset(s->dsdctx[i].buf, 0x69, sizeof(s->dsdctx[i].buf));
> -    }
> +    wv_dsd_reset(s, 0);
>  }
>  
>  static int dsd_channel(AVCodecContext *avctx, void *frmptr, int jobnr, int threadnr)
> @@ -1590,15 +1657,6 @@ static int wavpack_decode_frame(AVCodecContext *avctx, void *data,
>  
>      s->modulation = (frame_flags & WV_DSD_DATA) ? MODULATION_DSD : MODULATION_PCM;
>  
> -    if (frame_flags & (WV_FLOAT_DATA | WV_DSD_DATA)) {
> -        avctx->sample_fmt = AV_SAMPLE_FMT_FLTP;
> -    } else if ((frame_flags & 0x03) <= 1) {
> -        avctx->sample_fmt = AV_SAMPLE_FMT_S16P;
> -    } else {
> -        avctx->sample_fmt          = AV_SAMPLE_FMT_S32P;
> -        avctx->bits_per_raw_sample = ((frame_flags & 0x03) + 1) << 3;
> -    }
> -
>      while (buf_size > WV_HEADER_SIZE) {
>          frame_size = AV_RL32(buf + 4) - 12;
>          buf       += 20;

I was working on implementing this myself, but this is a better solution. I did not use AVBufferRef, but achieved a similar
thing with multiple pointers back to the initial context (one for the DSD context and one for the DSD channel count).

I have tested this with my reference files, and also a new one that artificially changes the channel configuration in DSD mode.
That crashed the previous version (even without Matroska) and now works fine. Good job!
diff mbox series

Patch

diff --git a/libavcodec/wavpack.c b/libavcodec/wavpack.c
index b27262b94e..9cc4104dd0 100644
--- a/libavcodec/wavpack.c
+++ b/libavcodec/wavpack.c
@@ -20,6 +20,7 @@ 
  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
  */
 
+#include "libavutil/buffer.h"
 #include "libavutil/channel_layout.h"
 
 #define BITSTREAM_READER_LE
@@ -109,7 +110,10 @@  typedef struct WavpackContext {
     AVFrame *frame;
     ThreadFrame curr_frame, prev_frame;
     Modulation modulation;
+
+    AVBufferRef *dsd_ref;
     DSDContext *dsdctx;
+    int dsd_channels;
 } WavpackContext;
 
 #define LEVEL_DECAY(a)  (((a) + 0x80) >> 8)
@@ -978,6 +982,32 @@  static av_cold int wv_alloc_frame_context(WavpackContext *c)
     return 0;
 }
 
+static int wv_dsd_reset(WavpackContext *s, int channels)
+{
+    int i;
+
+    s->dsdctx = NULL;
+    s->dsd_channels = 0;
+    av_buffer_unref(&s->dsd_ref);
+
+    if (!channels)
+        return 0;
+
+    if (channels > INT_MAX / sizeof(*s->dsdctx))
+        return AVERROR(EINVAL);
+
+    s->dsd_ref = av_buffer_allocz(channels * sizeof(*s->dsdctx));
+    if (!s->dsd_ref)
+        return AVERROR(ENOMEM);
+    s->dsdctx = (DSDContext*)s->dsd_ref->data;
+    s->dsd_channels = channels;
+
+    for (i = 0; i < channels; i++)
+        memset(s->dsdctx[i].buf, 0x69, sizeof(s->dsdctx[i].buf));
+
+    return 0;
+}
+
 #if HAVE_THREADS
 static int init_thread_copy(AVCodecContext *avctx)
 {
@@ -1008,6 +1038,17 @@  static int update_thread_context(AVCodecContext *dst, const AVCodecContext *src)
             return ret;
     }
 
+    av_buffer_unref(&fdst->dsd_ref);
+    fdst->dsdctx = NULL;
+    fdst->dsd_channels = 0;
+    if (fsrc->dsd_ref) {
+        fdst->dsd_ref = av_buffer_ref(fsrc->dsd_ref);
+        if (!fdst->dsd_ref)
+            return AVERROR(ENOMEM);
+        fdst->dsdctx = (DSDContext*)fdst->dsd_ref->data;
+        fdst->dsd_channels = fsrc->dsd_channels;
+    }
+
     return 0;
 }
 #endif
@@ -1025,15 +1066,9 @@  static av_cold int wavpack_decode_init(AVCodecContext *avctx)
     s->curr_frame.f = av_frame_alloc();
     s->prev_frame.f = av_frame_alloc();
 
-    // the DSD to PCM context is shared (and used serially) between all decoding threads
-    s->dsdctx = av_calloc(avctx->channels, sizeof(DSDContext));
-
-    if (!s->curr_frame.f || !s->prev_frame.f || !s->dsdctx)
+    if (!s->curr_frame.f || !s->prev_frame.f)
         return AVERROR(ENOMEM);
 
-    for (int i = 0; i < avctx->channels; i++)
-        memset(s->dsdctx[i].buf, 0x69, sizeof(s->dsdctx[i].buf));
-
     ff_init_dsd_data();
 
     return 0;
@@ -1053,8 +1088,7 @@  static av_cold int wavpack_decode_end(AVCodecContext *avctx)
     ff_thread_release_buffer(avctx, &s->prev_frame);
     av_frame_free(&s->prev_frame.f);
 
-    if (!avctx->internal->is_copy)
-        av_freep(&s->dsdctx);
+    av_buffer_unref(&s->dsd_ref);
 
     return 0;
 }
@@ -1065,6 +1099,7 @@  static int wavpack_decode_block(AVCodecContext *avctx, int block_no,
     WavpackContext *wc = avctx->priv_data;
     WavpackFrameContext *s;
     GetByteContext gb;
+    enum AVSampleFormat sample_fmt;
     void *samples_l = NULL, *samples_r = NULL;
     int ret;
     int got_terms   = 0, got_weights = 0, got_samples = 0,
@@ -1102,7 +1137,15 @@  static int wavpack_decode_block(AVCodecContext *avctx, int block_no,
         return AVERROR_INVALIDDATA;
     }
     s->frame_flags = bytestream2_get_le32(&gb);
-    bpp            = av_get_bytes_per_sample(avctx->sample_fmt);
+
+    if (s->frame_flags & (WV_FLOAT_DATA | WV_DSD_DATA))
+        sample_fmt = AV_SAMPLE_FMT_FLTP;
+    else if ((s->frame_flags & 0x03) <= 1)
+        sample_fmt = AV_SAMPLE_FMT_S16P;
+    else
+        sample_fmt          = AV_SAMPLE_FMT_S32P;
+
+    bpp            = av_get_bytes_per_sample(sample_fmt);
     orig_bpp       = ((s->frame_flags & 0x03) + 1) << 3;
     multiblock     = (s->frame_flags & WV_SINGLE_BLOCK) != WV_SINGLE_BLOCK;
 
@@ -1436,11 +1479,11 @@  static int wavpack_decode_block(AVCodecContext *avctx, int block_no,
             av_log(avctx, AV_LOG_ERROR, "Hybrid config not found\n");
             return AVERROR_INVALIDDATA;
         }
-        if (!got_float && avctx->sample_fmt == AV_SAMPLE_FMT_FLTP) {
+        if (!got_float && sample_fmt == AV_SAMPLE_FMT_FLTP) {
             av_log(avctx, AV_LOG_ERROR, "Float information not found\n");
             return AVERROR_INVALIDDATA;
         }
-        if (s->got_extra_bits && avctx->sample_fmt != AV_SAMPLE_FMT_FLTP) {
+        if (s->got_extra_bits && sample_fmt != AV_SAMPLE_FMT_FLTP) {
             const int size   = get_bits_left(&s->gb_extra_bits);
             const int wanted = s->samples * s->extra_bits << s->stereo_in;
             if (size < wanted) {
@@ -1462,27 +1505,54 @@  static int wavpack_decode_block(AVCodecContext *avctx, int block_no,
     }
 
     if (!wc->ch_offset) {
+        int      new_channels = avctx->channels;
+        uint64_t new_chmask   = avctx->channel_layout;
+        int new_samplerate;
         int sr = (s->frame_flags >> 23) & 0xf;
         if (sr == 0xf) {
             if (!sample_rate) {
                 av_log(avctx, AV_LOG_ERROR, "Custom sample rate missing.\n");
                 return AVERROR_INVALIDDATA;
             }
-            avctx->sample_rate = sample_rate * rate_x;
+            new_samplerate = sample_rate * rate_x;
         } else
-            avctx->sample_rate = wv_rates[sr] * rate_x;
+            new_samplerate = wv_rates[sr] * rate_x;
 
         if (multiblock) {
             if (chan)
-                avctx->channels = chan;
+                new_channels = chan;
             if (chmask)
-                avctx->channel_layout = chmask;
+                new_chmask = chmask;
         } else {
-            avctx->channels       = s->stereo ? 2 : 1;
-            avctx->channel_layout = s->stereo ? AV_CH_LAYOUT_STEREO :
-                                                AV_CH_LAYOUT_MONO;
+            new_channels = s->stereo ? 2 : 1;
+            new_chmask   = s->stereo ? AV_CH_LAYOUT_STEREO :
+                                       AV_CH_LAYOUT_MONO;
+        }
+
+        if (new_chmask &&
+            av_get_channel_layout_nb_channels(new_chmask) != new_channels) {
+            av_log(avctx, AV_LOG_ERROR, "Channel mask does not match the channel count\n");
+            return AVERROR_INVALIDDATA;
         }
 
+        /* clear DSD state if stream properties change */
+        if (new_channels   != wc->dsd_channels      ||
+            new_chmask     != avctx->channel_layout ||
+            new_samplerate != avctx->sample_rate    ||
+            !!got_dsd      != !!wc->dsdctx) {
+            ret = wv_dsd_reset(wc, got_dsd ? new_channels : 0);
+            if (ret < 0) {
+                av_log(avctx, AV_LOG_ERROR, "Error reinitializing the DSD context\n");
+                return ret;
+            }
+            ff_thread_release_buffer(avctx, &wc->curr_frame);
+        }
+        avctx->channels            = new_channels;
+        avctx->channel_layout      = new_chmask;
+        avctx->sample_rate         = new_samplerate;
+        avctx->sample_fmt          = sample_fmt;
+        avctx->bits_per_raw_sample = orig_bpp;
+
         ff_thread_release_buffer(avctx, &wc->prev_frame);
         FFSWAP(ThreadFrame, wc->curr_frame, wc->prev_frame);
 
@@ -1546,10 +1616,7 @@  static void wavpack_decode_flush(AVCodecContext *avctx)
 {
     WavpackContext *s = avctx->priv_data;
 
-    if (!avctx->internal->is_copy) {
-        for (int i = 0; i < avctx->channels; i++)
-            memset(s->dsdctx[i].buf, 0x69, sizeof(s->dsdctx[i].buf));
-    }
+    wv_dsd_reset(s, 0);
 }
 
 static int dsd_channel(AVCodecContext *avctx, void *frmptr, int jobnr, int threadnr)
@@ -1590,15 +1657,6 @@  static int wavpack_decode_frame(AVCodecContext *avctx, void *data,
 
     s->modulation = (frame_flags & WV_DSD_DATA) ? MODULATION_DSD : MODULATION_PCM;
 
-    if (frame_flags & (WV_FLOAT_DATA | WV_DSD_DATA)) {
-        avctx->sample_fmt = AV_SAMPLE_FMT_FLTP;
-    } else if ((frame_flags & 0x03) <= 1) {
-        avctx->sample_fmt = AV_SAMPLE_FMT_S16P;
-    } else {
-        avctx->sample_fmt          = AV_SAMPLE_FMT_S32P;
-        avctx->bits_per_raw_sample = ((frame_flags & 0x03) + 1) << 3;
-    }
-
     while (buf_size > WV_HEADER_SIZE) {
         frame_size = AV_RL32(buf + 4) - 12;
         buf       += 20;