diff --git a/src/gallium/drivers/zink/zink_compiler.c b/src/gallium/drivers/zink/zink_compiler.c index e678e9edf2f..275782c150d 100644 --- a/src/gallium/drivers/zink/zink_compiler.c +++ b/src/gallium/drivers/zink/zink_compiler.c @@ -5313,6 +5313,29 @@ zink_type_size(const struct glsl_type *type, bool bindless) return glsl_count_attribute_slots(type, false); } +static nir_mem_access_size_align +mem_access_size_align_cb(nir_intrinsic_op intrin, uint8_t bytes, + uint8_t bit_size, uint32_t align, + uint32_t align_offset, bool offset_is_const, + const void *cb_data) +{ + align = nir_combined_align(align, align_offset); + + assert(util_is_power_of_two_nonzero(align)); + + return (nir_mem_access_size_align){ + .num_components = MIN2(bytes / (bit_size / 8), 4), + .bit_size = bit_size, + .align = bit_size / 8, + }; +} + +static uint8_t +lower_vec816_alu(const nir_instr *instr, const void *cb_data) +{ + return 4; +} + struct zink_shader * zink_shader_create(struct zink_screen *screen, struct nir_shader *nir) { @@ -5351,6 +5374,19 @@ zink_shader_create(struct zink_screen *screen, struct nir_shader *nir) lower_io_flags |= nir_lower_io_lower_64bit_to_32; NIR_PASS_V(nir, nir_lower_io, nir_var_shader_in, zink_type_size, lower_io_flags); nir->info.io_lowered = true; + + if (nir->info.stage == MESA_SHADER_KERNEL) { + nir_lower_mem_access_bit_sizes_options lower_mem_access_options = { + .modes = nir_var_all, + .may_lower_unaligned_stores_to_atomics = true, + .callback = mem_access_size_align_cb, + .cb_data = screen, + }; + NIR_PASS_V(nir, nir_lower_mem_access_bit_sizes, &lower_mem_access_options); + NIR_PASS_V(nir, nir_lower_alu_width, lower_vec816_alu, NULL); + NIR_PASS_V(nir, nir_lower_alu_vec8_16_srcs); + } + optimize_nir(nir, NULL); nir_foreach_variable_with_modes(var, nir, nir_var_shader_in | nir_var_shader_out) { if (glsl_type_is_image(var->type) || glsl_type_is_sampler(var->type)) {