diff mbox series

[FFmpeg-devel,3/3] nlmeans_vulkan: parallelize workgroup invocations

Message ID Ng9fYdg--3-9@lynne.ee
State New
Headers show
Series [FFmpeg-devel,1/3] nlmeans_vulkan: fix width/height for chroma plane weights calculation | expand

Commit Message

Lynne Oct. 7, 2023, 3:07 p.m. UTC
Removes the clever subgroup parallel prefix computation,
and instead just computes the prefix inline.
Cuts down the number of dispatches by a huge amount.

Provides a ~12x speedup (2.5fps to 30fps on a 7900XTX,
2.1fps to 24fps on an Ada).

Patch attached.
diff mbox series

Patch

From a51dd2ace418974c7e8b24a3762bd7495d3b3b10 Mon Sep 17 00:00:00 2001
From: Lynne <dev@lynne.ee>
Date: Fri, 15 Sep 2023 21:55:59 +0200
Subject: [PATCH 3/3] nlmeans_vulkan: parallelize workgroup invocations

---
 libavfilter/Makefile               |   3 +-
 libavfilter/vf_nlmeans_vulkan.c    | 333 ++++++++++++++---------------
 libavfilter/vulkan/prefix_sum.comp | 151 -------------
 3 files changed, 167 insertions(+), 320 deletions(-)
 delete mode 100644 libavfilter/vulkan/prefix_sum.comp

diff --git a/libavfilter/Makefile b/libavfilter/Makefile
index 9a100cd665..603b532ad0 100644
--- a/libavfilter/Makefile
+++ b/libavfilter/Makefile
@@ -395,8 +395,7 @@  OBJS-$(CONFIG_MULTIPLY_FILTER)               += vf_multiply.o
 OBJS-$(CONFIG_NEGATE_FILTER)                 += vf_negate.o
 OBJS-$(CONFIG_NLMEANS_FILTER)                += vf_nlmeans.o
 OBJS-$(CONFIG_NLMEANS_OPENCL_FILTER)         += vf_nlmeans_opencl.o opencl.o opencl/nlmeans.o
-OBJS-$(CONFIG_NLMEANS_VULKAN_FILTER)         += vf_nlmeans_vulkan.o vulkan.o vulkan_filter.o \
-                                                vulkan/prefix_sum.o
+OBJS-$(CONFIG_NLMEANS_VULKAN_FILTER)         += vf_nlmeans_vulkan.o vulkan.o vulkan_filter.o
 OBJS-$(CONFIG_NNEDI_FILTER)                  += vf_nnedi.o
 OBJS-$(CONFIG_NOFORMAT_FILTER)               += vf_format.o
 OBJS-$(CONFIG_NOISE_FILTER)                  += vf_noise.o
diff --git a/libavfilter/vf_nlmeans_vulkan.c b/libavfilter/vf_nlmeans_vulkan.c
index 9741dd67ac..6046ff598c 100644
--- a/libavfilter/vf_nlmeans_vulkan.c
+++ b/libavfilter/vf_nlmeans_vulkan.c
@@ -38,9 +38,10 @@  typedef struct NLMeansVulkanContext {
     VkSampler sampler;
 
     AVBufferPool *integral_buf_pool;
-    AVBufferPool *state_buf_pool;
     AVBufferPool *ws_buf_pool;
 
+    FFVkBuffer xyoffsets_buf;
+
     int pl_weights_rows;
     FFVulkanPipeline pl_weights;
     FFVkSPIRVShader shd_weights;
@@ -66,107 +67,97 @@  typedef struct NLMeansVulkanContext {
 
 extern const char *ff_source_prefix_sum_comp;
 
-static void insert_first(FFVkSPIRVShader *shd, int r, int horiz, int plane, int comp)
+static void insert_first(FFVkSPIRVShader *shd, int r, const char *off, int horiz, int plane, int comp)
 {
-    GLSLF(2,     s1    = texture(input_img[%i], ivec2(x + %i, y + %i))[%i];
-          ,plane, horiz ? r : 0, !horiz ? r : 0, comp);
-
-    if (TYPE_ELEMS == 4) {
-        GLSLF(2, s2[0] = texture(input_img[%i], ivec2(x + %i + xoffs[0], y + %i + yoffs[0]))[%i];
-              ,plane, horiz ? r : 0, !horiz ? r : 0, comp);
-        GLSLF(2, s2[1] = texture(input_img[%i], ivec2(x + %i + xoffs[1], y + %i + yoffs[1]))[%i];
-              ,plane, horiz ? r : 0, !horiz ? r : 0, comp);
-        GLSLF(2, s2[2] = texture(input_img[%i], ivec2(x + %i + xoffs[2], y + %i + yoffs[2]))[%i];
-              ,plane, horiz ? r : 0, !horiz ? r : 0, comp);
-        GLSLF(2, s2[3] = texture(input_img[%i], ivec2(x + %i + xoffs[3], y + %i + yoffs[3]))[%i];
-              ,plane, horiz ? r : 0, !horiz ? r : 0, comp);
-    } else {
-        for (int i = 0; i < 16; i++) {
-            GLSLF(2, s2[%i][%i] = texture(input_img[%i], ivec2(x + %i + xoffs[%i], y + %i + yoffs[%i]))[%i];
-                  ,i / 4, i % 4, plane, horiz ? r : 0, i, !horiz ? r : 0, i, comp);
-        }
-    }
-
-    GLSLC(2, s2 = (s1 - s2) * (s1 - s2);                                       );
+    GLSLF(4, s1    = texture(input_img[%i], ivec2(x + %i + %s, y + %i + %s))[%i];
+          ,plane, horiz ? r : 0, horiz ? off : "0", !horiz ? r : 0, !horiz ? off : "0", comp);
+
+    GLSLF(4, s2[0] = texture(input_img[%i], ivec2(x + %i + %s + xoffs[0], y + %i + %s + yoffs[0]))[%i];
+          ,plane, horiz ? r : 0, horiz ? off : "0", !horiz ? r : 0, !horiz ? off : "0", comp);
+    GLSLF(4, s2[1] = texture(input_img[%i], ivec2(x + %i + %s + xoffs[1], y + %i + %s + yoffs[1]))[%i];
+          ,plane, horiz ? r : 0, horiz ? off : "0", !horiz ? r : 0, !horiz ? off : "0", comp);
+    GLSLF(4, s2[2] = texture(input_img[%i], ivec2(x + %i + %s + xoffs[2], y + %i + %s + yoffs[2]))[%i];
+          ,plane, horiz ? r : 0, horiz ? off : "0", !horiz ? r : 0, !horiz ? off : "0", comp);
+    GLSLF(4, s2[3] = texture(input_img[%i], ivec2(x + %i + %s + xoffs[3], y + %i + %s + yoffs[3]))[%i];
+          ,plane, horiz ? r : 0, horiz ? off : "0", !horiz ? r : 0, !horiz ? off : "0", comp);
+
+    GLSLC(4, s2 = (s1 - s2) * (s1 - s2);                                                    );
 }
 
 static void insert_horizontal_pass(FFVkSPIRVShader *shd, int nb_rows, int first, int plane, int comp)
 {
-    GLSLF(1, x = int(gl_GlobalInvocationID.x) * %i;                   ,nb_rows);
-    if (!first) {
-        GLSLC(1, controlBarrier(gl_ScopeWorkgroup, gl_ScopeWorkgroup,
-                                gl_StorageSemanticsBuffer,
-                                gl_SemanticsAcquireRelease |
-                                gl_SemanticsMakeAvailable |
-                                gl_SemanticsMakeVisible);                     );
-    }
-    GLSLF(1, for (y = 0; y < height[%i]; y++) {                               ,plane);
-    GLSLC(2,     offset = uint64_t(int_stride)*y*T_ALIGN;                     );
-    GLSLC(2,     dst = DataBuffer(uint64_t(integral_data) + offset);          );
-    GLSLC(0,                                                                  );
-    if (first) {
-        for (int r = 0; r < nb_rows; r++) {
-            insert_first(shd, r, 1, plane, comp);
-            GLSLF(2, dst.v[x + %i] = s2;                                    ,r);
-            GLSLC(0,                                                          );
-        }
-    }
-    GLSLC(2,     barrier();                                                   );
-    GLSLC(2,     prefix_sum(dst, 1, dst, 1);                                  );
-    GLSLC(1, }                                                                );
-    GLSLC(0,                                                                  );
+    GLSLF(1, y = int(gl_GlobalInvocationID.x) * %i;                               ,nb_rows);
+    if (!first)
+        GLSLC(1, barrier();                                                       );
+    GLSLC(0,                                                                      );
+    GLSLF(1, if (y < height[%i]) {                                                ,plane);
+    GLSLC(2,     #pragma unroll(1)                                                );
+    GLSLF(2,     for (r = 0; r < %i; r++) {                                       ,nb_rows);
+    GLSLC(3,         prefix_sum = DTYPE(0);                                       );
+    GLSLC(3,         offset = uint64_t(int_stride)*(y + r)*T_ALIGN;               );
+    GLSLC(3,         dst = DataBuffer(uint64_t(integral_data) + offset);          );
+    GLSLC(0,                                                                      );
+    GLSLF(3,         for (x = 0; x < width[%i]; x++) {                            ,plane);
+    if (first)
+        insert_first(shd, 0, "r", 0, plane, comp);
+    else
+        GLSLC(4,         s2 = dst.v[x];                                           );
+    GLSLC(4,             dst.v[x] = s2 + prefix_sum;                              );
+    GLSLC(4,             prefix_sum += s2;                                        );
+    GLSLC(3,         }                                                            );
+    GLSLC(2,     }                                                                );
+    GLSLC(1, }                                                                    );
+    GLSLC(0,                                                                      );
 }
 
 static void insert_vertical_pass(FFVkSPIRVShader *shd, int nb_rows, int first, int plane, int comp)
 {
-    GLSLF(1, y = int(gl_GlobalInvocationID.x) * %i;                   ,nb_rows);
-    if (!first) {
-        GLSLC(1, controlBarrier(gl_ScopeWorkgroup, gl_ScopeWorkgroup,
-                                gl_StorageSemanticsBuffer,
-                                gl_SemanticsAcquireRelease |
-                                gl_SemanticsMakeAvailable |
-                                gl_SemanticsMakeVisible);                     );
-    }
-    GLSLF(1, for (x = 0; x < width[%i]; x++) {                                ,plane);
-    GLSLC(2,     dst = DataBuffer(uint64_t(integral_data) + x*T_ALIGN);       );
-
-    for (int r = 0; r < nb_rows; r++) {
-        if (first) {
-            insert_first(shd, r, 0, plane, comp);
-            GLSLF(2, integral_data.v[(y + %i)*int_stride + x] = s2;         ,r);
-            GLSLC(0,                                                          );
-        }
-    }
-
-    GLSLC(2,     barrier();                                                   );
-    GLSLC(2,     prefix_sum(dst, int_stride, dst, int_stride);                );
-    GLSLC(1, }                                                                );
-    GLSLC(0,                                                                  );
+    GLSLF(1, x = int(gl_GlobalInvocationID.x) * %i;                               ,nb_rows);
+    GLSLC(1, #pragma unroll(1)                                                    );
+    GLSLF(1, for (r = 0; r < %i; r++)                                             ,nb_rows);
+    GLSLC(2,     psum[r] = DTYPE(0);                                              );
+    GLSLC(0,                                                                      );
+    if (!first)
+        GLSLC(1, barrier();                                                       );
+    GLSLC(0,                                                                      );
+    GLSLF(1, if (x < width[%i]) {                                                 ,plane);
+    GLSLF(2,     for (y = 0; y < height[%i]; y++) {                               ,plane);
+    GLSLC(3,         offset = uint64_t(int_stride)*y*T_ALIGN;                     );
+    GLSLC(3,         dst = DataBuffer(uint64_t(integral_data) + offset);          );
+    GLSLC(0,                                                                      );
+    GLSLC(3,         #pragma unroll(1)                                            );
+    GLSLF(3,         for (r = 0; r < %i; r++) {                                   ,nb_rows);
+    if (first)
+        insert_first(shd, 0, "r", 1, plane, comp);
+    else
+        GLSLC(4,         s2 = dst.v[x + r];                                       );
+    GLSLC(4,             dst.v[x + r] = s2 + psum[r];                             );
+    GLSLC(4,             psum[r] += s2;                                           );
+    GLSLC(3,         }                                                            );
+    GLSLC(2,     }                                                                );
+    GLSLC(1, }                                                                    );
+    GLSLC(0,                                                                      );
 }
 
 static void insert_weights_pass(FFVkSPIRVShader *shd, int nb_rows, int vert,
                                 int t, int dst_comp, int plane, int comp)
 {
-    GLSLF(1, p = patch_size[%i];                                     ,dst_comp);
+    GLSLF(1, p = patch_size[%i];                                              ,dst_comp);
     GLSLC(0,                                                                  );
-    GLSLC(1, controlBarrier(gl_ScopeWorkgroup, gl_ScopeWorkgroup,
-                            gl_StorageSemanticsBuffer,
-                            gl_SemanticsAcquireRelease |
-                            gl_SemanticsMakeAvailable |
-                            gl_SemanticsMakeVisible);                         );
     GLSLC(1, barrier();                                                       );
+    GLSLC(0,                                                                  );
     if (!vert) {
         GLSLF(1, for (y = 0; y < height[%i]; y++) {                           ,plane);
         GLSLF(2,     if (gl_GlobalInvocationID.x*%i >= width[%i])             ,nb_rows, plane);
         GLSLC(3,         break;                                               );
-        GLSLF(2,     for (r = 0; r < %i; r++) {                       ,nb_rows);
-        GLSLF(3,         x = int(gl_GlobalInvocationID.x) * %i + r;   ,nb_rows);
+        GLSLF(2,     for (r = 0; r < %i; r++) {                               ,nb_rows);
+        GLSLF(3,         x = int(gl_GlobalInvocationID.x) * %i + r;           ,nb_rows);
     } else {
         GLSLF(1, for (x = 0; x < width[%i]; x++) {                            ,plane);
         GLSLF(2,     if (gl_GlobalInvocationID.x*%i >= height[%i])            ,nb_rows, plane);
         GLSLC(3,         break;                                               );
-        GLSLF(2,     for (r = 0; r < %i; r++) {                       ,nb_rows);
-        GLSLF(3,         y = int(gl_GlobalInvocationID.x) * %i + r;   ,nb_rows);
+        GLSLF(2,     for (r = 0; r < %i; r++) {                               ,nb_rows);
+        GLSLF(3,         y = int(gl_GlobalInvocationID.x) * %i + r;           ,nb_rows);
     }
     GLSLC(0,                                                                  );
     GLSLC(3,         a = DTYPE(0);                                            );
@@ -223,16 +214,15 @@  static void insert_weights_pass(FFVkSPIRVShader *shd, int nb_rows, int vert,
 }
 
 typedef struct HorizontalPushData {
-    VkDeviceAddress integral_data;
-    VkDeviceAddress state_data;
-    int32_t  xoffs[TYPE_ELEMS];
-    int32_t  yoffs[TYPE_ELEMS];
     uint32_t width[4];
     uint32_t height[4];
     uint32_t ws_stride[4];
     int32_t  patch_size[4];
     float    strength[4];
+    VkDeviceAddress integral_base;
+    uint32_t integral_size;
     uint32_t int_stride;
+    uint32_t xyoffs_start;
 } HorizontalPushData;
 
 static av_cold int init_weights_pipeline(FFVulkanContext *vkctx, FFVkExecPool *exec,
@@ -249,26 +239,18 @@  static av_cold int init_weights_pipeline(FFVulkanContext *vkctx, FFVkExecPool *e
     FFVulkanDescriptorSetBinding *desc_set;
     int max_dim = FFMAX(width, height);
     uint32_t max_wg = vkctx->props.properties.limits.maxComputeWorkGroupSize[0];
-    int max_shm = vkctx->props.properties.limits.maxComputeSharedMemorySize;
     int wg_size, wg_rows;
 
     /* Round the max workgroup size to the previous power of two */
-    max_wg = 1 << (31 - ff_clz(max_wg));
     wg_size = max_wg;
     wg_rows = 1;
 
     if (max_wg > max_dim) {
-        wg_size = max_wg / (max_wg / max_dim);
+        wg_size = max_dim;
     } else if (max_wg < max_dim) {
-        /* First, make it fit */
+        /* Make it fit */
         while (wg_size*wg_rows < max_dim)
             wg_rows++;
-
-        /* Second, make sure there's enough shared memory */
-        while ((wg_size * TYPE_SIZE + TYPE_SIZE + 2*4) > max_shm) {
-            wg_size >>= 1;
-            wg_rows++;
-        }
     }
 
     RET(ff_vk_shader_init(pl, shd, "nlmeans_weights", VK_SHADER_STAGE_COMPUTE_BIT, 0));
@@ -278,33 +260,24 @@  static av_cold int init_weights_pipeline(FFVulkanContext *vkctx, FFVkExecPool *e
     if (t > 1)
         GLSLC(0, #extension GL_EXT_shader_atomic_float : require              );
     GLSLC(0, #extension GL_ARB_gpu_shader_int64 : require                     );
-    GLSLC(0, #pragma use_vulkan_memory_model                                  );
-    GLSLC(0, #extension GL_KHR_memory_scope_semantics : enable                );
     GLSLC(0,                                                                  );
-    GLSLF(0, #define N_ROWS %i                                       ,*nb_rows);
-    GLSLC(0, #define WG_SIZE (gl_WorkGroupSize.x)                             );
-    GLSLF(0, #define LG_WG_SIZE %i                ,ff_log2(shd->local_size[0]));
-    GLSLC(0, #define PARTITION_SIZE (N_ROWS*WG_SIZE)                          );
-    GLSLF(0, #define DTYPE %s                                       ,TYPE_NAME);
-    GLSLF(0, #define T_ALIGN %i                                     ,TYPE_SIZE);
+    GLSLF(0, #define DTYPE %s                                                 ,TYPE_NAME);
+    GLSLF(0, #define T_ALIGN %i                                               ,TYPE_SIZE);
     GLSLC(0,                                                                  );
-    GLSLC(0, layout(buffer_reference, buffer_reference_align = T_ALIGN) coherent buffer DataBuffer {  );
+    GLSLC(0, layout(buffer_reference, buffer_reference_align = T_ALIGN) buffer DataBuffer {  );
     GLSLC(1,     DTYPE v[];                                                   );
     GLSLC(0, };                                                               );
     GLSLC(0,                                                                  );
-    GLSLC(0, layout(buffer_reference) buffer StateData;                       );
-    GLSLC(0,                                                                  );
     GLSLC(0, layout(push_constant, std430) uniform pushConstants {            );
-    GLSLC(1,     coherent DataBuffer integral_data;                           );
-    GLSLC(1,     StateData  state;                                            );
-    GLSLF(1,     uint xoffs[%i];                                   ,TYPE_ELEMS);
-    GLSLF(1,     uint yoffs[%i];                                   ,TYPE_ELEMS);
     GLSLC(1,     uvec4 width;                                                 );
     GLSLC(1,     uvec4 height;                                                );
     GLSLC(1,     uvec4 ws_stride;                                             );
     GLSLC(1,     ivec4 patch_size;                                            );
     GLSLC(1,     vec4 strength;                                               );
+    GLSLC(1,     DataBuffer integral_base;                                    );
+    GLSLC(1,     uint integral_size;                                          );
     GLSLC(1,     uint int_stride;                                             );
+    GLSLC(1,     uint xyoffs_start;                                           );
     GLSLC(0, };                                                               );
     GLSLC(0,                                                                  );
 
@@ -370,7 +343,17 @@  static av_cold int init_weights_pipeline(FFVulkanContext *vkctx, FFVkExecPool *e
     };
     RET(ff_vk_pipeline_descriptor_set_add(vkctx, pl, shd, desc_set, 1 + 2*desc->nb_components, 0, 0));
 
-    GLSLD(   ff_source_prefix_sum_comp                                        );
+    desc_set = (FFVulkanDescriptorSetBinding []) {
+        {
+            .name        = "xyoffsets_buffer",
+            .type        = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
+            .mem_quali   = "readonly",
+            .stages      = VK_SHADER_STAGE_COMPUTE_BIT,
+            .buf_content = "int xyoffsets[];",
+        },
+    };
+    RET(ff_vk_pipeline_descriptor_set_add(vkctx, pl, shd, desc_set, 1, 1, 0));
+
     GLSLC(0,                                                                  );
     GLSLC(0, void main()                                                      );
     GLSLC(0, {                                                                );
@@ -378,11 +361,24 @@  static av_cold int init_weights_pipeline(FFVulkanContext *vkctx, FFVkExecPool *e
     GLSLC(1,     DataBuffer dst;                                              );
     GLSLC(1,     float s1;                                                    );
     GLSLC(1,     DTYPE s2;                                                    );
+    GLSLC(1,     DTYPE prefix_sum;                                            );
+    GLSLF(1,     DTYPE psum[%i];                                              ,*nb_rows);
     GLSLC(1,     int r;                                                       );
     GLSLC(1,     int x;                                                       );
     GLSLC(1,     int y;                                                       );
     GLSLC(1,     int p;                                                       );
     GLSLC(0,                                                                  );
+    GLSLC(1,     DataBuffer integral_data;                                    );
+    GLSLF(1,     int xoffs[%i];                                               ,TYPE_ELEMS);
+    GLSLF(1,     int yoffs[%i];                                               ,TYPE_ELEMS);
+    GLSLC(0,                                                                  );
+    GLSLC(1,     int invoc_idx = int(gl_WorkGroupID.z);                       );
+    GLSLC(1,     integral_data = DataBuffer(uint64_t(integral_base) + invoc_idx*integral_size);        );
+    for (int i = 0; i < TYPE_ELEMS*2; i += 2) {
+        GLSLF(1, xoffs[%i] = xyoffsets[xyoffs_start + 2*%i*invoc_idx + %i + 0];       ,i/2,TYPE_ELEMS,i);
+        GLSLF(1, yoffs[%i] = xyoffsets[xyoffs_start + 2*%i*invoc_idx + %i + 1];       ,i/2,TYPE_ELEMS,i);
+    }
+    GLSLC(0,                                                                  );
     GLSLC(1,     DTYPE a;                                                     );
     GLSLC(1,     DTYPE b;                                                     );
     GLSLC(1,     DTYPE c;                                                     );
@@ -405,7 +401,7 @@  static av_cold int init_weights_pipeline(FFVulkanContext *vkctx, FFVkExecPool *e
 
     for (int i = 0; i < desc->nb_components; i++) {
         int off = desc->comp[i].offset / (FFALIGN(desc->comp[i].depth, 8)/8);
-        if (width > height) {
+        if (width >= height) {
             insert_horizontal_pass(shd, *nb_rows, 1, desc->comp[i].plane, off);
             insert_vertical_pass(shd, *nb_rows, 0, desc->comp[i].plane, off);
             insert_weights_pass(shd, *nb_rows, 0, t, i, desc->comp[i].plane, off);
@@ -584,6 +580,7 @@  static av_cold int init_filter(AVFilterContext *ctx)
     FFVulkanContext *vkctx = &s->vkctx;
     const int planes = av_pix_fmt_count_planes(s->vkctx.output_format);
     FFVkSPIRVCompiler *spv;
+    int *offsets_buf;
 
     const AVPixFmtDescriptor *desc;
     desc = av_pix_fmt_desc_get(vkctx->output_format);
@@ -634,6 +631,20 @@  static av_cold int init_filter(AVFilterContext *ctx)
         }
     }
 
+    RET(ff_vk_create_buf(&s->vkctx, &s->xyoffsets_buf, 2*s->nb_offsets*sizeof(int32_t), NULL, NULL,
+                         VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT |
+                         VK_BUFFER_USAGE_STORAGE_BUFFER_BIT,
+                         VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT |
+                         VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT));
+    RET(ff_vk_map_buffer(&s->vkctx, &s->xyoffsets_buf, (uint8_t **)&offsets_buf, 0));
+
+    for (int i = 0; i < 2*s->nb_offsets; i += 2) {
+        offsets_buf[i + 0] = s->xoffsets[i >> 1];
+        offsets_buf[i + 1] = s->yoffsets[i >> 1];
+    }
+
+    RET(ff_vk_unmap_buffer(&s->vkctx, &s->xyoffsets_buf, 1));
+
     s->opts.t = FFMIN(s->opts.t, (FFALIGN(s->nb_offsets, TYPE_ELEMS) / TYPE_ELEMS));
     if (!vkctx->atomic_float_feats.shaderBufferFloat32AtomicAdd) {
         av_log(ctx, AV_LOG_WARNING, "Device doesn't support atomic float adds, "
@@ -641,11 +652,6 @@  static av_cold int init_filter(AVFilterContext *ctx)
         s->opts.t = 1;
     }
 
-    if (!vkctx->feats_12.vulkanMemoryModel) {
-        av_log(ctx, AV_LOG_ERROR, "Device doesn't support the Vulkan memory model!");
-        return AVERROR(EINVAL);;
-    }
-
     spv = ff_vk_spirv_init();
     if (!spv) {
         av_log(ctx, AV_LOG_ERROR, "Unable to initialize SPIR-V compiler!\n");
@@ -663,6 +669,10 @@  static av_cold int init_filter(AVFilterContext *ctx)
     RET(init_denoise_pipeline(vkctx, &s->e, &s->pl_denoise, &s->shd_denoise, s->sampler,
                               spv, desc, planes));
 
+    RET(ff_vk_set_descriptor_buffer(&s->vkctx, &s->pl_weights, NULL, 1, 0, 0,
+                                    s->xyoffsets_buf.address, s->xyoffsets_buf.size,
+                                    VK_FORMAT_UNDEFINED));
+
     av_log(ctx, AV_LOG_VERBOSE, "Filter initialized, %i x/y offsets, %i dispatches, %i parallel\n",
            s->nb_offsets, (FFALIGN(s->nb_offsets, TYPE_ELEMS) / TYPE_ELEMS) + 1, s->opts.t);
 
@@ -736,18 +746,16 @@  static int nlmeans_vulkan_filter_frame(AVFilterLink *link, AVFrame *in)
     int plane_widths[4];
     int plane_heights[4];
 
+    int offsets_dispatched = 0;
+
     /* Integral */
-    AVBufferRef *state_buf;
-    FFVkBuffer *state_vk;
-    AVBufferRef *integral_buf;
+    AVBufferRef *integral_buf = NULL;
     FFVkBuffer *integral_vk;
     uint32_t int_stride;
     size_t int_size;
-    size_t state_size;
-    int t_offset = 0;
 
     /* Weights/sums */
-    AVBufferRef *ws_buf;
+    AVBufferRef *ws_buf = NULL;
     FFVkBuffer *ws_vk;
     VkDeviceAddress weights_addr[4];
     VkDeviceAddress sums_addr[4];
@@ -773,7 +781,6 @@  static int nlmeans_vulkan_filter_frame(AVFilterLink *link, AVFrame *in)
     /* Integral image */
     int_stride = s->pl_weights.wg_size[0]*s->pl_weights_rows;
     int_size = int_stride * int_stride * TYPE_SIZE;
-    state_size = int_stride * 3 *TYPE_SIZE;
 
     /* Plane dimensions */
     for (int i = 0; i < desc->nb_components; i++) {
@@ -798,16 +805,6 @@  static int nlmeans_vulkan_filter_frame(AVFilterLink *link, AVFrame *in)
         return err;
     integral_vk = (FFVkBuffer *)integral_buf->data;
 
-    err = ff_vk_get_pooled_buffer(&s->vkctx, &s->state_buf_pool, &state_buf,
-                                  VK_BUFFER_USAGE_STORAGE_BUFFER_BIT |
-                                  VK_BUFFER_USAGE_SHADER_DEVICE_ADDRESS_BIT,
-                                  NULL,
-                                  s->opts.t * state_size,
-                                  VK_MEMORY_PROPERTY_DEVICE_LOCAL_BIT);
-    if (err < 0)
-        return err;
-    state_vk = (FFVkBuffer *)state_buf->data;
-
     err = ff_vk_get_pooled_buffer(&s->vkctx, &s->ws_buf_pool, &ws_buf,
                                   VK_BUFFER_USAGE_STORAGE_BUFFER_BIT |
                                   VK_BUFFER_USAGE_TRANSFER_DST_BIT |
@@ -844,9 +841,12 @@  static int nlmeans_vulkan_filter_frame(AVFilterLink *link, AVFrame *in)
     RET(ff_vk_exec_add_dep_frame(vkctx, exec, out,
                                  VK_PIPELINE_STAGE_2_ALL_COMMANDS_BIT,
                                  VK_PIPELINE_STAGE_2_COMPUTE_SHADER_BIT));
+
     RET(ff_vk_exec_add_dep_buf(vkctx, exec, &integral_buf, 1, 0));
-    RET(ff_vk_exec_add_dep_buf(vkctx, exec, &state_buf,    1, 0));
+    integral_buf = NULL;
+
     RET(ff_vk_exec_add_dep_buf(vkctx, exec, &ws_buf,       1, 0));
+    ws_buf = NULL;
 
     /* Input frame prep */
     RET(ff_vk_create_imageviews(vkctx, exec, in_views, in));
@@ -869,6 +869,7 @@  static int nlmeans_vulkan_filter_frame(AVFilterLink *link, AVFrame *in)
                         VK_IMAGE_LAYOUT_GENERAL,
                         VK_QUEUE_FAMILY_IGNORED);
 
+    nb_buf_bar = 0;
     buf_bar[nb_buf_bar++] = (VkBufferMemoryBarrier2) {
         .sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER_2,
         .srcStageMask = ws_vk->stage,
@@ -881,6 +882,19 @@  static int nlmeans_vulkan_filter_frame(AVFilterLink *link, AVFrame *in)
         .size = ws_vk->size,
         .offset = 0,
     };
+    buf_bar[nb_buf_bar++] = (VkBufferMemoryBarrier2) {
+        .sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER_2,
+        .srcStageMask = integral_vk->stage,
+        .dstStageMask = VK_PIPELINE_STAGE_2_COMPUTE_SHADER_BIT,
+        .srcAccessMask = integral_vk->access,
+        .dstAccessMask = VK_ACCESS_2_SHADER_STORAGE_READ_BIT |
+                         VK_ACCESS_2_SHADER_STORAGE_WRITE_BIT,
+        .srcQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED,
+        .dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED,
+        .buffer = integral_vk->buf,
+        .size = integral_vk->size,
+        .offset = 0,
+    };
 
     vk->CmdPipelineBarrier2(exec->buf, &(VkDependencyInfo) {
             .sType = VK_STRUCTURE_TYPE_DEPENDENCY_INFO,
@@ -891,10 +905,13 @@  static int nlmeans_vulkan_filter_frame(AVFilterLink *link, AVFrame *in)
         });
     ws_vk->stage = buf_bar[0].dstStageMask;
     ws_vk->access = buf_bar[0].dstAccessMask;
+    integral_vk->stage = buf_bar[1].dstStageMask;
+    integral_vk->access = buf_bar[1].dstAccessMask;
 
-    /* Weights/sums buffer zeroing */
+    /* Buffer zeroing */
     vk->CmdFillBuffer(exec->buf, ws_vk->buf, 0, ws_vk->size, 0x0);
 
+    nb_buf_bar = 0;
     buf_bar[nb_buf_bar++] = (VkBufferMemoryBarrier2) {
         .sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER_2,
         .srcStageMask = ws_vk->stage,
@@ -948,29 +965,22 @@  static int nlmeans_vulkan_filter_frame(AVFilterLink *link, AVFrame *in)
     /* Weights pipeline */
     ff_vk_exec_bind_pipeline(vkctx, exec, &s->pl_weights);
 
-    for (int i = 0; i < s->nb_offsets; i += TYPE_ELEMS) {
-        int *xoffs = s->xoffsets + i;
-        int *yoffs = s->yoffsets + i;
+    do {
+        int wg_invoc;
         HorizontalPushData pd = {
-            integral_vk->address + t_offset*int_size,
-            state_vk->address + t_offset*state_size,
-            { 0 },
-            { 0 },
             { plane_widths[0], plane_widths[1], plane_widths[2], plane_widths[3] },
             { plane_heights[0], plane_heights[1], plane_heights[2], plane_heights[3] },
             { ws_stride[0], ws_stride[1], ws_stride[2], ws_stride[3] },
             { s->patch[0], s->patch[1], s->patch[2], s->patch[3] },
             { s->strength[0], s->strength[1], s->strength[2], s->strength[2], },
+            integral_vk->address,
+            int_size,
             int_stride,
+            offsets_dispatched * 2,
         };
 
-        memcpy(pd.xoffs, xoffs, sizeof(pd.xoffs));
-        memcpy(pd.yoffs, yoffs, sizeof(pd.yoffs));
-
-        /* Put a barrier once we run out of parallelism buffers */
-        if (!t_offset) {
+        if (offsets_dispatched) {
             nb_buf_bar = 0;
-            /* Buffer prep/sync */
             buf_bar[nb_buf_bar++] = (VkBufferMemoryBarrier2) {
                 .sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER_2,
                 .srcStageMask = integral_vk->stage,
@@ -984,39 +994,27 @@  static int nlmeans_vulkan_filter_frame(AVFilterLink *link, AVFrame *in)
                 .size = integral_vk->size,
                 .offset = 0,
             };
-            buf_bar[nb_buf_bar++] = (VkBufferMemoryBarrier2) {
-                .sType = VK_STRUCTURE_TYPE_BUFFER_MEMORY_BARRIER_2,
-                .srcStageMask = state_vk->stage,
-                .dstStageMask = VK_PIPELINE_STAGE_2_COMPUTE_SHADER_BIT,
-                .srcAccessMask = state_vk->access,
-                .dstAccessMask = VK_ACCESS_2_SHADER_STORAGE_READ_BIT |
-                                 VK_ACCESS_2_SHADER_STORAGE_WRITE_BIT,
-                .srcQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED,
-                .dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED,
-                .buffer = state_vk->buf,
-                .size = state_vk->size,
-                .offset = 0,
-            };
 
             vk->CmdPipelineBarrier2(exec->buf, &(VkDependencyInfo) {
                     .sType = VK_STRUCTURE_TYPE_DEPENDENCY_INFO,
                     .pBufferMemoryBarriers = buf_bar,
                     .bufferMemoryBarrierCount = nb_buf_bar,
                 });
-            integral_vk->stage = buf_bar[0].dstStageMask;
-            integral_vk->access = buf_bar[0].dstAccessMask;
-            state_vk->stage = buf_bar[1].dstStageMask;
-            state_vk->access = buf_bar[1].dstAccessMask;
+            integral_vk->stage = buf_bar[1].dstStageMask;
+            integral_vk->access = buf_bar[1].dstAccessMask;
         }
-        t_offset = (t_offset + 1) % s->opts.t;
 
         /* Push data */
         ff_vk_update_push_exec(vkctx, exec, &s->pl_weights, VK_SHADER_STAGE_COMPUTE_BIT,
                                0, sizeof(pd), &pd);
 
+        wg_invoc = FFMIN((s->nb_offsets - offsets_dispatched)/TYPE_ELEMS, s->opts.t);
+
         /* End of horizontal pass */
-        vk->CmdDispatch(exec->buf, 1, 1, 1);
-    }
+        vk->CmdDispatch(exec->buf, 1, 1, wg_invoc);
+
+        offsets_dispatched += wg_invoc * TYPE_ELEMS;
+    } while (offsets_dispatched < s->nb_offsets);
 
     RET(denoise_pass(s, exec, ws_vk, ws_stride));
 
@@ -1033,6 +1031,8 @@  static int nlmeans_vulkan_filter_frame(AVFilterLink *link, AVFrame *in)
     return ff_filter_frame(outlink, out);
 
 fail:
+    av_buffer_unref(&integral_buf);
+    av_buffer_unref(&ws_buf);
     av_frame_free(&in);
     av_frame_free(&out);
     return err;
@@ -1051,7 +1051,6 @@  static void nlmeans_vulkan_uninit(AVFilterContext *avctx)
     ff_vk_shader_free(vkctx, &s->shd_denoise);
 
     av_buffer_pool_uninit(&s->integral_buf_pool);
-    av_buffer_pool_uninit(&s->state_buf_pool);
     av_buffer_pool_uninit(&s->ws_buf_pool);
 
     if (s->sampler)
diff --git a/libavfilter/vulkan/prefix_sum.comp b/libavfilter/vulkan/prefix_sum.comp
deleted file mode 100644
index 9147cd82fb..0000000000
--- a/libavfilter/vulkan/prefix_sum.comp
+++ /dev/null
@@ -1,151 +0,0 @@ 
-#extension GL_EXT_buffer_reference : require
-#extension GL_EXT_buffer_reference2 : require
-
-#define ACQUIRE gl_StorageSemanticsBuffer, gl_SemanticsAcquire
-#define RELEASE gl_StorageSemanticsBuffer, gl_SemanticsRelease
-
-// These correspond to X, A, P respectively in the prefix sum paper.
-#define FLAG_NOT_READY       0u
-#define FLAG_AGGREGATE_READY 1u
-#define FLAG_PREFIX_READY    2u
-
-layout(buffer_reference, buffer_reference_align = T_ALIGN) nonprivate buffer StateData {
-    DTYPE aggregate;
-    DTYPE prefix;
-    uint flag;
-};
-
-shared DTYPE sh_scratch[WG_SIZE];
-shared DTYPE sh_prefix;
-shared uint  sh_part_ix;
-shared uint  sh_flag;
-
-void prefix_sum(DataBuffer dst, uint dst_stride, DataBuffer src, uint src_stride)
-{
-    DTYPE local[N_ROWS];
-    // Determine partition to process by atomic counter (described in Section 4.4 of prefix sum paper).
-    if (gl_GlobalInvocationID.x == 0)
-          sh_part_ix = gl_WorkGroupID.x;
-//        sh_part_ix = atomicAdd(part_counter, 1);
-
-    barrier();
-    uint part_ix = sh_part_ix;
-
-    uint ix = part_ix * PARTITION_SIZE + gl_LocalInvocationID.x * N_ROWS;
-
-    // TODO: gate buffer read? (evaluate whether shader check or CPU-side padding is better)
-    local[0] = src.v[ix*src_stride];
-    for (uint i = 1; i < N_ROWS; i++)
-        local[i] = local[i - 1] + src.v[(ix + i)*src_stride];
-
-    DTYPE agg = local[N_ROWS - 1];
-    sh_scratch[gl_LocalInvocationID.x] = agg;
-    for (uint i = 0; i < LG_WG_SIZE; i++) {
-        barrier();
-        if (gl_LocalInvocationID.x >= (1u << i))
-            agg += sh_scratch[gl_LocalInvocationID.x - (1u << i)];
-        barrier();
-
-        sh_scratch[gl_LocalInvocationID.x] = agg;
-    }
-
-    // Publish aggregate for this partition
-    if (gl_LocalInvocationID.x == WG_SIZE - 1) {
-        state[part_ix].aggregate = agg;
-        if (part_ix == 0)
-            state[0].prefix = agg;
-    }
-
-    // Write flag with release semantics
-    if (gl_LocalInvocationID.x == WG_SIZE - 1) {
-        uint flag = part_ix == 0 ? FLAG_PREFIX_READY : FLAG_AGGREGATE_READY;
-        atomicStore(state[part_ix].flag, flag, gl_ScopeDevice, RELEASE);
-    }
-
-    DTYPE exclusive = DTYPE(0);
-    if (part_ix != 0) {
-        // step 4 of paper: decoupled lookback
-        uint look_back_ix = part_ix - 1;
-
-        DTYPE their_agg;
-        uint their_ix = 0;
-        while (true) {
-            // Read flag with acquire semantics.
-            if (gl_LocalInvocationID.x == WG_SIZE - 1)
-                sh_flag = atomicLoad(state[look_back_ix].flag, gl_ScopeDevice, ACQUIRE);
-
-            // The flag load is done only in the last thread. However, because the
-            // translation of memoryBarrierBuffer to Metal requires uniform control
-            // flow, we broadcast it to all threads.
-            barrier();
-
-            uint flag = sh_flag;
-            barrier();
-
-            if (flag == FLAG_PREFIX_READY) {
-                if (gl_LocalInvocationID.x == WG_SIZE - 1) {
-                    DTYPE their_prefix = state[look_back_ix].prefix;
-                    exclusive = their_prefix + exclusive;
-                }
-                break;
-            } else if (flag == FLAG_AGGREGATE_READY) {
-                if (gl_LocalInvocationID.x == WG_SIZE - 1) {
-                    their_agg = state[look_back_ix].aggregate;
-                    exclusive = their_agg + exclusive;
-                }
-                look_back_ix--;
-                their_ix = 0;
-                continue;
-            } // else spins
-
-            if (gl_LocalInvocationID.x == WG_SIZE - 1) {
-                // Unfortunately there's no guarantee of forward progress of other
-                // workgroups, so compute a bit of the aggregate before trying again.
-                // In the worst case, spinning stops when the aggregate is complete.
-                DTYPE m = src.v[(look_back_ix * PARTITION_SIZE + their_ix)*src_stride];
-                if (their_ix == 0)
-                    their_agg = m;
-                else
-                    their_agg += m;
-
-                their_ix++;
-                if (their_ix == PARTITION_SIZE) {
-                    exclusive = their_agg + exclusive;
-                    if (look_back_ix == 0) {
-                        sh_flag = FLAG_PREFIX_READY;
-                    } else {
-                        look_back_ix--;
-                        their_ix = 0;
-                    }
-                }
-            }
-            barrier();
-            flag = sh_flag;
-            barrier();
-            if (flag == FLAG_PREFIX_READY)
-                break;
-        }
-
-        // step 5 of paper: compute inclusive prefix
-        if (gl_LocalInvocationID.x == WG_SIZE - 1) {
-            DTYPE inclusive_prefix = exclusive + agg;
-            sh_prefix = exclusive;
-            state[part_ix].prefix = inclusive_prefix;
-        }
-
-        if (gl_LocalInvocationID.x == WG_SIZE - 1)
-            atomicStore(state[part_ix].flag, FLAG_PREFIX_READY, gl_ScopeDevice, RELEASE);
-    }
-
-    barrier();
-    if (part_ix != 0)
-        exclusive = sh_prefix;
-
-    DTYPE row = exclusive;
-    if (gl_LocalInvocationID.x > 0)
-        row += sh_scratch[gl_LocalInvocationID.x - 1];
-
-    // note - may overwrite
-    for (uint i = 0; i < N_ROWS; i++)
-        dst.v[(ix + i)*dst_stride] = row + local[i];
-}
-- 
2.42.0