diff mbox series

[FFmpeg-devel,1/2] lavu/tx: rewrite internal code as a tree-based codelet constructor

Message ID Mtvl7g3--3-2@lynne.ee
State New
Headers show
Series [FFmpeg-devel,1/2] lavu/tx: rewrite internal code as a tree-based codelet constructor | expand

Checks

Context Check Description
andriy/make_x86 fail Make failed
andriy/make_ppc fail Make failed
andriy/make_aarch64_jetson fail Make failed

Commit Message

Lynne Jan. 21, 2022, 8:33 a.m. UTC
This commit rewrites the internal transform code into a constructor
that stitches transforms (codelets).
This allows for transforms to reuse arbitrary parts of other
transforms, and allows transforms to be stacked onto one
another (such as a full iMDCT using a half-iMDCT which in turn
uses an FFT). It also permits for each step to be individually
replaced by assembly or a custom implementation (such as an ASIC).

Patch attached.
Subject: [PATCH 1/2] lavu/tx: rewrite internal code as a tree-based codelet
 constructor

This commit rewrites the internal transform code into a constructor
that stitches transforms (codelets).
This allows for transforms to reuse arbitrary parts of other
transforms, and allows transforms to be stacked onto one
another (such as a full iMDCT using a half-iMDCT which in turn
uses an FFT). It also permits for each step to be individually
replaced by assembly or a custom implementation (such as an ASIC).
---
 libavutil/Makefile            |    4 +-
 libavutil/tx.c                |  483 +++++++++---
 libavutil/tx.h                |    3 +
 libavutil/tx_priv.h           |  180 +++--
 libavutil/tx_template.c       | 1356 ++++++++++++++++++++-------------
 libavutil/x86/tx_float.asm    |  111 +--
 libavutil/x86/tx_float_init.c |  170 +++--
 7 files changed, 1526 insertions(+), 781 deletions(-)

Comments

Lynne Jan. 21, 2022, 8:51 a.m. UTC | #1
21 Jan 2022, 09:33 by dev@lynne.ee:

> This commit rewrites the internal transform code into a constructor
> that stitches transforms (codelets).
> This allows for transforms to reuse arbitrary parts of other
> transforms, and allows transforms to be stacked onto one
> another (such as a full iMDCT using a half-iMDCT which in turn
> uses an FFT). It also permits for each step to be individually
> replaced by assembly or a custom implementation (such as an ASIC).
>
> Patch attached.
>

Forgot that I disabled double and int32 transforms to speed up
testing, reenabled locally and on my github tx_tree branch.
Also removed some inactive debug code.
https://github.com/cyanreg/FFmpeg/tree/tx_tree
Lynne Jan. 25, 2022, 10:46 a.m. UTC | #2
21 Jan 2022, 09:51 by dev@lynne.ee:

> 21 Jan 2022, 09:33 by dev@lynne.ee:
>
>> This commit rewrites the internal transform code into a constructor
>> that stitches transforms (codelets).
>> This allows for transforms to reuse arbitrary parts of other
>> transforms, and allows transforms to be stacked onto one
>> another (such as a full iMDCT using a half-iMDCT which in turn
>> uses an FFT). It also permits for each step to be individually
>> replaced by assembly or a custom implementation (such as an ASIC).
>>
>> Patch attached.
>>
>
> Forgot that I disabled double and int32 transforms to speed up
> testing, reenabled locally and on my github tx_tree branch.
> Also removed some inactive debug code.
> https://github.com/cyanreg/FFmpeg/tree/tx_tree
>

I fixed bugs and improved to code more, and I think it's ready
for merging now.
The rdft is no longer bound by any convention, and its
scale may be changed by the user, eliminating after-transform
multiplies that are used pretty much everywhere in our code.

If someone (looks at Paul) gives it a test or converts a filter,
would be nice. I've only tested it on my synthetic benchmarks:
https://github.com/cyanreg/lavu_fft_test

I plan to push the patchset tomorrow unless there are comments.
Mostly done with the aarch64's SIMD, patch coming soon, hopefully.
Paul B Mahol Jan. 25, 2022, 5:17 p.m. UTC | #3
On Tue, Jan 25, 2022 at 11:46 AM Lynne <dev@lynne.ee> wrote:

> 21 Jan 2022, 09:51 by dev@lynne.ee:
>
> > 21 Jan 2022, 09:33 by dev@lynne.ee:
> >
> >> This commit rewrites the internal transform code into a constructor
> >> that stitches transforms (codelets).
> >> This allows for transforms to reuse arbitrary parts of other
> >> transforms, and allows transforms to be stacked onto one
> >> another (such as a full iMDCT using a half-iMDCT which in turn
> >> uses an FFT). It also permits for each step to be individually
> >> replaced by assembly or a custom implementation (such as an ASIC).
> >>
> >> Patch attached.
> >>
> >
> > Forgot that I disabled double and int32 transforms to speed up
> > testing, reenabled locally and on my github tx_tree branch.
> > Also removed some inactive debug code.
> > https://github.com/cyanreg/FFmpeg/tree/tx_tree
> >
>
> I fixed bugs and improved to code more, and I think it's ready
> for merging now.
> The rdft is no longer bound by any convention, and its
> scale may be changed by the user, eliminating after-transform
> multiplies that are used pretty much everywhere in our code.
>
> If someone (looks at Paul) gives it a test or converts a filter,
> would be nice. I've only tested it on my synthetic benchmarks:
> https://github.com/cyanreg/lavu_fft_test


Will try it once its applied. Thanks.

>
>
> I plan to push the patchset tomorrow unless there are comments.
> Mostly done with the aarch64's SIMD, patch coming soon, hopefully.
> _______________________________________________
> ffmpeg-devel mailing list
> ffmpeg-devel@ffmpeg.org
> https://ffmpeg.org/mailman/listinfo/ffmpeg-devel
>
> To unsubscribe, visit link above, or email
> ffmpeg-devel-request@ffmpeg.org with subject "unsubscribe".
>
Lynne Jan. 26, 2022, 3:31 a.m. UTC | #4
25 Jan 2022, 18:17 by onemda@gmail.com:

> On Tue, Jan 25, 2022 at 11:46 AM Lynne <dev@lynne.ee> wrote:
>
>> 21 Jan 2022, 09:51 by dev@lynne.ee:
>>
>> > 21 Jan 2022, 09:33 by dev@lynne.ee:
>> >
>> >> This commit rewrites the internal transform code into a constructor
>> >> that stitches transforms (codelets).
>> >> This allows for transforms to reuse arbitrary parts of other
>> >> transforms, and allows transforms to be stacked onto one
>> >> another (such as a full iMDCT using a half-iMDCT which in turn
>> >> uses an FFT). It also permits for each step to be individually
>> >> replaced by assembly or a custom implementation (such as an ASIC).
>> >>
>> >> Patch attached.
>> >>
>> >
>> > Forgot that I disabled double and int32 transforms to speed up
>> > testing, reenabled locally and on my github tx_tree branch.
>> > Also removed some inactive debug code.
>> > https://github.com/cyanreg/FFmpeg/tree/tx_tree
>> >
>>
>> I fixed bugs and improved to code more, and I think it's ready
>> for merging now.
>> The rdft is no longer bound by any convention, and its
>> scale may be changed by the user, eliminating after-transform
>> multiplies that are used pretty much everywhere in our code.
>>
>> If someone (looks at Paul) gives it a test or converts a filter,
>> would be nice. I've only tested it on my synthetic benchmarks:
>> https://github.com/cyanreg/lavu_fft_test
>>
>
>
> Will try it once its applied. Thanks.
>

Applied.
It's around 20% faster than lavc's rdft for powers of two lengths.
Non-power-of-two lengths are partially SIMD'd, so they're usable too.
I'll SIMD the small O(n) rdft loop once I'm done with NEON's and
PFA's SIMD. If you find bugs ping me on IRC.
diff mbox series

Patch

diff --git a/libavutil/Makefile b/libavutil/Makefile
index d17876df1a..22a7b15f61 100644
--- a/libavutil/Makefile
+++ b/libavutil/Makefile
@@ -170,8 +170,8 @@  OBJS = adler32.o                                                        \
        tea.o                                                            \
        tx.o                                                             \
        tx_float.o                                                       \
-       tx_double.o                                                      \
-       tx_int32.o                                                       \
+#       tx_double.o                                                      \
+#       tx_int32.o                                                       \
        video_enc_params.o                                               \
        film_grain_params.o                                              \
 
diff --git a/libavutil/tx.c b/libavutil/tx.c
index fa81ada2f1..28fe6c55b9 100644
--- a/libavutil/tx.c
+++ b/libavutil/tx.c
@@ -17,8 +17,9 @@ 
  */
 
 #include "tx_priv.h"
+#include "qsort.h"
 
-int ff_tx_type_is_mdct(enum AVTXType type)
+static av_always_inline int type_is_mdct(enum AVTXType type)
 {
     switch (type) {
     case AV_TX_FLOAT_MDCT:
@@ -42,22 +43,26 @@  static av_always_inline int mulinv(int n, int m)
 }
 
 /* Guaranteed to work for any n, m where gcd(n, m) == 1 */
-int ff_tx_gen_compound_mapping(AVTXContext *s)
+int ff_tx_gen_compound_mapping(AVTXContext *s, int n, int m)
 {
     int *in_map, *out_map;
-    const int n     = s->n;
-    const int m     = s->m;
-    const int inv   = s->inv;
-    const int len   = n*m;
-    const int m_inv = mulinv(m, n);
-    const int n_inv = mulinv(n, m);
-    const int mdct  = ff_tx_type_is_mdct(s->type);
-
-    if (!(s->pfatab = av_malloc(2*len*sizeof(*s->pfatab))))
+    const int inv = s->inv;
+    const int len = n*m;    /* Will not be equal to s->len for MDCTs */
+    const int mdct = type_is_mdct(s->type);
+    int m_inv, n_inv;
+
+    /* Make sure the numbers are coprime */
+    if (av_gcd(n, m) != 1)
+        return AVERROR(EINVAL);
+
+    m_inv = mulinv(m, n);
+    n_inv = mulinv(n, m);
+
+    if (!(s->map = av_malloc(2*len*sizeof(*s->map))))
         return AVERROR(ENOMEM);
 
-    in_map  = s->pfatab;
-    out_map = s->pfatab + n*m;
+    in_map  = s->map;
+    out_map = s->map + len;
 
     /* Ruritanian map for input, CRT map for output, can be swapped */
     for (int j = 0; j < m; j++) {
@@ -92,48 +97,50 @@  int ff_tx_gen_compound_mapping(AVTXContext *s)
     return 0;
 }
 
-static inline int split_radix_permutation(int i, int m, int inverse)
+static inline int split_radix_permutation(int i, int len, int inv)
 {
-    m >>= 1;
-    if (m <= 1)
+    len >>= 1;
+    if (len <= 1)
         return i & 1;
-    if (!(i & m))
-        return split_radix_permutation(i, m, inverse) * 2;
-    m >>= 1;
-    return split_radix_permutation(i, m, inverse) * 4 + 1 - 2*(!(i & m) ^ inverse);
+    if (!(i & len))
+        return split_radix_permutation(i, len, inv) * 2;
+    len >>= 1;
+    return split_radix_permutation(i, len, inv) * 4 + 1 - 2*(!(i & len) ^ inv);
 }
 
 int ff_tx_gen_ptwo_revtab(AVTXContext *s, int invert_lookup)
 {
-    const int m = s->m, inv = s->inv;
+    int len = s->len;
 
-    if (!(s->revtab = av_malloc(s->m*sizeof(*s->revtab))))
-        return AVERROR(ENOMEM);
-    if (!(s->revtab_c = av_malloc(m*sizeof(*s->revtab_c))))
+    if (!(s->map = av_malloc(len*sizeof(*s->map))))
         return AVERROR(ENOMEM);
 
-    /* Default */
-    for (int i = 0; i < m; i++) {
-        int k = -split_radix_permutation(i, m, inv) & (m - 1);
-        if (invert_lookup)
-            s->revtab[i] = s->revtab_c[i] = k;
-        else
-            s->revtab[i] = s->revtab_c[k] = i;
+    if (invert_lookup) {
+        for (int i = 0; i < s->len; i++)
+            s->map[i] = -split_radix_permutation(i, len, s->inv) & (len - 1);
+    } else {
+        for (int i = 0; i < s->len; i++)
+            s->map[-split_radix_permutation(i, len, s->inv) & (len - 1)] = i;
     }
 
     return 0;
 }
 
-int ff_tx_gen_ptwo_inplace_revtab_idx(AVTXContext *s, int *revtab)
+int ff_tx_gen_ptwo_inplace_revtab_idx(AVTXContext *s)
 {
-    int nb_inplace_idx = 0;
+    int *chain_map, chain_map_idx = 0, len = s->len;
 
-    if (!(s->inplace_idx = av_malloc(s->m*sizeof(*s->inplace_idx))))
+    if (!(s->map = av_malloc(2*len*sizeof(*s->map))))
         return AVERROR(ENOMEM);
 
+    chain_map = &s->map[s->len];
+
+    for (int i = 0; i < len; i++)
+        s->map[-split_radix_permutation(i, len, s->inv) & (len - 1)] = i;
+
     /* The first coefficient is always already in-place */
-    for (int src = 1; src < s->m; src++) {
-        int dst = revtab[src];
+    for (int src = 1; src < s->len; src++) {
+        int dst = s->map[src];
         int found = 0;
 
         if (dst <= src)
@@ -143,48 +150,54 @@  int ff_tx_gen_ptwo_inplace_revtab_idx(AVTXContext *s, int *revtab)
          * and if so, skips it, since to fully permute a loop we must only
          * enter it once. */
         do {
-            for (int j = 0; j < nb_inplace_idx; j++) {
-                if (dst == s->inplace_idx[j]) {
+            for (int j = 0; j < chain_map_idx; j++) {
+                if (dst == chain_map[j]) {
                     found = 1;
                     break;
                 }
             }
-            dst = revtab[dst];
+            dst = s->map[dst];
         } while (dst != src && !found);
 
         if (!found)
-            s->inplace_idx[nb_inplace_idx++] = src;
+            chain_map[chain_map_idx++] = src;
     }
 
-    s->inplace_idx[nb_inplace_idx++] = 0;
+    chain_map[chain_map_idx++] = 0;
 
     return 0;
 }
 
 static void parity_revtab_generator(int *revtab, int n, int inv, int offset,
                                     int is_dual, int dual_high, int len,
-                                    int basis, int dual_stride)
+                                    int basis, int dual_stride, int inv_lookup)
 {
     len >>= 1;
 
     if (len <= basis) {
-        int k1, k2, *even, *odd, stride;
+        int k1, k2, stride, even_idx, odd_idx;
 
         is_dual = is_dual && dual_stride;
         dual_high = is_dual & dual_high;
         stride = is_dual ? FFMIN(dual_stride, len) : 0;
 
-        even = &revtab[offset + dual_high*(stride - 2*len)];
-        odd  = &even[len + (is_dual && !dual_high)*len + dual_high*len];
+        even_idx = offset + dual_high*(stride - 2*len);
+        odd_idx  = even_idx + len + (is_dual && !dual_high)*len + dual_high*len;
+
 
         for (int i = 0; i < len; i++) {
             k1 = -split_radix_permutation(offset + i*2 + 0, n, inv) & (n - 1);
             k2 = -split_radix_permutation(offset + i*2 + 1, n, inv) & (n - 1);
-            *even++ = k1;
-            *odd++  = k2;
+            if (inv_lookup) {
+                revtab[even_idx++] = k1;
+                revtab[odd_idx++]  = k2;
+            } else {
+                revtab[k1] = even_idx++;
+                revtab[k2] = odd_idx++;
+            }
             if (stride && !((i + 1) % stride)) {
-                even += stride;
-                odd  += stride;
+                even_idx += stride;
+                odd_idx  += stride;
             }
         }
 
@@ -192,22 +205,52 @@  static void parity_revtab_generator(int *revtab, int n, int inv, int offset,
     }
 
     parity_revtab_generator(revtab, n, inv, offset,
-                            0, 0, len >> 0, basis, dual_stride);
+                            0, 0, len >> 0, basis, dual_stride, inv_lookup);
     parity_revtab_generator(revtab, n, inv, offset + (len >> 0),
-                            1, 0, len >> 1, basis, dual_stride);
+                            1, 0, len >> 1, basis, dual_stride, inv_lookup);
     parity_revtab_generator(revtab, n, inv, offset + (len >> 0) + (len >> 1),
-                            1, 1, len >> 1, basis, dual_stride);
+                            1, 1, len >> 1, basis, dual_stride, inv_lookup);
 }
 
-void ff_tx_gen_split_radix_parity_revtab(int *revtab, int len, int inv,
-                                         int basis, int dual_stride)
+int ff_tx_gen_split_radix_parity_revtab(AVTXContext *s, int invert_lookup,
+                                        int basis, int dual_stride)
 {
+    int len = s->len;
+    int inv = s->inv;
+
+    if (!(s->map = av_mallocz(len*sizeof(*s->map))))
+        return AVERROR(ENOMEM);
+
     basis >>= 1;
     if (len < basis)
-        return;
+        return AVERROR(EINVAL);
+
     av_assert0(!dual_stride || !(dual_stride & (dual_stride - 1)));
     av_assert0(dual_stride <= basis);
-    parity_revtab_generator(revtab, len, inv, 0, 0, 0, len, basis, dual_stride);
+    parity_revtab_generator(s->map, len, inv, 0, 0, 0, len,
+                            basis, dual_stride, invert_lookup);
+
+    return 0;
+}
+
+static void reset_ctx(AVTXContext *s)
+{
+    if (!s)
+        return;
+
+    if (s->sub)
+        for (int i = 0; i < s->nb_sub; i++)
+            reset_ctx(&s->sub[i]);
+
+    if (s->cd_self->uninit)
+        s->cd_self->uninit(s);
+
+    av_freep(&s->sub);
+    av_freep(&s->map);
+    av_freep(&s->exp);
+    av_freep(&s->tmp);
+
+    memset(s, 0, sizeof(*s));
 }
 
 av_cold void av_tx_uninit(AVTXContext **ctx)
@@ -215,53 +258,303 @@  av_cold void av_tx_uninit(AVTXContext **ctx)
     if (!(*ctx))
         return;
 
-    av_free((*ctx)->pfatab);
-    av_free((*ctx)->exptab);
-    av_free((*ctx)->revtab);
-    av_free((*ctx)->revtab_c);
-    av_free((*ctx)->inplace_idx);
-    av_free((*ctx)->tmp);
-
+    reset_ctx(*ctx);
     av_freep(ctx);
 }
 
+/* Null transform when the length is 1 */
+static void ff_tx_null(AVTXContext *s, void *_out, void *_in, ptrdiff_t stride)
+{
+    memcpy(_out, _in, stride);
+}
+
+static const FFTXCodelet ff_tx_null_def = {
+    .name       = "null",
+    .function   = ff_tx_null,
+    .type       = TX_TYPE_ANY,
+    .flags      = AV_TX_UNALIGNED | FF_TX_ALIGNED |
+                  FF_TX_OUT_OF_PLACE | AV_TX_INPLACE,
+    .factors[0] = TX_FACTOR_ANY,
+    .min_len    = 1,
+    .max_len    = 1,
+    .init       = NULL,
+    .cpu_flags  = FF_TX_CPU_FLAGS_ALL,
+    .prio       = FF_TX_PRIO_MAX,
+};
+
+static const FFTXCodelet * const ff_tx_null_list[] = {
+    &ff_tx_null_def,
+};
+
+typedef struct TXCodeletMatch {
+    const FFTXCodelet *cd;
+    int prio;
+} TXCodeletMatch;
+
+static int cmp_matches(TXCodeletMatch *a, TXCodeletMatch *b)
+{
+    int diff = FFDIFFSIGN(b->prio, a->prio);
+    if (!diff)
+        return FFDIFFSIGN(b->cd->factors[0], a->cd->factors[0]);
+    return diff;
+}
+
+static void print_flags(uint64_t flags)
+{
+    av_log(NULL, AV_LOG_WARNING, "Flags: %s%s%s%s%s%s\n",
+           flags & AV_TX_INPLACE ? "inplace+" : "",
+           flags & FF_TX_OUT_OF_PLACE ? "out_of_place+" : "",
+           flags & FF_TX_ALIGNED ? "aligned+" : "",
+           flags & AV_TX_UNALIGNED ? "unaligned+" : "",
+           flags & FF_TX_PRESHUFFLE ? "preshuffle+" : "",
+           flags & AV_TX_FULL_IMDCT ? "full_imdct+" : "");
+}
+
+/* We want all factors to completely cover the length */
+static inline int check_cd_factors(const FFTXCodelet *cd, int len)
+{
+    int all_flag = 0;
+
+    for (int i = 0; i < TX_MAX_SUB; i++) {
+        int factor = cd->factors[i];
+
+        /* Conditions satisfied */
+        if (len == 1)
+            return 1;
+
+        /* No more factors */
+        if (!factor) {
+            break;
+        } else if (factor == TX_FACTOR_ANY) {
+            all_flag = 1;
+            continue;
+        }
+
+        if (factor == 2) { /* Fast path */
+            int bits_2 = ff_ctz(len);
+            if (!bits_2)
+                return 0; /* Factor not supported */
+
+            len >>= bits_2;
+        } else {
+            int res = len % factor;
+            if (res)
+                return 0; /* Factor not supported */
+
+            while (!res) {
+                len /= factor;
+                res = len % factor;
+            }
+        }
+    }
+
+    return all_flag || (len == 1);
+}
+
+av_cold int ff_tx_init_subtx(AVTXContext *s, enum AVTXType type,
+                             uint64_t flags, FFTXCodeletOptions *opts,
+                             int len, int inv, const void *scale)
+{
+    int ret = 0;
+    AVTXContext *sub = NULL;
+    TXCodeletMatch *cd_tmp, *cd_matches = NULL;
+    unsigned int cd_matches_size = 0;
+    int nb_cd_matches = 0;
+
+    /* Array of all compiled codelet lists. Order is irrelevant. */
+    const FFTXCodelet * const * const codelet_list[] = {
+        ff_tx_codelet_list_float_c,
+        ff_tx_codelet_list_float_x86,
+        ff_tx_null_list,
+    };
+    int codelet_list_num = FF_ARRAY_ELEMS(codelet_list);
+
+    /* We still accept functions marked with SLOW, even if the CPU is
+     * marked with the same flag, but we give them lower priority. */
+    const int cpu_flags = av_get_cpu_flags();
+    const int slow_mask = AV_CPU_FLAG_SSE2SLOW | AV_CPU_FLAG_SSE3SLOW  |
+                          AV_CPU_FLAG_ATOM     | AV_CPU_FLAG_SSSE3SLOW |
+                          AV_CPU_FLAG_AVXSLOW  | AV_CPU_FLAG_SLOW_GATHER;
+
+    /* Flags the transform wants */
+    uint64_t req_flags = flags;
+    int penalize_unaligned = 0;
+
+    /* Unaligned codelets are compatible with the aligned flag, with a slight
+     * penalty */
+    if (req_flags & FF_TX_ALIGNED) {
+        req_flags |= AV_TX_UNALIGNED;
+        penalize_unaligned = 1;
+    }
+
+    /* If either flag is set, both are okay, so don't check for an exact match */
+    if ((req_flags & AV_TX_INPLACE) && (req_flags & FF_TX_OUT_OF_PLACE))
+        req_flags &= ~(AV_TX_INPLACE | FF_TX_OUT_OF_PLACE);
+    if ((req_flags & FF_TX_ALIGNED) && (req_flags & AV_TX_UNALIGNED))
+        req_flags &= ~(FF_TX_ALIGNED | AV_TX_UNALIGNED);
+
+    /* Flags the codelet may require to be present */
+    uint64_t inv_req_mask = AV_TX_FULL_IMDCT | FF_TX_PRESHUFFLE;
+
+//    print_flags(req_flags);
+
+    /* Loop through all codelets in all codelet lists to find matches
+     * to the requirements */
+    while (codelet_list_num--) {
+        const FFTXCodelet * const * list = codelet_list[codelet_list_num];
+        const FFTXCodelet *cd = NULL;
+
+        while ((cd = *list++)) {
+            /* Check if the type matches */
+            if (cd->type != TX_TYPE_ANY && type != cd->type)
+                continue;
+
+            /* Check direction for non-orthogonal codelets */
+            if (((cd->flags & FF_TX_FORWARD_ONLY) && inv) ||
+                ((cd->flags & (FF_TX_INVERSE_ONLY | AV_TX_FULL_IMDCT)) && !inv))
+                continue;
+
+            /* Check if the requested flags match from both sides */
+            if (((req_flags    & cd->flags) != (req_flags)) ||
+                ((inv_req_mask & cd->flags) != (req_flags & inv_req_mask)))
+                continue;
+
+            /* Check if length is supported */
+            if ((len < cd->min_len) || (cd->max_len != -1 && (len > cd->max_len)))
+                continue;
+
+            /* Check if the CPU supports the required ISA */
+            if (!(!cd->cpu_flags || (cpu_flags & (cd->cpu_flags & ~slow_mask))))
+                continue;
+
+            /* Check for factors */
+            if (!check_cd_factors(cd, len))
+                continue;
+
+            /* Realloc array and append */
+            cd_tmp = av_fast_realloc(cd_matches, &cd_matches_size,
+                                     sizeof(*cd_tmp) * (nb_cd_matches + 1));
+            if (!cd_tmp) {
+                av_free(cd_matches);
+                return AVERROR(ENOMEM);
+            }
+
+            cd_matches                     = cd_tmp;
+            cd_matches[nb_cd_matches].cd   = cd;
+            cd_matches[nb_cd_matches].prio = cd->prio;
+
+            /* If the CPU has a SLOW flag, and the instruction is also flagged
+             * as being slow for such, reduce its priority */
+            if ((cpu_flags & cd->cpu_flags) & slow_mask)
+                cd_matches[nb_cd_matches].prio -= 64;
+
+            /* Penalize unaligned functions if needed */
+            if ((cd->flags & AV_TX_UNALIGNED) && penalize_unaligned)
+                cd_matches[nb_cd_matches].prio -= 64;
+
+            /* Codelets for specific lengths are generally faster. */
+            if ((len == cd->min_len) && (len == cd->max_len))
+                cd_matches[nb_cd_matches].prio += 64;
+
+            nb_cd_matches++;
+        }
+    }
+
+    /* No matches found */
+    if (!nb_cd_matches)
+        return AVERROR(ENOSYS);
+
+    /* Sort the list */
+    AV_QSORT(cd_matches, nb_cd_matches, TXCodeletMatch, cmp_matches);
+
+    if (!s->sub)
+        s->sub = sub = av_mallocz(TX_MAX_SUB*sizeof(*sub));
+
+    /* Attempt to initialize each */
+    for (int i = 0; i < nb_cd_matches; i++) {
+        const FFTXCodelet *cd = cd_matches[i].cd;
+        AVTXContext *sctx = &s->sub[s->nb_sub];
+
+        sctx->len        = len;
+        sctx->inv        = inv;
+        sctx->type       = type;
+        sctx->flags      = flags;
+        sctx->cd_self    = cd;
+
+        s->fn[s->nb_sub] = cd->function;
+        s->cd[s->nb_sub] = cd;
+
+        ret = 0;
+        if (cd->init)
+            ret = cd->init(sctx, cd, flags, opts, len, inv, scale);
+
+        if (ret >= 0) {
+            s->nb_sub++;
+            goto end;
+        }
+
+        s->fn[s->nb_sub] = NULL;
+        s->cd[s->nb_sub] = NULL;
+
+        reset_ctx(sctx);
+        if (ret == AVERROR(ENOMEM))
+            break;
+    }
+
+    if (sub)
+        av_freep(&s->sub);
+
+    if (ret >= 0)
+        ret = AVERROR(ENOSYS);
+
+end:
+    av_free(cd_matches);
+    return ret;
+}
+
+static void print_tx_structure(AVTXContext *s, int depth)
+{
+    const FFTXCodelet *cd = s->cd_self;
+
+    for (int i = 0; i <= depth; i++)
+        av_log(NULL, AV_LOG_WARNING, "    ");
+    av_log(NULL, AV_LOG_WARNING, "↳ %s - %s, %ipt, %p\n", cd->name,
+           cd->type == TX_TYPE_ANY       ? "all"         :
+           cd->type == AV_TX_FLOAT_FFT   ? "fft_float"   :
+           cd->type == AV_TX_FLOAT_MDCT  ? "mdct_float"  :
+           cd->type == AV_TX_DOUBLE_FFT  ? "fft_double"  :
+           cd->type == AV_TX_DOUBLE_MDCT ? "mdct_double" :
+           cd->type == AV_TX_INT32_FFT   ? "fft_int32"   :
+           cd->type == AV_TX_INT32_MDCT  ? "mdct_int32"  : "unknown",
+           s->len, cd->function);
+
+    for (int i = 0; i < s->nb_sub; i++)
+        print_tx_structure(&s->sub[i], depth + 1);
+}
+
 av_cold int av_tx_init(AVTXContext **ctx, av_tx_fn *tx, enum AVTXType type,
                        int inv, int len, const void *scale, uint64_t flags)
 {
-    int err;
-    AVTXContext *s = av_mallocz(sizeof(*s));
-    if (!s)
-        return AVERROR(ENOMEM);
+    int ret;
+    AVTXContext tmp = { 0 };
 
-    switch (type) {
-    case AV_TX_FLOAT_FFT:
-    case AV_TX_FLOAT_MDCT:
-        if ((err = ff_tx_init_mdct_fft_float(s, tx, type, inv, len, scale, flags)))
-            goto fail;
-        if (ARCH_X86)
-            ff_tx_init_float_x86(s, tx);
-        break;
-    case AV_TX_DOUBLE_FFT:
-    case AV_TX_DOUBLE_MDCT:
-        if ((err = ff_tx_init_mdct_fft_double(s, tx, type, inv, len, scale, flags)))
-            goto fail;
-        break;
-    case AV_TX_INT32_FFT:
-    case AV_TX_INT32_MDCT:
-        if ((err = ff_tx_init_mdct_fft_int32(s, tx, type, inv, len, scale, flags)))
-            goto fail;
-        break;
-    default:
-        err = AVERROR(EINVAL);
-        goto fail;
-    }
+    if (!len || type >= AV_TX_NB)
+        return AVERROR(EINVAL);
 
-    *ctx = s;
+    if (!(flags & AV_TX_UNALIGNED))
+        flags |= FF_TX_ALIGNED;
+    if (!(flags & AV_TX_INPLACE))
+        flags |= FF_TX_OUT_OF_PLACE;
 
-    return 0;
+    ret = ff_tx_init_subtx(&tmp, type, flags, NULL, len, inv, scale);
+    if (ret < 0)
+        return ret;
+
+    *ctx = &tmp.sub[0];
+    *tx  = tmp.fn[0];
+
+    av_log(NULL, AV_LOG_WARNING, "Transform tree:\n");
+    print_tx_structure(*ctx, 0);
 
-fail:
-    av_tx_uninit(&s);
-    *tx = NULL;
-    return err;
+    return ret;
 }
diff --git a/libavutil/tx.h b/libavutil/tx.h
index 55173810ee..4bc1478644 100644
--- a/libavutil/tx.h
+++ b/libavutil/tx.h
@@ -82,6 +82,9 @@  enum AVTXType {
      * Stride must be a non-zero multiple of sizeof(int32_t).
      */
     AV_TX_INT32_MDCT = 5,
+
+    /* Not part of the API, do not use */
+    AV_TX_NB,
 };
 
 /**
diff --git a/libavutil/tx_priv.h b/libavutil/tx_priv.h
index 63dc6bbe6d..a709e6973f 100644
--- a/libavutil/tx_priv.h
+++ b/libavutil/tx_priv.h
@@ -25,17 +25,26 @@ 
 #include "attributes.h"
 
 #ifdef TX_FLOAT
-#define TX_NAME(x) x ## _float
+#define TX_TAB(x) x ## _float
+#define TX_NAME(x) x ## _float_c
+#define TX_NAME_STR(x) x "_float_c"
+#define TX_TYPE(x) AV_TX_FLOAT_ ## x
 #define SCALE_TYPE float
 typedef float FFTSample;
 typedef AVComplexFloat FFTComplex;
 #elif defined(TX_DOUBLE)
-#define TX_NAME(x) x ## _double
+#define TX_TAB(x) x ## _double
+#define TX_NAME(x) x ## _double_c
+#define TX_NAME_STR(x) x "_double_c"
+#define TX_TYPE(x) AV_TX_DOUBLE_ ## x
 #define SCALE_TYPE double
 typedef double FFTSample;
 typedef AVComplexDouble FFTComplex;
 #elif defined(TX_INT32)
-#define TX_NAME(x) x ## _int32
+#define TX_TAB(x) x ## _int32
+#define TX_NAME(x) x ## _int32_c
+#define TX_NAME_STR(x) x "_int32_c"
+#define TX_TYPE(x) AV_TX_INT32_ ## x
 #define SCALE_TYPE float
 typedef int32_t FFTSample;
 typedef AVComplexInt32 FFTComplex;
@@ -103,53 +112,130 @@  typedef void FFTComplex;
 #define CMUL3(c, a, b)                                                         \
     CMUL((c).re, (c).im, (a).re, (a).im, (b).re, (b).im)
 
-#define COSTABLE(size)                                                         \
-    DECLARE_ALIGNED(32, FFTSample, TX_NAME(ff_cos_##size))[size/4 + 1]
+/* Codelet flags, used to pick codelets. Must be a superset of enum AVTXFlags,
+ * but if it runs out of bits, it can be made separate. */
+typedef enum FFTXCodeletFlags {
+    FF_TX_OUT_OF_PLACE  = (1ULL << 63), /* Can be OR'd with AV_TX_INPLACE             */
+    FF_TX_ALIGNED       = (1ULL << 62), /* Cannot be OR'd with AV_TX_UNALIGNED        */
+    FF_TX_PRESHUFFLE    = (1ULL << 61), /* Codelet expects permuted coeffs            */
+    FF_TX_INVERSE_ONLY  = (1ULL << 60), /* For non-orthogonal inverse-only transforms */
+    FF_TX_FORWARD_ONLY  = (1ULL << 59), /* For non-orthogonal forward-only transforms */
+} FFTXCodeletFlags;
+
+typedef enum FFTXCodeletPriority {
+    FF_TX_PRIO_BASE = 0,              /* Baseline priority */
+
+    /* For SIMD, set prio to the register size in bits. */
+
+    FF_TX_PRIO_MIN          = -131072, /* For naive implementations */
+    FF_TX_PRIO_MAX          =  32768,  /* For custom implementations/ASICs */
+} FFTXCodeletPriority;
+
+/* Codelet options */
+typedef struct FFTXCodeletOptions {
+    int invert_lookup;     /* If codelet is flagged as FF_TX_CODELET_PRESHUFFLE,
+                              invert the lookup direction for the map generated */
+} FFTXCodeletOptions;
+
+/* Maximum amount of subtransform functions, subtransforms and factors. Arbitrary. */
+#define TX_MAX_SUB 4
+
+typedef struct FFTXCodelet {
+    const char       *name;                    /* Codelet name, for debugging */
+    av_tx_fn          function;                /* Codelet function, != NULL */
+    enum AVTXType     type;                    /* Type of codelet transform */
+#define TX_TYPE_ANY INT32_MAX   /* Special type to allow all types */
+
+    uint64_t          flags;                   /* A combination of AVTXFlags
+                                                * and FFTXCodeletFlags flags
+                                                * to describe the codelet. */
+
+    int               factors[TX_MAX_SUB];     /* Length factors */
+#define TX_FACTOR_ANY -1        /* When used alone, signals that the codelet
+                                 * supports all factors. Otherwise, if other
+                                 * factors are present, it signals that whatever
+                                 * remains will be supported, as long as the
+                                 * other factors are a component of the length */
+
+    int               min_len;                 /* Minimum length of transform, must be >= 1 */
+    int               max_len;                 /* Maximum length of transform */
+#define TX_LEN_UNLIMITED -1     /* Special length value to permit arbitrarily large transforms */
+
+    int (*init)(AVTXContext *s,                /* Callback for current context initialization. */
+                const struct FFTXCodelet *cd,  /* May be NULL */
+                uint64_t flags,
+                FFTXCodeletOptions *opts,
+                int len, int inv,
+                const void *scale);
+
+    int (*uninit)(AVTXContext *s);             /* Callback for uninitialization. Can be NULL. */
+
+    int cpu_flags;                             /* CPU flags. If any negative flags like
+                                                * SLOW are present, will avoid picking.
+                                                * 0x0 to signal it's a C codelet */
+#define FF_TX_CPU_FLAGS_ALL 0x0 /* Special CPU flag for C */
+
+    int prio;                                  /* < 0 = least, 0 = no pref, > 0 = prefer */
+} FFTXCodelet;
 
-/* Used by asm, reorder with care */
 struct AVTXContext {
-    int n;              /* Non-power-of-two part */
-    int m;              /* Power-of-two part */
-    int inv;            /* Is inverse */
-    int type;           /* Type */
-    uint64_t flags;     /* Flags */
-    double scale;       /* Scale */
-
-    FFTComplex *exptab; /* MDCT exptab */
-    FFTComplex    *tmp; /* Temporary buffer needed for all compound transforms */
-    int        *pfatab; /* Input/Output mapping for compound transforms */
-    int        *revtab; /* Input mapping for power of two transforms */
-    int   *inplace_idx; /* Required indices to revtab for in-place transforms */
-
-    int      *revtab_c; /* Revtab for only the C transforms, needed because
-                         * checkasm makes us reuse the same context. */
-
-    av_tx_fn    top_tx; /* Used for computing transforms derived from other
-                         * transforms, like full-length iMDCTs and RDFTs.
-                         * NOTE: Do NOT use this to mix assembly with C code. */
+    /* Fields the root transform and subtransforms use or may use.
+     * NOTE: This section is used by assembly, do not reorder or change */
+    int                len;             /* Length of the transform */
+    int                inv;             /* If transform is inverse */
+    int               *map;             /* Lookup table(s) */
+    FFTComplex        *exp;             /* Any non-pre-baked multiplication factors needed */
+    FFTComplex        *tmp;             /* Temporary buffer, if needed */
+
+    AVTXContext       *sub;             /* Subtransform context(s), if needed */
+    av_tx_fn           fn[TX_MAX_SUB];  /* Function(s) for the subtransforms */
+    int                nb_sub;          /* Number of subtransforms.
+                                         * The reason all of these are set here
+                                         * rather than in each separate context
+                                         * is to eliminate extra pointer
+                                         * dereferences. */
+
+    /* Fields mainly useul/applicable for the root transform or initialization.
+     * Fields below are not used by assembly code. */
+    const FFTXCodelet *cd[TX_MAX_SUB];  /* Subtransform codelets */
+    const FFTXCodelet *cd_self;         /* Codelet for the current context */
+    enum AVTXType      type;            /* Type of transform */
+    uint64_t           flags;           /* A combination of AVTXFlags
+                                           and FFTXCodeletFlags flags
+                                           used when creating */
+    float              scale_f;
+    double             scale_d;
+    void              *opaque;          /* Free to use by implementations */
 };
 
-/* Checks if type is an MDCT */
-int ff_tx_type_is_mdct(enum AVTXType type);
+/* Create a subtransform in the current context with the given parameters.
+ * The flags parameter from FFTXCodelet.init() should be preserved as much
+ * as that's possible.
+ * MUST be called during the sub() callback of each codelet. */
+int ff_tx_init_subtx(AVTXContext *s, enum AVTXType type,
+                     uint64_t flags, FFTXCodeletOptions *opts,
+                     int len, int inv, const void *scale);
 
 /*
  * Generates the PFA permutation table into AVTXContext->pfatab. The end table
  * is appended to the start table.
  */
-int ff_tx_gen_compound_mapping(AVTXContext *s);
+int ff_tx_gen_compound_mapping(AVTXContext *s, int n, int m);
 
 /*
  * Generates a standard-ish (slightly modified) Split-Radix revtab into
- * AVTXContext->revtab
+ * AVTXContext->map. Invert lookup changes how the mapping needs to be applied.
+ * If it's set to 0, it has to be applied like out[map[i]] = in[i], otherwise
+ * if it's set to 1, has to be applied as out[i] = in[map[i]]
  */
 int ff_tx_gen_ptwo_revtab(AVTXContext *s, int invert_lookup);
 
 /*
  * Generates an index into AVTXContext->inplace_idx that if followed in the
- * specific order,  allows the revtab to be done in-place. AVTXContext->revtab
+ * specific order, allows the revtab to be done in-place. AVTXContext->map
  * must already exist.
  */
-int ff_tx_gen_ptwo_inplace_revtab_idx(AVTXContext *s, int *revtab);
+int ff_tx_gen_ptwo_inplace_revtab_idx(AVTXContext *s);
 
 /*
  * This generates a parity-based revtab of length len and direction inv.
@@ -179,25 +265,17 @@  int ff_tx_gen_ptwo_inplace_revtab_idx(AVTXContext *s, int *revtab);
  *
  * If length is smaller than basis/2 this function will not do anything.
  */
-void ff_tx_gen_split_radix_parity_revtab(int *revtab, int len, int inv,
-                                         int basis, int dual_stride);
-
-/* Templated init functions */
-int ff_tx_init_mdct_fft_float(AVTXContext *s, av_tx_fn *tx,
-                              enum AVTXType type, int inv, int len,
-                              const void *scale, uint64_t flags);
-int ff_tx_init_mdct_fft_double(AVTXContext *s, av_tx_fn *tx,
-                               enum AVTXType type, int inv, int len,
-                               const void *scale, uint64_t flags);
-int ff_tx_init_mdct_fft_int32(AVTXContext *s, av_tx_fn *tx,
-                              enum AVTXType type, int inv, int len,
-                              const void *scale, uint64_t flags);
-
-typedef struct CosTabsInitOnce {
-    void (*func)(void);
-    AVOnce control;
-} CosTabsInitOnce;
-
-void ff_tx_init_float_x86(AVTXContext *s, av_tx_fn *tx);
+int ff_tx_gen_split_radix_parity_revtab(AVTXContext *s, int invert_lookup,
+                                        int basis, int dual_stride);
+
+void ff_tx_init_tabs_float (int len);
+extern const FFTXCodelet * const ff_tx_codelet_list_float_c       [];
+extern const FFTXCodelet * const ff_tx_codelet_list_float_x86     [];
+
+void ff_tx_init_tabs_double(int len);
+extern const FFTXCodelet * const ff_tx_codelet_list_double_c      [];
+
+void ff_tx_init_tabs_int32 (int len);
+extern const FFTXCodelet * const ff_tx_codelet_list_int32_c       [];
 
 #endif /* AVUTIL_TX_PRIV_H */
diff --git a/libavutil/tx_template.c b/libavutil/tx_template.c
index cad66a8bc0..bfd27799be 100644
--- a/libavutil/tx_template.c
+++ b/libavutil/tx_template.c
@@ -24,134 +24,160 @@ 
  * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
  */
 
-/* All costabs for a type are defined here */
-COSTABLE(16);
-COSTABLE(32);
-COSTABLE(64);
-COSTABLE(128);
-COSTABLE(256);
-COSTABLE(512);
-COSTABLE(1024);
-COSTABLE(2048);
-COSTABLE(4096);
-COSTABLE(8192);
-COSTABLE(16384);
-COSTABLE(32768);
-COSTABLE(65536);
-COSTABLE(131072);
-DECLARE_ALIGNED(32, FFTComplex, TX_NAME(ff_cos_53))[4];
-DECLARE_ALIGNED(32, FFTComplex, TX_NAME(ff_cos_7))[3];
-DECLARE_ALIGNED(32, FFTComplex, TX_NAME(ff_cos_9))[4];
-
-static FFTSample * const cos_tabs[18] = {
-    NULL,
-    NULL,
-    NULL,
-    NULL,
-    TX_NAME(ff_cos_16),
-    TX_NAME(ff_cos_32),
-    TX_NAME(ff_cos_64),
-    TX_NAME(ff_cos_128),
-    TX_NAME(ff_cos_256),
-    TX_NAME(ff_cos_512),
-    TX_NAME(ff_cos_1024),
-    TX_NAME(ff_cos_2048),
-    TX_NAME(ff_cos_4096),
-    TX_NAME(ff_cos_8192),
-    TX_NAME(ff_cos_16384),
-    TX_NAME(ff_cos_32768),
-    TX_NAME(ff_cos_65536),
-    TX_NAME(ff_cos_131072),
-};
-
-static av_always_inline void init_cos_tabs_idx(int index)
-{
-    int m = 1 << index;
-    double freq = 2*M_PI/m;
-    FFTSample *tab = cos_tabs[index];
-
-    for (int i = 0; i < m/4; i++)
-        *tab++ = RESCALE(cos(i*freq));
-
-    *tab = 0;
+#define TABLE_DEF(name, size) \
+    DECLARE_ALIGNED(32, FFTSample, TX_TAB(ff_tx_tab_ ##name))[size]
+
+#define SR_TABLE(len) \
+    TABLE_DEF(len, len/4 + 1)
+
+/* Power of two tables */
+SR_TABLE(8);
+SR_TABLE(16);
+SR_TABLE(32);
+SR_TABLE(64);
+SR_TABLE(128);
+SR_TABLE(256);
+SR_TABLE(512);
+SR_TABLE(1024);
+SR_TABLE(2048);
+SR_TABLE(4096);
+SR_TABLE(8192);
+SR_TABLE(16384);
+SR_TABLE(32768);
+SR_TABLE(65536);
+SR_TABLE(131072);
+
+/* Other factors' tables */
+TABLE_DEF(53, 8);
+TABLE_DEF( 7, 6);
+TABLE_DEF( 9, 8);
+
+typedef struct FFSRTabsInitOnce {
+    void (*func)(void);
+    AVOnce control;
+    int factors[4]; /* Must be sorted high -> low */
+} FFSRTabsInitOnce;
+
+#define INIT_FF_SR_TAB(len)                                        \
+static av_cold void TX_TAB(ff_tx_init_tab_ ##len)(void)            \
+{                                                                  \
+    double freq = 2*M_PI/len;                                      \
+    FFTSample *tab = TX_TAB(ff_tx_tab_ ##len);                     \
+                                                                   \
+    for (int i = 0; i < len/4; i++)                                \
+        *tab++ = RESCALE(cos(i*freq));                             \
+                                                                   \
+    *tab = 0;                                                      \
 }
 
-#define INIT_FF_COS_TABS_FUNC(index, size)                                     \
-static av_cold void init_cos_tabs_ ## size (void)                              \
-{                                                                              \
-    init_cos_tabs_idx(index);                                                  \
-}
+INIT_FF_SR_TAB(8)
+INIT_FF_SR_TAB(16)
+INIT_FF_SR_TAB(32)
+INIT_FF_SR_TAB(64)
+INIT_FF_SR_TAB(128)
+INIT_FF_SR_TAB(256)
+INIT_FF_SR_TAB(512)
+INIT_FF_SR_TAB(1024)
+INIT_FF_SR_TAB(2048)
+INIT_FF_SR_TAB(4096)
+INIT_FF_SR_TAB(8192)
+INIT_FF_SR_TAB(16384)
+INIT_FF_SR_TAB(32768)
+INIT_FF_SR_TAB(65536)
+INIT_FF_SR_TAB(131072)
+
+FFSRTabsInitOnce sr_tabs_init_once[] = {
+    { TX_TAB(ff_tx_init_tab_8),      AV_ONCE_INIT },
+    { TX_TAB(ff_tx_init_tab_16),     AV_ONCE_INIT },
+    { TX_TAB(ff_tx_init_tab_32),     AV_ONCE_INIT },
+    { TX_TAB(ff_tx_init_tab_64),     AV_ONCE_INIT },
+    { TX_TAB(ff_tx_init_tab_128),    AV_ONCE_INIT },
+    { TX_TAB(ff_tx_init_tab_256),    AV_ONCE_INIT },
+    { TX_TAB(ff_tx_init_tab_512),    AV_ONCE_INIT },
+    { TX_TAB(ff_tx_init_tab_1024),   AV_ONCE_INIT },
+    { TX_TAB(ff_tx_init_tab_2048),   AV_ONCE_INIT },
+    { TX_TAB(ff_tx_init_tab_4096),   AV_ONCE_INIT },
+    { TX_TAB(ff_tx_init_tab_8192),   AV_ONCE_INIT },
+    { TX_TAB(ff_tx_init_tab_16384),  AV_ONCE_INIT },
+    { TX_TAB(ff_tx_init_tab_32768),  AV_ONCE_INIT },
+    { TX_TAB(ff_tx_init_tab_65536),  AV_ONCE_INIT },
+    { TX_TAB(ff_tx_init_tab_131072), AV_ONCE_INIT },
+};
 
-INIT_FF_COS_TABS_FUNC(4, 16)
-INIT_FF_COS_TABS_FUNC(5, 32)
-INIT_FF_COS_TABS_FUNC(6, 64)
-INIT_FF_COS_TABS_FUNC(7, 128)
-INIT_FF_COS_TABS_FUNC(8, 256)
-INIT_FF_COS_TABS_FUNC(9, 512)
-INIT_FF_COS_TABS_FUNC(10, 1024)
-INIT_FF_COS_TABS_FUNC(11, 2048)
-INIT_FF_COS_TABS_FUNC(12, 4096)
-INIT_FF_COS_TABS_FUNC(13, 8192)
-INIT_FF_COS_TABS_FUNC(14, 16384)
-INIT_FF_COS_TABS_FUNC(15, 32768)
-INIT_FF_COS_TABS_FUNC(16, 65536)
-INIT_FF_COS_TABS_FUNC(17, 131072)
-
-static av_cold void ff_init_53_tabs(void)
+static av_cold void TX_TAB(ff_tx_init_tab_53)(void)
 {
-    TX_NAME(ff_cos_53)[0] = (FFTComplex){ RESCALE(cos(2 * M_PI / 12)), RESCALE(cos(2 * M_PI / 12)) };
-    TX_NAME(ff_cos_53)[1] = (FFTComplex){ RESCALE(cos(2 * M_PI /  6)), RESCALE(cos(2 * M_PI /  6)) };
-    TX_NAME(ff_cos_53)[2] = (FFTComplex){ RESCALE(cos(2 * M_PI /  5)), RESCALE(sin(2 * M_PI /  5)) };
-    TX_NAME(ff_cos_53)[3] = (FFTComplex){ RESCALE(cos(2 * M_PI / 10)), RESCALE(sin(2 * M_PI / 10)) };
+    TX_TAB(ff_tx_tab_53)[0] = RESCALE(cos(2 * M_PI / 12));
+    TX_TAB(ff_tx_tab_53)[1] = RESCALE(cos(2 * M_PI / 12));
+    TX_TAB(ff_tx_tab_53)[2] = RESCALE(cos(2 * M_PI /  6));
+    TX_TAB(ff_tx_tab_53)[3] = RESCALE(cos(2 * M_PI /  6));
+    TX_TAB(ff_tx_tab_53)[4] = RESCALE(cos(2 * M_PI /  5));
+    TX_TAB(ff_tx_tab_53)[5] = RESCALE(sin(2 * M_PI /  5));
+    TX_TAB(ff_tx_tab_53)[6] = RESCALE(cos(2 * M_PI / 10));
+    TX_TAB(ff_tx_tab_53)[7] = RESCALE(sin(2 * M_PI / 10));
 }
 
-static av_cold void ff_init_7_tabs(void)
+static av_cold void TX_TAB(ff_tx_init_tab_7)(void)
 {
-    TX_NAME(ff_cos_7)[0] = (FFTComplex){ RESCALE(cos(2 * M_PI /  7)), RESCALE(sin(2 * M_PI /  7)) };
-    TX_NAME(ff_cos_7)[1] = (FFTComplex){ RESCALE(sin(2 * M_PI / 28)), RESCALE(cos(2 * M_PI / 28)) };
-    TX_NAME(ff_cos_7)[2] = (FFTComplex){ RESCALE(cos(2 * M_PI / 14)), RESCALE(sin(2 * M_PI / 14)) };
+    TX_TAB(ff_tx_tab_7)[0] = RESCALE(cos(2 * M_PI /  7));
+    TX_TAB(ff_tx_tab_7)[1] = RESCALE(sin(2 * M_PI /  7));
+    TX_TAB(ff_tx_tab_7)[2] = RESCALE(sin(2 * M_PI / 28));
+    TX_TAB(ff_tx_tab_7)[3] = RESCALE(cos(2 * M_PI / 28));
+    TX_TAB(ff_tx_tab_7)[4] = RESCALE(cos(2 * M_PI / 14));
+    TX_TAB(ff_tx_tab_7)[5] = RESCALE(sin(2 * M_PI / 14));
 }
 
-static av_cold void ff_init_9_tabs(void)
+static av_cold void TX_TAB(ff_tx_init_tab_9)(void)
 {
-    TX_NAME(ff_cos_9)[0] = (FFTComplex){ RESCALE(cos(2 * M_PI /  3)), RESCALE( sin(2 * M_PI /  3)) };
-    TX_NAME(ff_cos_9)[1] = (FFTComplex){ RESCALE(cos(2 * M_PI /  9)), RESCALE( sin(2 * M_PI /  9)) };
-    TX_NAME(ff_cos_9)[2] = (FFTComplex){ RESCALE(cos(2 * M_PI / 36)), RESCALE( sin(2 * M_PI / 36)) };
-    TX_NAME(ff_cos_9)[3] = (FFTComplex){ TX_NAME(ff_cos_9)[1].re + TX_NAME(ff_cos_9)[2].im,
-                                         TX_NAME(ff_cos_9)[1].im - TX_NAME(ff_cos_9)[2].re };
+    TX_TAB(ff_tx_tab_9)[0] = RESCALE(cos(2 * M_PI /  3));
+    TX_TAB(ff_tx_tab_9)[1] = RESCALE(sin(2 * M_PI /  3));
+    TX_TAB(ff_tx_tab_9)[2] = RESCALE(cos(2 * M_PI /  9));
+    TX_TAB(ff_tx_tab_9)[3] = RESCALE(sin(2 * M_PI /  9));
+    TX_TAB(ff_tx_tab_9)[4] = RESCALE(cos(2 * M_PI / 36));
+    TX_TAB(ff_tx_tab_9)[5] = RESCALE(sin(2 * M_PI / 36));
+    TX_TAB(ff_tx_tab_9)[6] = TX_TAB(ff_tx_tab_9)[2] + TX_TAB(ff_tx_tab_9)[5];
+    TX_TAB(ff_tx_tab_9)[7] = TX_TAB(ff_tx_tab_9)[3] - TX_TAB(ff_tx_tab_9)[4];
 }
 
-static CosTabsInitOnce cos_tabs_init_once[] = {
-    { ff_init_53_tabs, AV_ONCE_INIT },
-    { ff_init_7_tabs, AV_ONCE_INIT },
-    { ff_init_9_tabs, AV_ONCE_INIT },
-    { NULL },
-    { init_cos_tabs_16, AV_ONCE_INIT },
-    { init_cos_tabs_32, AV_ONCE_INIT },
-    { init_cos_tabs_64, AV_ONCE_INIT },
-    { init_cos_tabs_128, AV_ONCE_INIT },
-    { init_cos_tabs_256, AV_ONCE_INIT },
-    { init_cos_tabs_512, AV_ONCE_INIT },
-    { init_cos_tabs_1024, AV_ONCE_INIT },
-    { init_cos_tabs_2048, AV_ONCE_INIT },
-    { init_cos_tabs_4096, AV_ONCE_INIT },
-    { init_cos_tabs_8192, AV_ONCE_INIT },
-    { init_cos_tabs_16384, AV_ONCE_INIT },
-    { init_cos_tabs_32768, AV_ONCE_INIT },
-    { init_cos_tabs_65536, AV_ONCE_INIT },
-    { init_cos_tabs_131072, AV_ONCE_INIT },
+FFSRTabsInitOnce nptwo_tabs_init_once[] = {
+    { TX_TAB(ff_tx_init_tab_53),      AV_ONCE_INIT, { 15, 5, 3 } },
+    { TX_TAB(ff_tx_init_tab_9),       AV_ONCE_INIT, {  9 }       },
+    { TX_TAB(ff_tx_init_tab_7),       AV_ONCE_INIT, {  7 }       },
 };
 
-static av_cold void init_cos_tabs(int index)
+av_cold void TX_TAB(ff_tx_init_tabs)(int len)
 {
-    ff_thread_once(&cos_tabs_init_once[index].control,
-                    cos_tabs_init_once[index].func);
+    int factor_2 = ff_ctz(len);
+    if (factor_2) {
+        int idx = factor_2 - 3;
+        for (int i = 0; i <= idx; i++)
+            ff_thread_once(&sr_tabs_init_once[i].control,
+                            sr_tabs_init_once[i].func);
+        len >>= factor_2;
+    }
+
+    for (int i = 0; i < FF_ARRAY_ELEMS(nptwo_tabs_init_once); i++) {
+        int f, f_idx = 0;
+
+        if (len <= 1)
+            return;
+
+        while ((f = nptwo_tabs_init_once[i].factors[f_idx++])) {
+            if (f % len)
+                continue;
+
+            ff_thread_once(&nptwo_tabs_init_once[i].control,
+                            nptwo_tabs_init_once[i].func);
+            len /= f;
+            break;
+        }
+    }
 }
 
 static av_always_inline void fft3(FFTComplex *out, FFTComplex *in,
                                   ptrdiff_t stride)
 {
     FFTComplex tmp[2];
+    const FFTSample *tab = TX_TAB(ff_tx_tab_53);
 #ifdef TX_INT32
     int64_t mtmp[4];
 #endif
@@ -163,19 +189,19 @@  static av_always_inline void fft3(FFTComplex *out, FFTComplex *in,
     out[0*stride].im = in[0].im + tmp[1].im;
 
 #ifdef TX_INT32
-    mtmp[0] = (int64_t)TX_NAME(ff_cos_53)[0].re * tmp[0].re;
-    mtmp[1] = (int64_t)TX_NAME(ff_cos_53)[0].im * tmp[0].im;
-    mtmp[2] = (int64_t)TX_NAME(ff_cos_53)[1].re * tmp[1].re;
-    mtmp[3] = (int64_t)TX_NAME(ff_cos_53)[1].re * tmp[1].im;
+    mtmp[0] = (int64_t)tab[0] * tmp[0].re;
+    mtmp[1] = (int64_t)tab[1] * tmp[0].im;
+    mtmp[2] = (int64_t)tab[2] * tmp[1].re;
+    mtmp[3] = (int64_t)tab[2] * tmp[1].im;
     out[1*stride].re = in[0].re - (mtmp[2] + mtmp[0] + 0x40000000 >> 31);
     out[1*stride].im = in[0].im - (mtmp[3] - mtmp[1] + 0x40000000 >> 31);
     out[2*stride].re = in[0].re - (mtmp[2] - mtmp[0] + 0x40000000 >> 31);
     out[2*stride].im = in[0].im - (mtmp[3] + mtmp[1] + 0x40000000 >> 31);
 #else
-    tmp[0].re = TX_NAME(ff_cos_53)[0].re * tmp[0].re;
-    tmp[0].im = TX_NAME(ff_cos_53)[0].im * tmp[0].im;
-    tmp[1].re = TX_NAME(ff_cos_53)[1].re * tmp[1].re;
-    tmp[1].im = TX_NAME(ff_cos_53)[1].re * tmp[1].im;
+    tmp[0].re = tab[0] * tmp[0].re;
+    tmp[0].im = tab[1] * tmp[0].im;
+    tmp[1].re = tab[2] * tmp[1].re;
+    tmp[1].im = tab[2] * tmp[1].im;
     out[1*stride].re = in[0].re - tmp[1].re + tmp[0].re;
     out[1*stride].im = in[0].im - tmp[1].im - tmp[0].im;
     out[2*stride].re = in[0].re - tmp[1].re - tmp[0].re;
@@ -183,38 +209,39 @@  static av_always_inline void fft3(FFTComplex *out, FFTComplex *in,
 #endif
 }
 
-#define DECL_FFT5(NAME, D0, D1, D2, D3, D4)                                                       \
-static av_always_inline void NAME(FFTComplex *out, FFTComplex *in,                                \
-                                  ptrdiff_t stride)                                               \
-{                                                                                                 \
-    FFTComplex z0[4], t[6];                                                                       \
-                                                                                                  \
-    BF(t[1].im, t[0].re, in[1].re, in[4].re);                                                     \
-    BF(t[1].re, t[0].im, in[1].im, in[4].im);                                                     \
-    BF(t[3].im, t[2].re, in[2].re, in[3].re);                                                     \
-    BF(t[3].re, t[2].im, in[2].im, in[3].im);                                                     \
-                                                                                                  \
-    out[D0*stride].re = in[0].re + t[0].re + t[2].re;                                             \
-    out[D0*stride].im = in[0].im + t[0].im + t[2].im;                                             \
-                                                                                                  \
-    SMUL(t[4].re, t[0].re, TX_NAME(ff_cos_53)[2].re, TX_NAME(ff_cos_53)[3].re, t[2].re, t[0].re); \
-    SMUL(t[4].im, t[0].im, TX_NAME(ff_cos_53)[2].re, TX_NAME(ff_cos_53)[3].re, t[2].im, t[0].im); \
-    CMUL(t[5].re, t[1].re, TX_NAME(ff_cos_53)[2].im, TX_NAME(ff_cos_53)[3].im, t[3].re, t[1].re); \
-    CMUL(t[5].im, t[1].im, TX_NAME(ff_cos_53)[2].im, TX_NAME(ff_cos_53)[3].im, t[3].im, t[1].im); \
-                                                                                                  \
-    BF(z0[0].re, z0[3].re, t[0].re, t[1].re);                                                     \
-    BF(z0[0].im, z0[3].im, t[0].im, t[1].im);                                                     \
-    BF(z0[2].re, z0[1].re, t[4].re, t[5].re);                                                     \
-    BF(z0[2].im, z0[1].im, t[4].im, t[5].im);                                                     \
-                                                                                                  \
-    out[D1*stride].re = in[0].re + z0[3].re;                                                      \
-    out[D1*stride].im = in[0].im + z0[0].im;                                                      \
-    out[D2*stride].re = in[0].re + z0[2].re;                                                      \
-    out[D2*stride].im = in[0].im + z0[1].im;                                                      \
-    out[D3*stride].re = in[0].re + z0[1].re;                                                      \
-    out[D3*stride].im = in[0].im + z0[2].im;                                                      \
-    out[D4*stride].re = in[0].re + z0[0].re;                                                      \
-    out[D4*stride].im = in[0].im + z0[3].im;                                                      \
+#define DECL_FFT5(NAME, D0, D1, D2, D3, D4)                         \
+static av_always_inline void NAME(FFTComplex *out, FFTComplex *in,  \
+                                  ptrdiff_t stride)                 \
+{                                                                   \
+    FFTComplex z0[4], t[6];                                         \
+    const FFTSample *tab = TX_TAB(ff_tx_tab_53);                    \
+                                                                    \
+    BF(t[1].im, t[0].re, in[1].re, in[4].re);                       \
+    BF(t[1].re, t[0].im, in[1].im, in[4].im);                       \
+    BF(t[3].im, t[2].re, in[2].re, in[3].re);                       \
+    BF(t[3].re, t[2].im, in[2].im, in[3].im);                       \
+                                                                    \
+    out[D0*stride].re = in[0].re + t[0].re + t[2].re;               \
+    out[D0*stride].im = in[0].im + t[0].im + t[2].im;               \
+                                                                    \
+    SMUL(t[4].re, t[0].re, tab[4], tab[6], t[2].re, t[0].re);       \
+    SMUL(t[4].im, t[0].im, tab[4], tab[6], t[2].im, t[0].im);       \
+    CMUL(t[5].re, t[1].re, tab[5], tab[7], t[3].re, t[1].re);       \
+    CMUL(t[5].im, t[1].im, tab[5], tab[7], t[3].im, t[1].im);       \
+                                                                    \
+    BF(z0[0].re, z0[3].re, t[0].re, t[1].re);                       \
+    BF(z0[0].im, z0[3].im, t[0].im, t[1].im);                       \
+    BF(z0[2].re, z0[1].re, t[4].re, t[5].re);                       \
+    BF(z0[2].im, z0[1].im, t[4].im, t[5].im);                       \
+                                                                    \
+    out[D1*stride].re = in[0].re + z0[3].re;                        \
+    out[D1*stride].im = in[0].im + z0[0].im;                        \
+    out[D2*stride].re = in[0].re + z0[2].re;                        \
+    out[D2*stride].im = in[0].im + z0[1].im;                        \
+    out[D3*stride].re = in[0].re + z0[1].re;                        \
+    out[D3*stride].im = in[0].im + z0[2].im;                        \
+    out[D4*stride].re = in[0].re + z0[0].re;                        \
+    out[D4*stride].im = in[0].im + z0[3].im;                        \
 }
 
 DECL_FFT5(fft5,     0,  1,  2,  3,  4)
@@ -226,7 +253,7 @@  static av_always_inline void fft7(FFTComplex *out, FFTComplex *in,
                                   ptrdiff_t stride)
 {
     FFTComplex t[6], z[3];
-    const FFTComplex *tab = TX_NAME(ff_cos_7);
+    const FFTComplex *tab = (const FFTComplex *)TX_TAB(ff_tx_tab_7);
 #ifdef TX_INT32
     int64_t mtmp[12];
 #endif
@@ -312,7 +339,7 @@  static av_always_inline void fft7(FFTComplex *out, FFTComplex *in,
 static av_always_inline void fft9(FFTComplex *out, FFTComplex *in,
                                   ptrdiff_t stride)
 {
-    const FFTComplex *tab = TX_NAME(ff_cos_9);
+    const FFTComplex *tab = (const FFTComplex *)TX_TAB(ff_tx_tab_9);
     FFTComplex t[16], w[4], x[5], y[5], z[2];
 #ifdef TX_INT32
     int64_t mtmp[12];
@@ -468,15 +495,16 @@  static av_always_inline void fft15(FFTComplex *out, FFTComplex *in,
     } while (0)
 
 /* z[0...8n-1], w[1...2n-1] */
-static void split_radix_combine(FFTComplex *z, const FFTSample *cos, int n)
+static inline void TX_NAME(ff_tx_fft_sr_combine)(FFTComplex *z,
+                                                 const FFTSample *cos, int len)
 {
-    int o1 = 2*n;
-    int o2 = 4*n;
-    int o3 = 6*n;
+    int o1 = 2*len;
+    int o2 = 4*len;
+    int o3 = 6*len;
     const FFTSample *wim = cos + o1 - 7;
     FFTSample t1, t2, t3, t4, t5, t6, r0, i0, r1, i1;
 
-    for (int i = 0; i < n; i += 4) {
+    for (int i = 0; i < len; i += 4) {
         TRANSFORM(z[0], z[o1 + 0], z[o2 + 0], z[o3 + 0], cos[0], wim[7]);
         TRANSFORM(z[2], z[o1 + 2], z[o2 + 2], z[o3 + 2], cos[2], wim[5]);
         TRANSFORM(z[4], z[o1 + 4], z[o2 + 4], z[o3 + 4], cos[4], wim[3]);
@@ -493,25 +521,62 @@  static void split_radix_combine(FFTComplex *z, const FFTSample *cos, int n)
     }
 }
 
-#define DECL_FFT(n, n2, n4)                            \
-static void fft##n(FFTComplex *z)                      \
-{                                                      \
-    fft##n2(z);                                        \
-    fft##n4(z + n4*2);                                 \
-    fft##n4(z + n4*3);                                 \
-    split_radix_combine(z, TX_NAME(ff_cos_##n), n4/2); \
+static av_cold int TX_NAME(ff_tx_fft_sr_codelet_init)(AVTXContext *s,
+                                                      const FFTXCodelet *cd,
+                                                      uint64_t flags,
+                                                      FFTXCodeletOptions *opts,
+                                                      int len, int inv,
+                                                      const void *scale)
+{
+    TX_TAB(ff_tx_init_tabs)(len);
+    return ff_tx_gen_ptwo_revtab(s, opts ? opts->invert_lookup : 1);
 }
 
-static void fft2(FFTComplex *z)
+#define DECL_SR_CODELET_DEF(n)                         \
+const FFTXCodelet TX_NAME(ff_tx_fft##n##_ns_def) = {   \
+    .name       = TX_NAME_STR("fft" #n "_ns"),         \
+    .function   = TX_NAME(ff_tx_fft##n##_ns),          \
+    .type       = TX_TYPE(FFT),                        \
+    .flags      = AV_TX_INPLACE | AV_TX_UNALIGNED |    \
+                  FF_TX_PRESHUFFLE,                    \
+    .factors[0] = 2,                                   \
+    .min_len    = n,                                   \
+    .max_len    = n,                                   \
+    .init       = TX_NAME(ff_tx_fft_sr_codelet_init),  \
+    .cpu_flags  = FF_TX_CPU_FLAGS_ALL,                 \
+    .prio       = FF_TX_PRIO_BASE,                     \
+};
+
+#define DECL_SR_CODELET(n, n2, n4)                                   \
+static void TX_NAME(ff_tx_fft##n##_ns)(AVTXContext *s, void *dst,    \
+                                        void *src, ptrdiff_t stride) \
+{                                                                    \
+    FFTComplex *z = dst;                                             \
+    const FFTSample *cos = TX_TAB(ff_tx_tab_##n);                    \
+                                                                     \
+    TX_NAME(ff_tx_fft##n2##_ns)(s, z,        z,        stride);      \
+    TX_NAME(ff_tx_fft##n4##_ns)(s, z + n4*2, z + n4*2, stride);      \
+    TX_NAME(ff_tx_fft##n4##_ns)(s, z + n4*3, z + n4*3, stride);      \
+    TX_NAME(ff_tx_fft_sr_combine)(z, cos, n4 >> 1);                  \
+}                                                                    \
+                                                                     \
+DECL_SR_CODELET_DEF(n)
+
+static void TX_NAME(ff_tx_fft2_ns)(AVTXContext *s, void *dst,
+                                   void *src, ptrdiff_t stride)
 {
+    FFTComplex *z = dst;
     FFTComplex tmp;
+
     BF(tmp.re, z[0].re, z[0].re, z[1].re);
     BF(tmp.im, z[0].im, z[0].im, z[1].im);
     z[1] = tmp;
 }
 
-static void fft4(FFTComplex *z)
+static void TX_NAME(ff_tx_fft4_ns)(AVTXContext *s, void *dst,
+                                   void *src, ptrdiff_t stride)
 {
+    FFTComplex *z = dst;
     FFTSample t1, t2, t3, t4, t5, t6, t7, t8;
 
     BF(t3, t1, z[0].re, z[1].re);
@@ -524,11 +589,14 @@  static void fft4(FFTComplex *z)
     BF(z[2].im, z[0].im, t2, t5);
 }
 
-static void fft8(FFTComplex *z)
+static void TX_NAME(ff_tx_fft8_ns)(AVTXContext *s, void *dst,
+                                   void *src, ptrdiff_t stride)
 {
+    FFTComplex *z = dst;
     FFTSample t1, t2, t3, t4, t5, t6, r0, i0, r1, i1;
+    const FFTSample cos = TX_TAB(ff_tx_tab_8)[1];
 
-    fft4(z);
+    TX_NAME(ff_tx_fft4_ns)(s, z, z, stride);
 
     BF(t1, z[5].re, z[4].re, -z[5].re);
     BF(t2, z[5].im, z[4].im, -z[5].im);
@@ -536,19 +604,23 @@  static void fft8(FFTComplex *z)
     BF(t6, z[7].im, z[6].im, -z[7].im);
 
     BUTTERFLIES(z[0], z[2], z[4], z[6]);
-    TRANSFORM(z[1], z[3], z[5], z[7], RESCALE(M_SQRT1_2), RESCALE(M_SQRT1_2));
+    TRANSFORM(z[1], z[3], z[5], z[7], cos, cos);
 }
 
-static void fft16(FFTComplex *z)
+static void TX_NAME(ff_tx_fft16_ns)(AVTXContext *s, void *dst,
+                                    void *src, ptrdiff_t stride)
 {
+    FFTComplex *z = dst;
+    const FFTSample *cos = TX_TAB(ff_tx_tab_16);
+
     FFTSample t1, t2, t3, t4, t5, t6, r0, i0, r1, i1;
-    FFTSample cos_16_1 = TX_NAME(ff_cos_16)[1];
-    FFTSample cos_16_2 = TX_NAME(ff_cos_16)[2];
-    FFTSample cos_16_3 = TX_NAME(ff_cos_16)[3];
+    FFTSample cos_16_1 = cos[1];
+    FFTSample cos_16_2 = cos[2];
+    FFTSample cos_16_3 = cos[3];
 
-    fft8(z +  0);
-    fft4(z +  8);
-    fft4(z + 12);
+    TX_NAME(ff_tx_fft8_ns)(s, z +  0, z +  0, stride);
+    TX_NAME(ff_tx_fft4_ns)(s, z +  8, z +  8, stride);
+    TX_NAME(ff_tx_fft4_ns)(s, z + 12, z + 12, stride);
 
     t1 = z[ 8].re;
     t2 = z[ 8].im;
@@ -561,90 +633,125 @@  static void fft16(FFTComplex *z)
     TRANSFORM(z[ 3], z[ 7], z[11], z[15], cos_16_3, cos_16_1);
 }
 
-DECL_FFT(32,16,8)
-DECL_FFT(64,32,16)
-DECL_FFT(128,64,32)
-DECL_FFT(256,128,64)
-DECL_FFT(512,256,128)
-DECL_FFT(1024,512,256)
-DECL_FFT(2048,1024,512)
-DECL_FFT(4096,2048,1024)
-DECL_FFT(8192,4096,2048)
-DECL_FFT(16384,8192,4096)
-DECL_FFT(32768,16384,8192)
-DECL_FFT(65536,32768,16384)
-DECL_FFT(131072,65536,32768)
-
-static void (* const fft_dispatch[])(FFTComplex*) = {
-    NULL, fft2, fft4, fft8, fft16, fft32, fft64, fft128, fft256, fft512,
-    fft1024, fft2048, fft4096, fft8192, fft16384, fft32768, fft65536, fft131072
-};
+DECL_SR_CODELET_DEF(2)
+DECL_SR_CODELET_DEF(4)
+DECL_SR_CODELET_DEF(8)
+DECL_SR_CODELET_DEF(16)
+DECL_SR_CODELET(32,16,8)
+DECL_SR_CODELET(64,32,16)
+DECL_SR_CODELET(128,64,32)
+DECL_SR_CODELET(256,128,64)
+DECL_SR_CODELET(512,256,128)
+DECL_SR_CODELET(1024,512,256)
+DECL_SR_CODELET(2048,1024,512)
+DECL_SR_CODELET(4096,2048,1024)
+DECL_SR_CODELET(8192,4096,2048)
+DECL_SR_CODELET(16384,8192,4096)
+DECL_SR_CODELET(32768,16384,8192)
+DECL_SR_CODELET(65536,32768,16384)
+DECL_SR_CODELET(131072,65536,32768)
+
+static void TX_NAME(ff_tx_fft_sr)(AVTXContext *s, void *_dst,
+                                  void *_src, ptrdiff_t stride)
+{
+    FFTComplex *src = _src;
+    FFTComplex *dst = _dst;
+    int *map = s->sub[0].map;
+    int len = s->len;
 
-#define DECL_COMP_FFT(N)                                                       \
-static void compound_fft_##N##xM(AVTXContext *s, void *_out,                   \
-                                 void *_in, ptrdiff_t stride)                  \
-{                                                                              \
-    const int m = s->m, *in_map = s->pfatab, *out_map = in_map + N*m;          \
-    FFTComplex *in = _in;                                                      \
-    FFTComplex *out = _out;                                                    \
-    FFTComplex fft##N##in[N];                                                  \
-    void (*fftp)(FFTComplex *z) = fft_dispatch[av_log2(m)];                    \
-                                                                               \
-    for (int i = 0; i < m; i++) {                                              \
-        for (int j = 0; j < N; j++)                                            \
-            fft##N##in[j] = in[in_map[i*N + j]];                               \
-        fft##N(s->tmp + s->revtab_c[i], fft##N##in, m);                        \
-    }                                                                          \
-                                                                               \
-    for (int i = 0; i < N; i++)                                                \
-        fftp(s->tmp + m*i);                                                    \
-                                                                               \
-    for (int i = 0; i < N*m; i++)                                              \
-        out[i] = s->tmp[out_map[i]];                                           \
-}
+    /* Compilers can't vectorize this anyway without assuming AVX2, which they
+     * generally don't, at least without -march=native -mtune=native */
+    for (int i = 0; i < len; i++)
+        dst[i] = src[map[i]];
 
-DECL_COMP_FFT(3)
-DECL_COMP_FFT(5)
-DECL_COMP_FFT(7)
-DECL_COMP_FFT(9)
-DECL_COMP_FFT(15)
+    s->fn[0](&s->sub[0], dst, dst, stride);
+}
 
-static void split_radix_fft(AVTXContext *s, void *_out, void *_in,
-                            ptrdiff_t stride)
+static void TX_NAME(ff_tx_fft_sr_inplace)(AVTXContext *s, void *_dst,
+                                          void *_src, ptrdiff_t stride)
 {
-    FFTComplex *in = _in;
-    FFTComplex *out = _out;
-    int m = s->m, mb = av_log2(m);
+    FFTComplex *dst = _dst;
+    FFTComplex tmp;
+    const int *map = s->sub->map;
+    const int *inplace_idx = s->map;
+    int src_idx, dst_idx;
+
+    src_idx = *inplace_idx++;
+    do {
+        tmp = dst[src_idx];
+        dst_idx = map[src_idx];
+        do {
+            FFSWAP(FFTComplex, tmp, dst[dst_idx]);
+            dst_idx = map[dst_idx];
+        } while (dst_idx != src_idx); /* Can be > as well, but is less predictable */
+        dst[dst_idx] = tmp;
+    } while ((src_idx = *inplace_idx++));
 
-    if (s->flags & AV_TX_INPLACE) {
-        FFTComplex tmp;
-        int src, dst, *inplace_idx = s->inplace_idx;
+    s->fn[0](&s->sub[0], dst, dst, stride);
+}
 
-        src = *inplace_idx++;
+static av_cold int TX_NAME(ff_tx_fft_sr_init)(AVTXContext *s,
+                                              const FFTXCodelet *cd,
+                                              uint64_t flags,
+                                              FFTXCodeletOptions *opts,
+                                              int len, int inv,
+                                              const void *scale)
+{
+    int ret;
+    FFTXCodeletOptions sub_opts = { 0 };
 
-        do {
-            tmp = out[src];
-            dst = s->revtab_c[src];
-            do {
-                FFSWAP(FFTComplex, tmp, out[dst]);
-                dst = s->revtab_c[dst];
-            } while (dst != src); /* Can be > as well, but is less predictable */
-            out[dst] = tmp;
-        } while ((src = *inplace_idx++));
+    if (flags & AV_TX_INPLACE) {
+        if ((ret = ff_tx_gen_ptwo_inplace_revtab_idx(s)))
+            return ret;
+        sub_opts.invert_lookup = 0;
     } else {
-        for (int i = 0; i < m; i++)
-            out[i] = in[s->revtab_c[i]];
+        /* For a straightforward lookup, it's faster to do it inverted
+         * (gather, rather than scatter). */
+        sub_opts.invert_lookup = 1;
     }
 
-    fft_dispatch[mb](out);
+    flags &= ~FF_TX_OUT_OF_PLACE; /* We want the subtransform to be */
+    flags |=  AV_TX_INPLACE;      /* in-place */
+    flags |=  FF_TX_PRESHUFFLE;   /* This function handles the permute step */
+
+    if ((ret = ff_tx_init_subtx(s, TX_TYPE(FFT), flags, &sub_opts, len, inv, scale)))
+        return ret;
+
+    return 0;
 }
 
-static void naive_fft(AVTXContext *s, void *_out, void *_in,
-                      ptrdiff_t stride)
+const FFTXCodelet TX_NAME(ff_tx_fft_sr_def) = {
+    .name       = TX_NAME_STR("fft_sr"),
+    .function   = TX_NAME(ff_tx_fft_sr),
+    .type       = TX_TYPE(FFT),
+    .flags      = AV_TX_UNALIGNED | FF_TX_OUT_OF_PLACE,
+    .factors[0] = 2,
+    .min_len    = 2,
+    .max_len    = TX_LEN_UNLIMITED,
+    .init       = TX_NAME(ff_tx_fft_sr_init),
+    .cpu_flags  = FF_TX_CPU_FLAGS_ALL,
+    .prio       = FF_TX_PRIO_BASE,
+};
+
+const FFTXCodelet TX_NAME(ff_tx_fft_sr_inplace_def) = {
+    .name       = TX_NAME_STR("fft_sr_inplace"),
+    .function   = TX_NAME(ff_tx_fft_sr_inplace),
+    .type       = TX_TYPE(FFT),
+    .flags      = AV_TX_UNALIGNED | AV_TX_INPLACE,
+    .factors[0] = 2,
+    .min_len    = 2,
+    .max_len    = TX_LEN_UNLIMITED,
+    .init       = TX_NAME(ff_tx_fft_sr_init),
+    .cpu_flags  = FF_TX_CPU_FLAGS_ALL,
+    .prio       = FF_TX_PRIO_BASE,
+};
+
+static void TX_NAME(ff_tx_fft_naive)(AVTXContext *s, void *_dst, void *_src,
+                                     ptrdiff_t stride)
 {
-    FFTComplex *in = _in;
-    FFTComplex *out = _out;
-    const int n = s->n;
+    FFTComplex *src = _src;
+    FFTComplex *dst = _dst;
+    const int n = s->len;
     double phase = s->inv ? 2.0*M_PI/n : -2.0*M_PI/n;
 
     for(int i = 0; i < n; i++) {
@@ -656,164 +763,218 @@  static void naive_fft(AVTXContext *s, void *_out, void *_in,
                 RESCALE(sin(factor)),
             };
             FFTComplex res;
-            CMUL3(res, in[j], mult);
+            CMUL3(res, src[j], mult);
             tmp.re += res.re;
             tmp.im += res.im;
         }
-        out[i] = tmp;
+        dst[i] = tmp;
     }
 }
 
-#define DECL_COMP_IMDCT(N)                                                     \
-static void compound_imdct_##N##xM(AVTXContext *s, void *_dst, void *_src,     \
-                                   ptrdiff_t stride)                           \
+const FFTXCodelet TX_NAME(ff_tx_fft_naive_def) = {
+    .name       = TX_NAME_STR("fft_naive"),
+    .function   = TX_NAME(ff_tx_fft_naive),
+    .type       = TX_TYPE(FFT),
+    .flags      = AV_TX_UNALIGNED | FF_TX_OUT_OF_PLACE,
+    .factors[0] = TX_FACTOR_ANY,
+    .min_len    = 2,
+    .max_len    = TX_LEN_UNLIMITED,
+    .init       = NULL,
+    .cpu_flags  = FF_TX_CPU_FLAGS_ALL,
+    .prio       = FF_TX_PRIO_MIN,
+};
+
+static av_cold int TX_NAME(ff_tx_fft_pfa_init)(AVTXContext *s,
+                                               const FFTXCodelet *cd,
+                                               uint64_t flags,
+                                               FFTXCodeletOptions *opts,
+                                               int len, int inv,
+                                               const void *scale)
+{
+    int ret;
+    int sub_len = len / cd->factors[0];
+    FFTXCodeletOptions sub_opts = { .invert_lookup = 0 };
+
+    flags &= ~FF_TX_OUT_OF_PLACE; /* We want the subtransform to be */
+    flags |=  AV_TX_INPLACE;      /* in-place */
+    flags |=  FF_TX_PRESHUFFLE;   /* This function handles the permute step */
+
+    if ((ret = ff_tx_init_subtx(s, TX_TYPE(FFT), flags, &sub_opts,
+                                sub_len, inv, scale)))
+        return ret;
+
+    if ((ret = ff_tx_gen_compound_mapping(s, cd->factors[0], sub_len)))
+        return ret;
+
+    if (!(s->tmp = av_malloc(len*sizeof(*s->tmp))))
+        return AVERROR(ENOMEM);
+
+    TX_TAB(ff_tx_init_tabs)(len / sub_len);
+
+    return 0;
+}
+
+#define DECL_COMP_FFT(N)                                                       \
+static void TX_NAME(ff_tx_fft_pfa_##N##xM)(AVTXContext *s, void *_out,         \
+                                           void *_in, ptrdiff_t stride)        \
 {                                                                              \
+    const int m = s->sub->len;                                                 \
+    const int *in_map = s->map, *out_map = in_map + s->len;                    \
+    const int *sub_map = s->sub->map;                                          \
+    FFTComplex *in = _in;                                                      \
+    FFTComplex *out = _out;                                                    \
     FFTComplex fft##N##in[N];                                                  \
-    FFTComplex *z = _dst, *exp = s->exptab;                                    \
-    const int m = s->m, len8 = N*m >> 1;                                       \
-    const int *in_map = s->pfatab, *out_map = in_map + N*m;                    \
-    const FFTSample *src = _src, *in1, *in2;                                   \
-    void (*fftp)(FFTComplex *) = fft_dispatch[av_log2(m)];                     \
-                                                                               \
-    stride /= sizeof(*src); /* To convert it from bytes */                     \
-    in1 = src;                                                                 \
-    in2 = src + ((N*m*2) - 1) * stride;                                        \
                                                                                \
     for (int i = 0; i < m; i++) {                                              \
-        for (int j = 0; j < N; j++) {                                          \
-            const int k = in_map[i*N + j];                                     \
-            FFTComplex tmp = { in2[-k*stride], in1[k*stride] };                \
-            CMUL3(fft##N##in[j], tmp, exp[k >> 1]);                            \
-        }                                                                      \
-        fft##N(s->tmp + s->revtab_c[i], fft##N##in, m);                        \
+        for (int j = 0; j < N; j++)                                            \
+            fft##N##in[j] = in[in_map[i*N + j]];                               \
+        fft##N(s->tmp + sub_map[i], fft##N##in, m);                            \
     }                                                                          \
                                                                                \
     for (int i = 0; i < N; i++)                                                \
-        fftp(s->tmp + m*i);                                                    \
+        s->fn[0](&s->sub[0], s->tmp + m*i, s->tmp + m*i, sizeof(FFTComplex));  \
                                                                                \
-    for (int i = 0; i < len8; i++) {                                           \
-        const int i0 = len8 + i, i1 = len8 - i - 1;                            \
-        const int s0 = out_map[i0], s1 = out_map[i1];                          \
-        FFTComplex src1 = { s->tmp[s1].im, s->tmp[s1].re };                    \
-        FFTComplex src0 = { s->tmp[s0].im, s->tmp[s0].re };                    \
+    for (int i = 0; i < N*m; i++)                                              \
+        out[i] = s->tmp[out_map[i]];                                           \
+}                                                                              \
                                                                                \
-        CMUL(z[i1].re, z[i0].im, src1.re, src1.im, exp[i1].im, exp[i1].re);    \
-        CMUL(z[i0].re, z[i1].im, src0.re, src0.im, exp[i0].im, exp[i0].re);    \
-    }                                                                          \
-}
+const FFTXCodelet TX_NAME(ff_tx_fft_pfa_##N##xM_def) = {                       \
+    .name       = TX_NAME_STR("fft_pfa_" #N "xM"),                             \
+    .function   = TX_NAME(ff_tx_fft_pfa_##N##xM),                              \
+    .type       = TX_TYPE(FFT),                                                \
+    .flags      = AV_TX_UNALIGNED | FF_TX_OUT_OF_PLACE,                        \
+    .factors    = { N, TX_FACTOR_ANY },                                        \
+    .min_len    = N*2,                                                         \
+    .max_len    = TX_LEN_UNLIMITED,                                            \
+    .init       = TX_NAME(ff_tx_fft_pfa_init),                                 \
+    .cpu_flags  = FF_TX_CPU_FLAGS_ALL,                                         \
+    .prio       = FF_TX_PRIO_BASE,                                             \
+};
 
-DECL_COMP_IMDCT(3)
-DECL_COMP_IMDCT(5)
-DECL_COMP_IMDCT(7)
-DECL_COMP_IMDCT(9)
-DECL_COMP_IMDCT(15)
+DECL_COMP_FFT(3)
+DECL_COMP_FFT(5)
+DECL_COMP_FFT(7)
+DECL_COMP_FFT(9)
+DECL_COMP_FFT(15)
 
-#define DECL_COMP_MDCT(N)                                                      \
-static void compound_mdct_##N##xM(AVTXContext *s, void *_dst, void *_src,      \
-                                  ptrdiff_t stride)                            \
-{                                                                              \
-    FFTSample *src = _src, *dst = _dst;                                        \
-    FFTComplex *exp = s->exptab, tmp, fft##N##in[N];                           \
-    const int m = s->m, len4 = N*m, len3 = len4 * 3, len8 = len4 >> 1;         \
-    const int *in_map = s->pfatab, *out_map = in_map + N*m;                    \
-    void (*fftp)(FFTComplex *) = fft_dispatch[av_log2(m)];                     \
-                                                                               \
-    stride /= sizeof(*dst);                                                    \
-                                                                               \
-    for (int i = 0; i < m; i++) { /* Folding and pre-reindexing */             \
-        for (int j = 0; j < N; j++) {                                          \
-            const int k = in_map[i*N + j];                                     \
-            if (k < len4) {                                                    \
-                tmp.re = FOLD(-src[ len4 + k],  src[1*len4 - 1 - k]);          \
-                tmp.im = FOLD(-src[ len3 + k], -src[1*len3 - 1 - k]);          \
-            } else {                                                           \
-                tmp.re = FOLD(-src[ len4 + k], -src[5*len4 - 1 - k]);          \
-                tmp.im = FOLD( src[-len4 + k], -src[1*len3 - 1 - k]);          \
-            }                                                                  \
-            CMUL(fft##N##in[j].im, fft##N##in[j].re, tmp.re, tmp.im,           \
-                 exp[k >> 1].re, exp[k >> 1].im);                              \
-        }                                                                      \
-        fft##N(s->tmp + s->revtab_c[i], fft##N##in, m);                        \
-    }                                                                          \
-                                                                               \
-    for (int i = 0; i < N; i++)                                                \
-        fftp(s->tmp + m*i);                                                    \
-                                                                               \
-    for (int i = 0; i < len8; i++) {                                           \
-        const int i0 = len8 + i, i1 = len8 - i - 1;                            \
-        const int s0 = out_map[i0], s1 = out_map[i1];                          \
-        FFTComplex src1 = { s->tmp[s1].re, s->tmp[s1].im };                    \
-        FFTComplex src0 = { s->tmp[s0].re, s->tmp[s0].im };                    \
-                                                                               \
-        CMUL(dst[2*i1*stride + stride], dst[2*i0*stride], src0.re, src0.im,    \
-             exp[i0].im, exp[i0].re);                                          \
-        CMUL(dst[2*i0*stride + stride], dst[2*i1*stride], src1.re, src1.im,    \
-             exp[i1].im, exp[i1].re);                                          \
-    }                                                                          \
-}
+static void TX_NAME(ff_tx_mdct_naive_fwd)(AVTXContext *s, void *_dst,
+                                          void *_src, ptrdiff_t stride)
+{
+    FFTSample *src = _src;
+    FFTSample *dst = _dst;
+    double scale = s->scale_d;
+    int len = s->len;
+    const double phase = M_PI/(4.0*len);
 
-DECL_COMP_MDCT(3)
-DECL_COMP_MDCT(5)
-DECL_COMP_MDCT(7)
-DECL_COMP_MDCT(9)
-DECL_COMP_MDCT(15)
+    stride /= sizeof(*dst);
+
+    for (int i = 0; i < len; i++) {
+        double sum = 0.0;
+        for (int j = 0; j < len*2; j++) {
+            int a = (2*j + 1 + len) * (2*i + 1);
+            sum += UNSCALE(src[j]) * cos(a * phase);
+        }
+        dst[i*stride] = RESCALE(sum*scale);
+    }
+}
 
-static void monolithic_imdct(AVTXContext *s, void *_dst, void *_src,
-                             ptrdiff_t stride)
+static void TX_NAME(ff_tx_mdct_naive_inv)(AVTXContext *s, void *_dst,
+                                          void *_src, ptrdiff_t stride)
 {
-    FFTComplex *z = _dst, *exp = s->exptab;
-    const int m = s->m, len8 = m >> 1;
-    const FFTSample *src = _src, *in1, *in2;
-    void (*fftp)(FFTComplex *) = fft_dispatch[av_log2(m)];
+    FFTSample *src = _src;
+    FFTSample *dst = _dst;
+    double scale = s->scale_d;
+    int len = s->len >> 1;
+    int len2 = len*2;
+    const double phase = M_PI/(4.0*len2);
 
     stride /= sizeof(*src);
-    in1 = src;
-    in2 = src + ((m*2) - 1) * stride;
 
-    for (int i = 0; i < m; i++) {
-        FFTComplex tmp = { in2[-2*i*stride], in1[2*i*stride] };
-        CMUL3(z[s->revtab_c[i]], tmp, exp[i]);
+    for (int i = 0; i < len; i++) {
+        double sum_d = 0.0;
+        double sum_u = 0.0;
+        double i_d = phase * (4*len  - 2*i - 1);
+        double i_u = phase * (3*len2 + 2*i + 1);
+        for (int j = 0; j < len2; j++) {
+            double a = (2 * j + 1);
+            double a_d = cos(a * i_d);
+            double a_u = cos(a * i_u);
+            double val = UNSCALE(src[j*stride]);
+            sum_d += a_d * val;
+            sum_u += a_u * val;
+        }
+        dst[i +   0] = RESCALE( sum_d*scale);
+        dst[i + len] = RESCALE(-sum_u*scale);
     }
+}
 
-    fftp(z);
+static av_cold int TX_NAME(ff_tx_mdct_naive_init)(AVTXContext *s,
+                                                  const FFTXCodelet *cd,
+                                                  uint64_t flags,
+                                                  FFTXCodeletOptions *opts,
+                                                  int len, int inv,
+                                                  const void *scale)
+{
+    s->scale_d = *((SCALE_TYPE *)scale);
+    s->scale_f = s->scale_d;
+    return 0;
+}
 
-    for (int i = 0; i < len8; i++) {
-        const int i0 = len8 + i, i1 = len8 - i - 1;
-        FFTComplex src1 = { z[i1].im, z[i1].re };
-        FFTComplex src0 = { z[i0].im, z[i0].re };
+const FFTXCodelet TX_NAME(ff_tx_mdct_naive_fwd_def) = {
+    .name       = TX_NAME_STR("mdct_naive_fwd"),
+    .function   = TX_NAME(ff_tx_mdct_naive_fwd),
+    .type       = TX_TYPE(MDCT),
+    .flags      = AV_TX_UNALIGNED | FF_TX_OUT_OF_PLACE | FF_TX_FORWARD_ONLY,
+    .factors    = { 2, TX_FACTOR_ANY }, /* MDCTs need even number of coefficients/samples */
+    .min_len    = 2,
+    .max_len    = TX_LEN_UNLIMITED,
+    .init       = TX_NAME(ff_tx_mdct_naive_init),
+    .cpu_flags  = FF_TX_CPU_FLAGS_ALL,
+    .prio       = FF_TX_PRIO_MIN,
+};
 
-        CMUL(z[i1].re, z[i0].im, src1.re, src1.im, exp[i1].im, exp[i1].re);
-        CMUL(z[i0].re, z[i1].im, src0.re, src0.im, exp[i0].im, exp[i0].re);
-    }
-}
+const FFTXCodelet TX_NAME(ff_tx_mdct_naive_inv_def) = {
+    .name       = TX_NAME_STR("mdct_naive_inv"),
+    .function   = TX_NAME(ff_tx_mdct_naive_inv),
+    .type       = TX_TYPE(MDCT),
+    .flags      = AV_TX_UNALIGNED | FF_TX_OUT_OF_PLACE | FF_TX_INVERSE_ONLY,
+    .factors    = { 2, TX_FACTOR_ANY },
+    .min_len    = 2,
+    .max_len    = TX_LEN_UNLIMITED,
+    .init       = TX_NAME(ff_tx_mdct_naive_init),
+    .cpu_flags  = FF_TX_CPU_FLAGS_ALL,
+    .prio       = FF_TX_PRIO_MIN,
+};
 
-static void monolithic_mdct(AVTXContext *s, void *_dst, void *_src,
-                            ptrdiff_t stride)
+static void TX_NAME(ff_tx_mdct_sr_fwd)(AVTXContext *s, void *_dst, void *_src,
+                                       ptrdiff_t stride)
 {
     FFTSample *src = _src, *dst = _dst;
-    FFTComplex *exp = s->exptab, tmp, *z = _dst;
-    const int m = s->m, len4 = m, len3 = len4 * 3, len8 = len4 >> 1;
-    void (*fftp)(FFTComplex *) = fft_dispatch[av_log2(m)];
+    FFTComplex *exp = s->exp, tmp, *z = _dst;
+    const int len2 = s->len >> 1;
+    const int len4 = s->len >> 2;
+    const int len3 = len2 * 3;
+    const int *sub_map = s->sub->map;
 
     stride /= sizeof(*dst);
 
-    for (int i = 0; i < m; i++) { /* Folding and pre-reindexing */
+    for (int i = 0; i < len2; i++) { /* Folding and pre-reindexing */
         const int k = 2*i;
-        if (k < len4) {
-            tmp.re = FOLD(-src[ len4 + k],  src[1*len4 - 1 - k]);
+        const int idx = sub_map[i];
+        if (k < len2) {
+            tmp.re = FOLD(-src[ len2 + k],  src[1*len2 - 1 - k]);
             tmp.im = FOLD(-src[ len3 + k], -src[1*len3 - 1 - k]);
         } else {
-            tmp.re = FOLD(-src[ len4 + k], -src[5*len4 - 1 - k]);
-            tmp.im = FOLD( src[-len4 + k], -src[1*len3 - 1 - k]);
+            tmp.re = FOLD(-src[ len2 + k], -src[5*len2 - 1 - k]);
+            tmp.im = FOLD( src[-len2 + k], -src[1*len3 - 1 - k]);
         }
-        CMUL(z[s->revtab_c[i]].im, z[s->revtab_c[i]].re, tmp.re, tmp.im,
-             exp[i].re, exp[i].im);
+        CMUL(z[idx].im, z[idx].re, tmp.re, tmp.im, exp[i].re, exp[i].im);
     }
 
-    fftp(z);
+    s->fn[0](&s->sub[0], z, z, sizeof(FFTComplex));
 
-    for (int i = 0; i < len8; i++) {
-        const int i0 = len8 + i, i1 = len8 - i - 1;
+    for (int i = 0; i < len4; i++) {
+        const int i0 = len4 + i, i1 = len4 - i - 1;
         FFTComplex src1 = { z[i1].re, z[i1].im };
         FFTComplex src0 = { z[i0].re, z[i0].im };
 
@@ -824,66 +985,117 @@  static void monolithic_mdct(AVTXContext *s, void *_dst, void *_src,
     }
 }
 
-static void naive_imdct(AVTXContext *s, void *_dst, void *_src,
-                        ptrdiff_t stride)
+static void TX_NAME(ff_tx_mdct_sr_inv)(AVTXContext *s, void *_dst, void *_src,
+                                       ptrdiff_t stride)
 {
-    int len = s->n;
-    int len2 = len*2;
-    FFTSample *src = _src;
-    FFTSample *dst = _dst;
-    double scale = s->scale;
-    const double phase = M_PI/(4.0*len2);
+    FFTComplex *z = _dst, *exp = s->exp;
+    const FFTSample *src = _src, *in1, *in2;
+    const int len2 = s->len >> 1;
+    const int len4 = s->len >> 2;
+    const int *sub_map = s->sub->map;
 
     stride /= sizeof(*src);
+    in1 = src;
+    in2 = src + ((len2*2) - 1) * stride;
 
-    for (int i = 0; i < len; i++) {
-        double sum_d = 0.0;
-        double sum_u = 0.0;
-        double i_d = phase * (4*len  - 2*i - 1);
-        double i_u = phase * (3*len2 + 2*i + 1);
-        for (int j = 0; j < len2; j++) {
-            double a = (2 * j + 1);
-            double a_d = cos(a * i_d);
-            double a_u = cos(a * i_u);
-            double val = UNSCALE(src[j*stride]);
-            sum_d += a_d * val;
-            sum_u += a_u * val;
-        }
-        dst[i +   0] = RESCALE( sum_d*scale);
-        dst[i + len] = RESCALE(-sum_u*scale);
+    for (int i = 0; i < len2; i++) {
+        FFTComplex tmp = { in2[-2*i*stride], in1[2*i*stride] };
+        CMUL3(z[sub_map[i]], tmp, exp[i]);
+    }
+
+    s->fn[0](&s->sub[0], z, z, sizeof(FFTComplex));
+
+    for (int i = 0; i < len4; i++) {
+        const int i0 = len4 + i, i1 = len4 - i - 1;
+        FFTComplex src1 = { z[i1].im, z[i1].re };
+        FFTComplex src0 = { z[i0].im, z[i0].re };
+
+        CMUL(z[i1].re, z[i0].im, src1.re, src1.im, exp[i1].im, exp[i1].re);
+        CMUL(z[i0].re, z[i1].im, src0.re, src0.im, exp[i0].im, exp[i0].re);
     }
 }
 
-static void naive_mdct(AVTXContext *s, void *_dst, void *_src,
-                       ptrdiff_t stride)
+static int TX_NAME(ff_tx_mdct_gen_exp)(AVTXContext *s)
 {
-    int len = s->n*2;
-    FFTSample *src = _src;
-    FFTSample *dst = _dst;
-    double scale = s->scale;
-    const double phase = M_PI/(4.0*len);
+    int len4 = s->len >> 1;
+    double scale = s->scale_d;
+    const double theta = (scale < 0 ? len4 : 0) + 1.0/8.0;
 
-    stride /= sizeof(*dst);
+    if (!(s->exp = av_malloc_array(len4, sizeof(*s->exp))))
+        return AVERROR(ENOMEM);
 
-    for (int i = 0; i < len; i++) {
-        double sum = 0.0;
-        for (int j = 0; j < len*2; j++) {
-            int a = (2*j + 1 + len) * (2*i + 1);
-            sum += UNSCALE(src[j]) * cos(a * phase);
-        }
-        dst[i*stride] = RESCALE(sum*scale);
+    scale = sqrt(fabs(scale));
+    for (int i = 0; i < len4; i++) {
+        const double alpha = M_PI_2 * (i + theta) / len4;
+        s->exp[i].re = RESCALE(cos(alpha) * scale);
+        s->exp[i].im = RESCALE(sin(alpha) * scale);
     }
+
+    return 0;
 }
 
-static void full_imdct_wrapper_fn(AVTXContext *s, void *_dst, void *_src,
-                                  ptrdiff_t stride)
+static av_cold int TX_NAME(ff_tx_mdct_sr_init)(AVTXContext *s,
+                                               const FFTXCodelet *cd,
+                                               uint64_t flags,
+                                               FFTXCodeletOptions *opts,
+                                               int len, int inv,
+                                               const void *scale)
+{
+    int ret;
+    FFTXCodeletOptions sub_opts = { .invert_lookup = 0 };
+
+    s->scale_d = *((SCALE_TYPE *)scale);
+    s->scale_f = s->scale_d;
+
+    flags &= ~FF_TX_OUT_OF_PLACE; /* We want the subtransform to be */
+    flags |=  AV_TX_INPLACE;      /* in-place */
+    flags |=  FF_TX_PRESHUFFLE;   /* This function handles the permute step */
+
+    if ((ret = ff_tx_init_subtx(s, TX_TYPE(FFT), flags, &sub_opts, len >> 1,
+                                inv, scale)))
+        return ret;
+
+    if ((ret = TX_NAME(ff_tx_mdct_gen_exp)(s)))
+        return ret;
+
+    return 0;
+}
+
+const FFTXCodelet TX_NAME(ff_tx_mdct_sr_fwd_def) = {
+    .name       = TX_NAME_STR("mdct_sr_fwd"),
+    .function   = TX_NAME(ff_tx_mdct_sr_fwd),
+    .type       = TX_TYPE(MDCT),
+    .flags      = AV_TX_UNALIGNED | FF_TX_OUT_OF_PLACE | FF_TX_FORWARD_ONLY,
+    .factors[0] = 2,
+    .min_len    = 2,
+    .max_len    = TX_LEN_UNLIMITED,
+    .init       = TX_NAME(ff_tx_mdct_sr_init),
+    .cpu_flags  = FF_TX_CPU_FLAGS_ALL,
+    .prio       = FF_TX_PRIO_BASE,
+};
+
+const FFTXCodelet TX_NAME(ff_tx_mdct_sr_inv_def) = {
+    .name       = TX_NAME_STR("mdct_sr_inv"),
+    .function   = TX_NAME(ff_tx_mdct_sr_inv),
+    .type       = TX_TYPE(MDCT),
+    .flags      = AV_TX_UNALIGNED | FF_TX_OUT_OF_PLACE | FF_TX_INVERSE_ONLY,
+    .factors[0] = 2,
+    .min_len    = 2,
+    .max_len    = TX_LEN_UNLIMITED,
+    .init       = TX_NAME(ff_tx_mdct_sr_init),
+    .cpu_flags  = FF_TX_CPU_FLAGS_ALL,
+    .prio       = FF_TX_PRIO_BASE,
+};
+
+static void TX_NAME(ff_tx_mdct_inv_full)(AVTXContext *s, void *_dst,
+                                         void *_src, ptrdiff_t stride)
 {
-    int len = s->m*s->n*4;
+    int len  = s->len << 1;
     int len2 = len >> 1;
     int len4 = len >> 2;
     FFTSample *dst = _dst;
 
-    s->top_tx(s, dst + len4, _src, stride);
+    s->fn[0](&s->sub[0], dst + len4, _src, stride);
 
     stride /= sizeof(*dst);
 
@@ -893,132 +1105,246 @@  static void full_imdct_wrapper_fn(AVTXContext *s, void *_dst, void *_src,
     }
 }
 
-static int gen_mdct_exptab(AVTXContext *s, int len4, double scale)
+static av_cold int TX_NAME(ff_tx_mdct_inv_full_init)(AVTXContext *s,
+                                                     const FFTXCodelet *cd,
+                                                     uint64_t flags,
+                                                     FFTXCodeletOptions *opts,
+                                                     int len, int inv,
+                                                     const void *scale)
 {
-    const double theta = (scale < 0 ? len4 : 0) + 1.0/8.0;
+    int ret;
 
-    if (!(s->exptab = av_malloc_array(len4, sizeof(*s->exptab))))
-        return AVERROR(ENOMEM);
+    s->scale_d = *((SCALE_TYPE *)scale);
+    s->scale_f = s->scale_d;
 
-    scale = sqrt(fabs(scale));
-    for (int i = 0; i < len4; i++) {
-        const double alpha = M_PI_2 * (i + theta) / len4;
-        s->exptab[i].re = RESCALE(cos(alpha) * scale);
-        s->exptab[i].im = RESCALE(sin(alpha) * scale);
-    }
+    flags &= ~AV_TX_FULL_IMDCT;
+
+    if ((ret = ff_tx_init_subtx(s, TX_TYPE(MDCT), flags, NULL, len, 1, scale)))
+        return ret;
 
     return 0;
 }
 
-int TX_NAME(ff_tx_init_mdct_fft)(AVTXContext *s, av_tx_fn *tx,
-                                 enum AVTXType type, int inv, int len,
-                                 const void *scale, uint64_t flags)
+const FFTXCodelet TX_NAME(ff_tx_mdct_inv_full_def) = {
+    .name       = TX_NAME_STR("mdct_inv_full"),
+    .function   = TX_NAME(ff_tx_mdct_inv_full),
+    .type       = TX_TYPE(MDCT),
+    .flags      = AV_TX_UNALIGNED | FF_TX_OUT_OF_PLACE | AV_TX_FULL_IMDCT,
+    .factors    = { 2, TX_FACTOR_ANY },
+    .min_len    = 2,
+    .max_len    = TX_LEN_UNLIMITED,
+    .init       = TX_NAME(ff_tx_mdct_inv_full_init),
+    .cpu_flags  = FF_TX_CPU_FLAGS_ALL,
+    .prio       = FF_TX_PRIO_BASE,
+};
+
+static av_cold int TX_NAME(ff_tx_mdct_pfa_init)(AVTXContext *s,
+                                                const FFTXCodelet *cd,
+                                                uint64_t flags,
+                                                FFTXCodeletOptions *opts,
+                                                int len, int inv,
+                                                const void *scale)
 {
-    const int is_mdct = ff_tx_type_is_mdct(type);
-    int err, l, n = 1, m = 1, max_ptwo = 1 << (FF_ARRAY_ELEMS(fft_dispatch) - 1);
+    int ret, sub_len;
+    FFTXCodeletOptions sub_opts = { .invert_lookup = 0 };
 
-    if (is_mdct)
-        len >>= 1;
+    len >>= 1;
+    sub_len = len / cd->factors[0];
 
-    l = len;
+    s->scale_d = *((SCALE_TYPE *)scale);
+    s->scale_f = s->scale_d;
 
-#define CHECK_FACTOR(DST, FACTOR, SRC)                                         \
-    if (DST == 1 && !(SRC % FACTOR)) {                                         \
-        DST = FACTOR;                                                          \
-        SRC /= FACTOR;                                                         \
-    }
-    CHECK_FACTOR(n, 15, len)
-    CHECK_FACTOR(n,  9, len)
-    CHECK_FACTOR(n,  7, len)
-    CHECK_FACTOR(n,  5, len)
-    CHECK_FACTOR(n,  3, len)
-#undef CHECK_FACTOR
-
-    /* len must be a power of two now */
-    if (!(len & (len - 1)) && len >= 2 && len <= max_ptwo) {
-        m = len;
-        len = 1;
-    }
+    flags &= ~FF_TX_OUT_OF_PLACE; /* We want the subtransform to be */
+    flags |=  AV_TX_INPLACE;      /* in-place */
+    flags |=  FF_TX_PRESHUFFLE;   /* This function handles the permute step */
 
-    s->n = n;
-    s->m = m;
-    s->inv = inv;
-    s->type = type;
-    s->flags = flags;
-
-    /* If we weren't able to split the length into factors we can handle,
-     * resort to using the naive and slow FT. This also filters out
-     * direct 3, 5 and 15 transforms as they're too niche. */
-    if (len > 1 || m == 1) {
-        if (is_mdct && (l & 1)) /* Odd (i)MDCTs are not supported yet */
-            return AVERROR(ENOSYS);
-        if (flags & AV_TX_INPLACE) /* Neither are in-place naive transforms */
-            return AVERROR(ENOSYS);
-        s->n = l;
-        s->m = 1;
-        *tx = naive_fft;
-        if (is_mdct) {
-            s->scale = *((SCALE_TYPE *)scale);
-            *tx = inv ? naive_imdct : naive_mdct;
-            if (inv && (flags & AV_TX_FULL_IMDCT)) {
-                s->top_tx = *tx;
-                *tx = full_imdct_wrapper_fn;
-            }
-        }
-        return 0;
-    }
+    if ((ret = ff_tx_init_subtx(s, TX_TYPE(FFT), flags, &sub_opts,
+                                sub_len, inv, scale)))
+        return ret;
 
-    if (n > 1 && m > 1) { /* 2D transform case */
-        if ((err = ff_tx_gen_compound_mapping(s)))
-            return err;
-        if (!(s->tmp = av_malloc(n*m*sizeof(*s->tmp))))
-            return AVERROR(ENOMEM);
-        if (!(m & (m - 1))) {
-            *tx = n == 3 ? compound_fft_3xM :
-                  n == 5 ? compound_fft_5xM :
-                  n == 7 ? compound_fft_7xM :
-                  n == 9 ? compound_fft_9xM :
-                           compound_fft_15xM;
-            if (is_mdct)
-                *tx = n == 3 ? inv ? compound_imdct_3xM  : compound_mdct_3xM :
-                      n == 5 ? inv ? compound_imdct_5xM  : compound_mdct_5xM :
-                      n == 7 ? inv ? compound_imdct_7xM  : compound_mdct_7xM :
-                      n == 9 ? inv ? compound_imdct_9xM  : compound_mdct_9xM :
-                               inv ? compound_imdct_15xM : compound_mdct_15xM;
-        }
-    } else { /* Direct transform case */
-        *tx = split_radix_fft;
-        if (is_mdct)
-            *tx = inv ? monolithic_imdct : monolithic_mdct;
-    }
+    if ((ret = ff_tx_gen_compound_mapping(s, cd->factors[0], sub_len)))
+        return ret;
 
-    if (n == 3 || n == 5 || n == 15)
-        init_cos_tabs(0);
-    else if (n == 7)
-        init_cos_tabs(1);
-    else if (n == 9)
-        init_cos_tabs(2);
-
-    if (m != 1 && !(m & (m - 1))) {
-        if ((err = ff_tx_gen_ptwo_revtab(s, n == 1 && !is_mdct && !(flags & AV_TX_INPLACE))))
-            return err;
-        if (flags & AV_TX_INPLACE) {
-            if (is_mdct) /* In-place MDCTs are not supported yet */
-                return AVERROR(ENOSYS);
-            if ((err = ff_tx_gen_ptwo_inplace_revtab_idx(s, s->revtab_c)))
-                return err;
-        }
-        for (int i = 4; i <= av_log2(m); i++)
-            init_cos_tabs(i);
-    }
+    if ((ret = TX_NAME(ff_tx_mdct_gen_exp)(s)))
+        return ret;
 
-    if (is_mdct) {
-        if (inv && (flags & AV_TX_FULL_IMDCT)) {
-            s->top_tx = *tx;
-            *tx = full_imdct_wrapper_fn;
-        }
-        return gen_mdct_exptab(s, n*m, *((SCALE_TYPE *)scale));
-    }
+    if (!(s->tmp = av_malloc(len*sizeof(*s->tmp))))
+        return AVERROR(ENOMEM);
+
+    TX_TAB(ff_tx_init_tabs)(len / sub_len);
 
     return 0;
 }
+
+#define DECL_COMP_IMDCT(N)                                                     \
+static void TX_NAME(ff_tx_mdct_pfa_##N##xM_inv)(AVTXContext *s, void *_dst,    \
+                                                void *_src, ptrdiff_t stride)  \
+{                                                                              \
+    FFTComplex fft##N##in[N];                                                  \
+    FFTComplex *z = _dst, *exp = s->exp;                                       \
+    const FFTSample *src = _src, *in1, *in2;                                   \
+    const int len4 = s->len >> 2;                                              \
+    const int m = s->sub->len;                                                 \
+    const int *in_map = s->map, *out_map = in_map + N*m;                       \
+    const int *sub_map = s->sub->map;                                          \
+                                                                               \
+    stride /= sizeof(*src); /* To convert it from bytes */                     \
+    in1 = src;                                                                 \
+    in2 = src + ((N*m*2) - 1) * stride;                                        \
+                                                                               \
+    for (int i = 0; i < m; i++) {                                              \
+        for (int j = 0; j < N; j++) {                                          \
+            const int k = in_map[i*N + j];                                     \
+            FFTComplex tmp = { in2[-k*stride], in1[k*stride] };                \
+            CMUL3(fft##N##in[j], tmp, exp[k >> 1]);                            \
+        }                                                                      \
+        fft##N(s->tmp + sub_map[i], fft##N##in, m);                            \
+    }                                                                          \
+                                                                               \
+    for (int i = 0; i < N; i++)                                                \
+        s->fn[0](&s->sub[0], s->tmp + m*i, s->tmp + m*i, sizeof(FFTComplex));  \
+                                                                               \
+    for (int i = 0; i < len4; i++) {                                           \
+        const int i0 = len4 + i, i1 = len4 - i - 1;                            \
+        const int s0 = out_map[i0], s1 = out_map[i1];                          \
+        FFTComplex src1 = { s->tmp[s1].im, s->tmp[s1].re };                    \
+        FFTComplex src0 = { s->tmp[s0].im, s->tmp[s0].re };                    \
+                                                                               \
+        CMUL(z[i1].re, z[i0].im, src1.re, src1.im, exp[i1].im, exp[i1].re);    \
+        CMUL(z[i0].re, z[i1].im, src0.re, src0.im, exp[i0].im, exp[i0].re);    \
+    }                                                                          \
+}                                                                              \
+                                                                               \
+const FFTXCodelet TX_NAME(ff_tx_mdct_pfa_##N##xM_inv_def) = {                  \
+    .name       = TX_NAME_STR("mdct_pfa_" #N "xM_inv"),                        \
+    .function   = TX_NAME(ff_tx_mdct_pfa_##N##xM_inv),                         \
+    .type       = TX_TYPE(MDCT),                                               \
+    .flags      = AV_TX_UNALIGNED | FF_TX_OUT_OF_PLACE | FF_TX_INVERSE_ONLY,   \
+    .factors    = { N, TX_FACTOR_ANY },                                        \
+    .min_len    = N*2,                                                         \
+    .max_len    = TX_LEN_UNLIMITED,                                            \
+    .init       = TX_NAME(ff_tx_mdct_pfa_init),                                \
+    .cpu_flags  = FF_TX_CPU_FLAGS_ALL,                                         \
+    .prio       = FF_TX_PRIO_BASE,                                             \
+};
+
+DECL_COMP_IMDCT(3)
+DECL_COMP_IMDCT(5)
+DECL_COMP_IMDCT(7)
+DECL_COMP_IMDCT(9)
+DECL_COMP_IMDCT(15)
+
+#define DECL_COMP_MDCT(N)                                                      \
+static void TX_NAME(ff_tx_mdct_pfa_##N##xM_fwd)(AVTXContext *s, void *_dst,    \
+                                                void *_src, ptrdiff_t stride)  \
+{                                                                              \
+    FFTComplex fft##N##in[N];                                                  \
+    FFTSample *src = _src, *dst = _dst;                                        \
+    FFTComplex *exp = s->exp, tmp;                                             \
+    const int m = s->sub->len;                                                 \
+    const int len4 = N*m;                                                      \
+    const int len3 = len4 * 3;                                                 \
+    const int len8 = s->len >> 2;                                              \
+    const int *in_map = s->map, *out_map = in_map + N*m;                       \
+    const int *sub_map = s->sub->map;                                          \
+                                                                               \
+    stride /= sizeof(*dst);                                                    \
+                                                                               \
+    for (int i = 0; i < m; i++) { /* Folding and pre-reindexing */             \
+        for (int j = 0; j < N; j++) {                                          \
+            const int k = in_map[i*N + j];                                     \
+            if (k < len4) {                                                    \
+                tmp.re = FOLD(-src[ len4 + k],  src[1*len4 - 1 - k]);          \
+                tmp.im = FOLD(-src[ len3 + k], -src[1*len3 - 1 - k]);          \
+            } else {                                                           \
+                tmp.re = FOLD(-src[ len4 + k], -src[5*len4 - 1 - k]);          \
+                tmp.im = FOLD( src[-len4 + k], -src[1*len3 - 1 - k]);          \
+            }                                                                  \
+            CMUL(fft##N##in[j].im, fft##N##in[j].re, tmp.re, tmp.im,           \
+                 exp[k >> 1].re, exp[k >> 1].im);                              \
+        }                                                                      \
+        fft##N(s->tmp + sub_map[i], fft##N##in, m);                            \
+    }                                                                          \
+                                                                               \
+    for (int i = 0; i < N; i++)                                                \
+        s->fn[0](&s->sub[0], s->tmp + m*i, s->tmp + m*i, sizeof(FFTComplex));  \
+                                                                               \
+    for (int i = 0; i < len8; i++) {                                           \
+        const int i0 = len8 + i, i1 = len8 - i - 1;                            \
+        const int s0 = out_map[i0], s1 = out_map[i1];                          \
+        FFTComplex src1 = { s->tmp[s1].re, s->tmp[s1].im };                    \
+        FFTComplex src0 = { s->tmp[s0].re, s->tmp[s0].im };                    \
+                                                                               \
+        CMUL(dst[2*i1*stride + stride], dst[2*i0*stride], src0.re, src0.im,    \
+             exp[i0].im, exp[i0].re);                                          \
+        CMUL(dst[2*i0*stride + stride], dst[2*i1*stride], src1.re, src1.im,    \
+             exp[i1].im, exp[i1].re);                                          \
+    }                                                                          \
+}                                                                              \
+                                                                               \
+const FFTXCodelet TX_NAME(ff_tx_mdct_pfa_##N##xM_fwd_def) = {                  \
+    .name       = TX_NAME_STR("mdct_pfa_" #N "xM_fwd"),                        \
+    .function   = TX_NAME(ff_tx_mdct_pfa_##N##xM_fwd),                         \
+    .type       = TX_TYPE(MDCT),                                               \
+    .flags      = AV_TX_UNALIGNED | FF_TX_OUT_OF_PLACE | FF_TX_FORWARD_ONLY,   \
+    .factors    = { N, TX_FACTOR_ANY },                                        \
+    .min_len    = N*2,                                                         \
+    .max_len    = TX_LEN_UNLIMITED,                                            \
+    .init       = TX_NAME(ff_tx_mdct_pfa_init),                                \
+    .cpu_flags  = FF_TX_CPU_FLAGS_ALL,                                         \
+    .prio       = FF_TX_PRIO_BASE,                                             \
+};
+
+DECL_COMP_MDCT(3)
+DECL_COMP_MDCT(5)
+DECL_COMP_MDCT(7)
+DECL_COMP_MDCT(9)
+DECL_COMP_MDCT(15)
+
+const FFTXCodelet * const TX_NAME(ff_tx_codelet_list)[] = {
+    /* Split-Radix codelets */
+    &TX_NAME(ff_tx_fft2_ns_def),
+    &TX_NAME(ff_tx_fft4_ns_def),
+    &TX_NAME(ff_tx_fft8_ns_def),
+    &TX_NAME(ff_tx_fft16_ns_def),
+    &TX_NAME(ff_tx_fft32_ns_def),
+    &TX_NAME(ff_tx_fft64_ns_def),
+    &TX_NAME(ff_tx_fft128_ns_def),
+    &TX_NAME(ff_tx_fft256_ns_def),
+    &TX_NAME(ff_tx_fft512_ns_def),
+    &TX_NAME(ff_tx_fft1024_ns_def),
+    &TX_NAME(ff_tx_fft2048_ns_def),
+    &TX_NAME(ff_tx_fft4096_ns_def),
+    &TX_NAME(ff_tx_fft8192_ns_def),
+    &TX_NAME(ff_tx_fft16384_ns_def),
+    &TX_NAME(ff_tx_fft32768_ns_def),
+    &TX_NAME(ff_tx_fft65536_ns_def),
+    &TX_NAME(ff_tx_fft131072_ns_def),
+
+    /* Standalone transforms */
+    &TX_NAME(ff_tx_fft_sr_def),
+    &TX_NAME(ff_tx_fft_sr_inplace_def),
+    &TX_NAME(ff_tx_fft_pfa_3xM_def),
+    &TX_NAME(ff_tx_fft_pfa_5xM_def),
+    &TX_NAME(ff_tx_fft_pfa_7xM_def),
+    &TX_NAME(ff_tx_fft_pfa_9xM_def),
+    &TX_NAME(ff_tx_fft_pfa_15xM_def),
+    &TX_NAME(ff_tx_fft_naive_def),
+    &TX_NAME(ff_tx_mdct_sr_fwd_def),
+    &TX_NAME(ff_tx_mdct_sr_inv_def),
+    &TX_NAME(ff_tx_mdct_pfa_3xM_fwd_def),
+    &TX_NAME(ff_tx_mdct_pfa_5xM_fwd_def),
+    &TX_NAME(ff_tx_mdct_pfa_7xM_fwd_def),
+    &TX_NAME(ff_tx_mdct_pfa_9xM_fwd_def),
+    &TX_NAME(ff_tx_mdct_pfa_15xM_fwd_def),
+    &TX_NAME(ff_tx_mdct_pfa_3xM_inv_def),
+    &TX_NAME(ff_tx_mdct_pfa_5xM_inv_def),
+    &TX_NAME(ff_tx_mdct_pfa_7xM_inv_def),
+    &TX_NAME(ff_tx_mdct_pfa_9xM_inv_def),
+    &TX_NAME(ff_tx_mdct_pfa_15xM_inv_def),
+    &TX_NAME(ff_tx_mdct_naive_fwd_def),
+    &TX_NAME(ff_tx_mdct_naive_inv_def),
+    &TX_NAME(ff_tx_mdct_inv_full_def),
+
+    NULL,
+};
diff --git a/libavutil/x86/tx_float.asm b/libavutil/x86/tx_float.asm
index 4d2283fae1..963e6cad66 100644
--- a/libavutil/x86/tx_float.asm
+++ b/libavutil/x86/tx_float.asm
@@ -31,6 +31,8 @@ 
 
 %include "x86util.asm"
 
+%define private_prefix ff_tx
+
 %if ARCH_X86_64
 %define ptr resq
 %else
@@ -39,25 +41,22 @@ 
 
 %assign i 16
 %rep 14
-cextern cos_ %+ i %+ _float ; ff_cos_i_float...
+cextern tab_ %+ i %+ _float ; ff_tab_i_float...
 %assign i (i << 1)
 %endrep
 
 struc AVTXContext
-    .n:           resd 1 ; Non-power-of-two part
-    .m:           resd 1 ; Power-of-two part
-    .inv:         resd 1 ; Is inverse
-    .type:        resd 1 ; Type
-    .flags:       resq 1 ; Flags
-    .scale:       resq 1 ; Scale
-
-    .exptab:       ptr 1 ; MDCT exptab
-    .tmp:          ptr 1 ; Temporary buffer needed for all compound transforms
-    .pfatab:       ptr 1 ; Input/Output mapping for compound transforms
-    .revtab:       ptr 1 ; Input mapping for power of two transforms
-    .inplace_idx:  ptr 1 ; Required indices to revtab for in-place transforms
-
-    .top_tx        ptr 1 ;  Used for transforms derived from other transforms
+    .len:          resd 1 ; Length
+    .inv           resd 1 ; Inverse flag
+    .map:           ptr 1 ; Lookup table(s)
+    .exp:           ptr 1 ; Exponentiation factors
+    .tmp:           ptr 1 ; Temporary data
+
+    .sub:           ptr 1 ; Subcontexts
+    .fn:            ptr 4 ; Subcontext functions
+    .nb_sub:       resd 1 ; Subcontext count
+
+    ; Everything else is inaccessible
 endstruc
 
 SECTION_RODATA 32
@@ -485,8 +484,8 @@  SECTION .text
     movaps [outq + 10*mmsize], tx1_o0
     movaps [outq + 14*mmsize], tx2_o0
 
-    movaps tw_e,           [cos_64_float + mmsize]
-    vperm2f128 tw_o, tw_o, [cos_64_float + 64 - 4*7 - mmsize], 0x23
+    movaps tw_e,           [tab_64_float + mmsize]
+    vperm2f128 tw_o, tw_o, [tab_64_float + 64 - 4*7 - mmsize], 0x23
 
     movaps m0, [outq +  1*mmsize]
     movaps m1, [outq +  3*mmsize]
@@ -708,14 +707,21 @@  cglobal fft4_ %+ %1 %+ _float, 4, 4, 3, ctx, out, in, stride
 FFT4 fwd, 0
 FFT4 inv, 1
 
+%macro FFT8_FN 2
 INIT_XMM sse3
-cglobal fft8_float, 4, 4, 6, ctx, out, in, tmp
-    mov ctxq, [ctxq + AVTXContext.revtab]
-
+cglobal fft8_ %+ %1, 4, 4, 6, ctx, out, in, tmp
+%if %2
+    mov ctxq, [ctxq + AVTXContext.map]
     LOAD64_LUT m0, inq, ctxq, (mmsize/2)*0, tmpq
     LOAD64_LUT m1, inq, ctxq, (mmsize/2)*1, tmpq
     LOAD64_LUT m2, inq, ctxq, (mmsize/2)*2, tmpq
     LOAD64_LUT m3, inq, ctxq, (mmsize/2)*3, tmpq
+%else
+    movaps m0, [inq + 0*mmsize]
+    movaps m1, [inq + 1*mmsize]
+    movaps m2, [inq + 2*mmsize]
+    movaps m3, [inq + 3*mmsize]
+%endif
 
     FFT8 m0, m1, m2, m3, m4, m5
 
@@ -730,13 +736,22 @@  cglobal fft8_float, 4, 4, 6, ctx, out, in, tmp
     movups [outq + 3*mmsize], m1
 
     RET
+%endmacro
 
-INIT_YMM avx
-cglobal fft8_float, 4, 4, 4, ctx, out, in, tmp
-    mov ctxq, [ctxq + AVTXContext.revtab]
+FFT8_FN float,    1
+FFT8_FN ns_float, 0
 
+%macro FFT16_FN 2
+INIT_YMM avx
+cglobal fft8_ %+ %1, 4, 4, 4, ctx, out, in, tmp
+%if %2
+    mov ctxq, [ctxq + AVTXContext.map]
     LOAD64_LUT m0, inq, ctxq, (mmsize/2)*0, tmpq, m2
     LOAD64_LUT m1, inq, ctxq, (mmsize/2)*1, tmpq, m3
+%else
+    movaps m0, [inq + 0*mmsize]
+    movaps m1, [inq + 1*mmsize]
+%endif
 
     FFT8_AVX m0, m1, m2, m3
 
@@ -750,11 +765,15 @@  cglobal fft8_float, 4, 4, 4, ctx, out, in, tmp
     vextractf128 [outq + 16*3], m0, 1
 
     RET
+%endmacro
+
+FFT16_FN float,    1
+FFT16_FN ns_float, 0
 
 %macro FFT16_FN 1
 INIT_YMM %1
 cglobal fft16_float, 4, 4, 8, ctx, out, in, tmp
-    mov ctxq, [ctxq + AVTXContext.revtab]
+    mov ctxq, [ctxq + AVTXContext.map]
 
     LOAD64_LUT m0, inq, ctxq, (mmsize/2)*0, tmpq, m4
     LOAD64_LUT m1, inq, ctxq, (mmsize/2)*1, tmpq, m5
@@ -786,7 +805,7 @@  FFT16_FN fma3
 %macro FFT32_FN 1
 INIT_YMM %1
 cglobal fft32_float, 4, 4, 16, ctx, out, in, tmp
-    mov ctxq, [ctxq + AVTXContext.revtab]
+    mov ctxq, [ctxq + AVTXContext.map]
 
     LOAD64_LUT m4, inq, ctxq, (mmsize/2)*4, tmpq,  m8,  m9
     LOAD64_LUT m5, inq, ctxq, (mmsize/2)*5, tmpq, m10, m11
@@ -800,8 +819,8 @@  cglobal fft32_float, 4, 4, 16, ctx, out, in, tmp
     LOAD64_LUT m2, inq, ctxq, (mmsize/2)*2, tmpq, m12, m13
     LOAD64_LUT m3, inq, ctxq, (mmsize/2)*3, tmpq, m14, m15
 
-    movaps m8,         [cos_32_float]
-    vperm2f128 m9, m9, [cos_32_float + 4*8 - 4*7], 0x23
+    movaps m8,         [tab_32_float]
+    vperm2f128 m9, m9, [tab_32_float + 4*8 - 4*7], 0x23
 
     FFT16 m0, m1, m2, m3, m10, m11, m12, m13
 
@@ -858,8 +877,8 @@  ALIGN 16
     POP lenq
     sub outq, (%1*4) + (%1*2) + (%1/2)
 
-    lea rtabq, [cos_ %+ %1 %+ _float]
-    lea itabq, [cos_ %+ %1 %+ _float + %1 - 4*7]
+    lea rtabq, [tab_ %+ %1 %+ _float]
+    lea itabq, [tab_ %+ %1 %+ _float + %1 - 4*7]
 
 %if %0 > 1
     cmp tgtq, %1
@@ -883,9 +902,9 @@  ALIGN 16
 
 %macro FFT_SPLIT_RADIX_FN 1
 INIT_YMM %1
-cglobal split_radix_fft_float, 4, 8, 16, 272, lut, out, in, len, tmp, itab, rtab, tgt
-    movsxd lenq, dword [lutq + AVTXContext.m]
-    mov lutq, [lutq + AVTXContext.revtab]
+cglobal fft_sr_float, 4, 8, 16, 272, lut, out, in, len, tmp, itab, rtab, tgt
+    movsxd lenq, dword [lutq + AVTXContext.len]
+    mov lutq, [lutq + AVTXContext.map]
     mov tgtq, lenq
 
 ; Bottom-most/32-point transform ===============================================
@@ -903,8 +922,8 @@  ALIGN 16
     LOAD64_LUT m2, inq, lutq, (mmsize/2)*2, tmpq, m12, m13
     LOAD64_LUT m3, inq, lutq, (mmsize/2)*3, tmpq, m14, m15
 
-    movaps m8,         [cos_32_float]
-    vperm2f128 m9, m9, [cos_32_float + 32 - 4*7], 0x23
+    movaps m8,         [tab_32_float]
+    vperm2f128 m9, m9, [tab_32_float + 32 - 4*7], 0x23
 
     FFT16 m0, m1, m2, m3, m10, m11, m12, m13
 
@@ -961,8 +980,8 @@  ALIGN 16
 
     FFT16 tx2_e0, tx2_e1, tx2_o0, tx2_o1, tmp1, tmp2, tw_e, tw_o
 
-    movaps tw_e,           [cos_64_float]
-    vperm2f128 tw_o, tw_o, [cos_64_float + 64 - 4*7], 0x23
+    movaps tw_e,           [tab_64_float]
+    vperm2f128 tw_o, tw_o, [tab_64_float + 64 - 4*7], 0x23
 
     add lutq, (mmsize/2)*8
     cmp tgtq, 64
@@ -989,8 +1008,8 @@  ALIGN 16
     POP lenq
     sub outq, 24*mmsize
 
-    lea rtabq, [cos_128_float]
-    lea itabq, [cos_128_float + 128 - 4*7]
+    lea rtabq, [tab_128_float]
+    lea itabq, [tab_128_float + 128 - 4*7]
 
     cmp tgtq, 128
     je .deinterleave
@@ -1016,8 +1035,8 @@  ALIGN 16
     POP lenq
     sub outq, 48*mmsize
 
-    lea rtabq, [cos_256_float]
-    lea itabq, [cos_256_float + 256 - 4*7]
+    lea rtabq, [tab_256_float]
+    lea itabq, [tab_256_float + 256 - 4*7]
 
     cmp tgtq, 256
     je .deinterleave
@@ -1044,8 +1063,8 @@  ALIGN 16
     POP lenq
     sub outq, 96*mmsize
 
-    lea rtabq, [cos_512_float]
-    lea itabq, [cos_512_float + 512 - 4*7]
+    lea rtabq, [tab_512_float]
+    lea itabq, [tab_512_float + 512 - 4*7]
 
     cmp tgtq, 512
     je .deinterleave
@@ -1079,8 +1098,8 @@  ALIGN 16
     POP lenq
     sub outq, 192*mmsize
 
-    lea rtabq, [cos_1024_float]
-    lea itabq, [cos_1024_float + 1024 - 4*7]
+    lea rtabq, [tab_1024_float]
+    lea itabq, [tab_1024_float + 1024 - 4*7]
 
     cmp tgtq, 1024
     je .deinterleave
@@ -1160,8 +1179,8 @@  FFT_SPLIT_RADIX_DEF 131072
     vextractf128 [outq + 13*mmsize +  0], tw_e,   1
     vextractf128 [outq + 13*mmsize + 16], tx2_e0, 1
 
-    movaps tw_e,           [cos_64_float + mmsize]
-    vperm2f128 tw_o, tw_o, [cos_64_float + 64 - 4*7 - mmsize], 0x23
+    movaps tw_e,           [tab_64_float + mmsize]
+    vperm2f128 tw_o, tw_o, [tab_64_float + 64 - 4*7 - mmsize], 0x23
 
     movaps m0, [outq +  1*mmsize]
     movaps m1, [outq +  3*mmsize]
diff --git a/libavutil/x86/tx_float_init.c b/libavutil/x86/tx_float_init.c
index 8b77a5f29f..9e9de35530 100644
--- a/libavutil/x86/tx_float_init.c
+++ b/libavutil/x86/tx_float_init.c
@@ -21,86 +21,112 @@ 
 #include "libavutil/attributes.h"
 #include "libavutil/x86/cpu.h"
 
-void ff_fft2_float_sse3     (AVTXContext *s, void *out, void *in, ptrdiff_t stride);
-void ff_fft4_inv_float_sse2 (AVTXContext *s, void *out, void *in, ptrdiff_t stride);
-void ff_fft4_fwd_float_sse2 (AVTXContext *s, void *out, void *in, ptrdiff_t stride);
-void ff_fft8_float_sse3     (AVTXContext *s, void *out, void *in, ptrdiff_t stride);
-void ff_fft8_float_avx      (AVTXContext *s, void *out, void *in, ptrdiff_t stride);
-void ff_fft16_float_avx     (AVTXContext *s, void *out, void *in, ptrdiff_t stride);
-void ff_fft16_float_fma3    (AVTXContext *s, void *out, void *in, ptrdiff_t stride);
-void ff_fft32_float_avx     (AVTXContext *s, void *out, void *in, ptrdiff_t stride);
-void ff_fft32_float_fma3    (AVTXContext *s, void *out, void *in, ptrdiff_t stride);
+#include "config.h"
 
-void ff_split_radix_fft_float_avx (AVTXContext *s, void *out, void *in, ptrdiff_t stride);
-void ff_split_radix_fft_float_avx2(AVTXContext *s, void *out, void *in, ptrdiff_t stride);
+void ff_tx_fft2_float_sse3     (AVTXContext *s, void *out, void *in, ptrdiff_t stride);
+void ff_tx_fft4_inv_float_sse2 (AVTXContext *s, void *out, void *in, ptrdiff_t stride);
+void ff_tx_fft4_fwd_float_sse2 (AVTXContext *s, void *out, void *in, ptrdiff_t stride);
+void ff_tx_fft8_float_sse3     (AVTXContext *s, void *out, void *in, ptrdiff_t stride);
+void ff_tx_fft8_ns_float_sse3  (AVTXContext *s, void *out, void *in, ptrdiff_t stride);
+void ff_tx_fft8_float_avx      (AVTXContext *s, void *out, void *in, ptrdiff_t stride);
+void ff_tx_fft8_ns_float_avx   (AVTXContext *s, void *out, void *in, ptrdiff_t stride);
+void ff_tx_fft16_float_avx     (AVTXContext *s, void *out, void *in, ptrdiff_t stride);
+void ff_tx_fft16_float_fma3    (AVTXContext *s, void *out, void *in, ptrdiff_t stride);
+void ff_tx_fft32_float_avx     (AVTXContext *s, void *out, void *in, ptrdiff_t stride);
+void ff_tx_fft32_float_fma3    (AVTXContext *s, void *out, void *in, ptrdiff_t stride);
 
-av_cold void ff_tx_init_float_x86(AVTXContext *s, av_tx_fn *tx)
-{
-    int cpu_flags = av_get_cpu_flags();
-    int gen_revtab = 0, basis, revtab_interleave;
+void ff_tx_fft_sr_float_avx    (AVTXContext *s, void *out, void *in, ptrdiff_t stride);
+void ff_tx_fft_sr_float_avx2   (AVTXContext *s, void *out, void *in, ptrdiff_t stride);
 
-    if (s->flags & AV_TX_UNALIGNED)
-        return;
-
-    if (ff_tx_type_is_mdct(s->type))
-        return;
+#define DECL_INIT_FN(basis, interleave)                                        \
+static av_cold av_unused int                                                   \
+    ff_tx_fft_sr_codelet_init_b ##basis## _i ##interleave## _x86               \
+    (AVTXContext *s,                                                           \
+     const FFTXCodelet *cd,                                                    \
+     uint64_t flags,                                                           \
+     FFTXCodeletOptions *opts,                                                 \
+     int len, int inv,                                                         \
+     const void *scale)                                                        \
+{                                                                              \
+    const int inv_lookup = opts ? opts->invert_lookup : 1;                     \
+    ff_tx_init_tabs_float(len);                                                \
+    return ff_tx_gen_split_radix_parity_revtab(s, inv_lookup,                  \
+                                               basis, interleave);             \
+}
 
-#define TXFN(fn, gentab, sr_basis, interleave) \
-    do {                                       \
-        *tx = fn;                              \
-        gen_revtab = gentab;                   \
-        basis = sr_basis;                      \
-        revtab_interleave = interleave;        \
-    } while (0)
+#define ff_tx_fft_sr_codelet_init_b0_i0_x86 NULL
+DECL_INIT_FN(8, 0)
+DECL_INIT_FN(8, 2)
 
-    if (s->n == 1) {
-        if (EXTERNAL_SSE2(cpu_flags)) {
-            if (s->m == 4 && s->inv)
-                TXFN(ff_fft4_inv_float_sse2, 0, 0, 0);
-            else if (s->m == 4)
-                TXFN(ff_fft4_fwd_float_sse2, 0, 0, 0);
-        }
+#define DECL_SR_CD_DEF(fn_name, len, init_fn, fn_prio, cpu, fn_flags) \
+const FFTXCodelet ff_tx_ ##fn_name## _def = {                         \
+    .name       = #fn_name,                                           \
+    .function   = ff_tx_ ##fn_name,                                   \
+    .type       = TX_TYPE(FFT),                                       \
+    .flags      = FF_TX_OUT_OF_PLACE | FF_TX_ALIGNED | fn_flags,      \
+    .factors[0] = 2,                                                  \
+    .min_len    = len,                                                \
+    .max_len    = len,                                                \
+    .init       = ff_tx_fft_sr_codelet_init_ ##init_fn## _x86,        \
+    .cpu_flags  = AV_CPU_FLAG_ ##cpu,                                 \
+    .prio       = fn_prio,                                            \
+};
 
-        if (EXTERNAL_SSE3(cpu_flags)) {
-            if (s->m == 2)
-                TXFN(ff_fft2_float_sse3, 0, 0, 0);
-            else if (s->m == 8)
-                TXFN(ff_fft8_float_sse3, 1, 8, 0);
-        }
+DECL_SR_CD_DEF(fft2_float_sse3,      2, b0_i0, 128, SSE3, AV_TX_INPLACE)
+DECL_SR_CD_DEF(fft4_fwd_float_sse2,  4, b0_i0, 128, SSE2, AV_TX_INPLACE | FF_TX_FORWARD_ONLY)
+DECL_SR_CD_DEF(fft4_inv_float_sse2,  4, b0_i0, 128, SSE2, AV_TX_INPLACE | FF_TX_INVERSE_ONLY)
+DECL_SR_CD_DEF(fft8_float_sse3,      8, b8_i0, 128, SSE3, AV_TX_INPLACE)
+DECL_SR_CD_DEF(fft8_ns_float_sse3,   8, b8_i0, 192, SSE3, AV_TX_INPLACE | FF_TX_PRESHUFFLE)
+DECL_SR_CD_DEF(fft8_float_avx,       8, b8_i0, 256, AVX,  AV_TX_INPLACE)
+DECL_SR_CD_DEF(fft8_ns_float_avx,    8, b8_i0, 320, AVX,  AV_TX_INPLACE | FF_TX_PRESHUFFLE)
+DECL_SR_CD_DEF(fft16_float_avx,     16, b8_i2, 256, AVX,  AV_TX_INPLACE)
+DECL_SR_CD_DEF(fft16_float_fma3,    16, b8_i2, 288, FMA3, AV_TX_INPLACE)
+DECL_SR_CD_DEF(fft32_float_avx,     32, b8_i2, 256, AVX,  AV_TX_INPLACE)
+DECL_SR_CD_DEF(fft32_float_fma3,    32, b8_i2, 288, FMA3, AV_TX_INPLACE)
 
-        if (EXTERNAL_AVX_FAST(cpu_flags)) {
-            if (s->m == 8)
-                TXFN(ff_fft8_float_avx, 1, 8, 0);
-            else if (s->m == 16)
-                TXFN(ff_fft16_float_avx, 1, 8, 2);
-#if ARCH_X86_64
-            else if (s->m == 32)
-                TXFN(ff_fft32_float_avx, 1, 8, 2);
-            else if (s->m >= 64 && s->m <= 131072 && !(s->flags & AV_TX_INPLACE))
-                TXFN(ff_split_radix_fft_float_avx, 1, 8, 2);
-#endif
-        }
+const FFTXCodelet ff_tx_fft_sr_float_avx_def = {
+    .name       = "fft_sr_float_avx",
+    .function   = ff_tx_fft_sr_float_avx,
+    .type       = TX_TYPE(FFT),
+    .flags      = FF_TX_ALIGNED | FF_TX_OUT_OF_PLACE,
+    .factors[0] = 2,
+    .min_len    = 64,
+    .max_len    = 131072,
+    .init       = ff_tx_fft_sr_codelet_init_b8_i2_x86,
+    .cpu_flags  = AV_CPU_FLAG_AVX,
+    .prio       = 256,
+};
 
-        if (EXTERNAL_FMA3_FAST(cpu_flags)) {
-            if (s->m == 16)
-                TXFN(ff_fft16_float_fma3, 1, 8, 2);
-#if ARCH_X86_64
-            else if (s->m == 32)
-                TXFN(ff_fft32_float_fma3, 1, 8, 2);
-#endif
-        }
+const FFTXCodelet ff_tx_fft_sr_float_avx2_def = {
+    .name       = "fft_sr_float_avx2",
+    .function   = ff_tx_fft_sr_float_avx2,
+    .type       = TX_TYPE(FFT),
+    .flags      = FF_TX_ALIGNED | FF_TX_OUT_OF_PLACE,
+    .factors[0] = 2,
+    .min_len    = 64,
+    .max_len    = 131072,
+    .init       = ff_tx_fft_sr_codelet_init_b8_i2_x86,
+    .cpu_flags  = AV_CPU_FLAG_AVX2,
+    .prio       = 288,
+};
 
-#if ARCH_X86_64
-        if (EXTERNAL_AVX2_FAST(cpu_flags)) {
-            if (s->m >= 64 && s->m <= 131072 && !(s->flags & AV_TX_INPLACE))
-                TXFN(ff_split_radix_fft_float_avx2, 1, 8, 2);
-        }
-#endif
-    }
+const FFTXCodelet * const ff_tx_codelet_list_float_x86[] = {
+    /* Split-Radix codelets */
+    &ff_tx_fft2_float_sse3_def,
+    &ff_tx_fft4_fwd_float_sse2_def,
+    &ff_tx_fft4_inv_float_sse2_def,
+    &ff_tx_fft8_float_sse3_def,
+    &ff_tx_fft8_ns_float_sse3_def,
+    &ff_tx_fft8_float_avx_def,
+    &ff_tx_fft8_ns_float_avx_def,
+    &ff_tx_fft16_float_avx_def,
+    &ff_tx_fft16_float_fma3_def,
+    &ff_tx_fft32_float_avx_def,
+    &ff_tx_fft32_float_fma3_def,
 
-    if (gen_revtab)
-        ff_tx_gen_split_radix_parity_revtab(s->revtab, s->m, s->inv, basis,
-                                            revtab_interleave);
+    /* Standalone transforms */
+    &ff_tx_fft_sr_float_avx_def,
+    &ff_tx_fft_sr_float_avx2_def,
 
-#undef TXFN
-}
+    NULL,
+};