diff --git a/src/amd/common/nir/ac_nir_lower_mem_access_bit_sizes.c b/src/amd/common/nir/ac_nir_lower_mem_access_bit_sizes.c index 6d4c25afa8b..868e16cfaf4 100644 --- a/src/amd/common/nir/ac_nir_lower_mem_access_bit_sizes.c +++ b/src/amd/common/nir/ac_nir_lower_mem_access_bit_sizes.c @@ -81,6 +81,39 @@ lower_mem_access_cb(nir_intrinsic_op intrin, uint8_t bytes, uint8_t bit_size, ui return res; } + if (is_smem) { + /* Round up subdword loads if unsupported. */ + const bool supported_subdword = cb_data->gfx_level >= GFX12 && intrin != nir_intrinsic_load_push_constant; + if (bit_size < 32 && (bytes >= 3 || !supported_subdword)) + bytes = align(bytes, 4); + + /* Generally, require an alignment of 4. */ + res.align = MIN2(4, bytes); + bit_size = MAX2(bit_size, res.align * 8); + + /* Maximum SMEM load size is 512 bits (16 dwords). */ + bytes = MIN2(bytes, 64); + + /* Lower unsupported sizes. */ + if (!util_is_power_of_two_nonzero(bytes) && (cb_data->gfx_level < GFX12 || bytes != 12)) { + const uint8_t larger = util_next_power_of_two(bytes); + const uint8_t smaller = larger / 2; + const bool is_buffer_load = intrin == nir_intrinsic_load_ubo || + intrin == nir_intrinsic_load_ssbo || + intrin == nir_intrinsic_load_constant; + const bool is_aligned = align_mul % smaller == 0; + + /* Overfetch up to 1 dword if this is a bounds-checked buffer load or the access is aligned. */ + bool overfetch = bytes + 4 >= larger && (is_buffer_load || is_aligned); + bytes = overfetch ? larger : smaller; + res.align = is_aligned ? smaller : res.align; + } + res.num_components = DIV_ROUND_UP(bytes, bit_size / 8); + res.bit_size = bit_size; + res.shift = nir_mem_access_shift_method_shift64; + return res; + } + /* Make 8-bit accesses 16-bit if possible */ if (is_load && bit_size == 8 && combined_align >= 2 && bytes % 2 == 0) bit_size = 16; @@ -94,8 +127,6 @@ lower_mem_access_cb(nir_intrinsic_op intrin, uint8_t bytes, uint8_t bit_size, ui if (cb_data->use_llvm && access & (ACCESS_COHERENT | ACCESS_VOLATILE) && (intrin == nir_intrinsic_load_global || intrin == nir_intrinsic_store_global)) max_components = 1; - else if (is_smem) - max_components = MIN2(512 / bit_size, 16); res.num_components = MIN2(DIV_ROUND_UP(bytes, bit_size / 8), max_components); res.bit_size = bit_size; @@ -105,11 +136,8 @@ lower_mem_access_cb(nir_intrinsic_op intrin, uint8_t bytes, uint8_t bit_size, ui if (!is_load) return res; - /* Lower 8/16-bit loads to 32-bit, unless it's a VMEM (or SMEM on GFX12+) scalar load. */ - - const bool supports_scalar_subdword = - !is_smem || (cb_data->gfx_level >= GFX12 && intrin != nir_intrinsic_load_push_constant); - const bool supported_subdword = res.num_components == 1 && supports_scalar_subdword && + /* Lower 8/16-bit loads to 32-bit, unless it's a scalar load. */ + const bool supported_subdword = res.num_components == 1 && (!cb_data->use_llvm || intrin != nir_intrinsic_load_ubo); if (res.bit_size >= 32 || supported_subdword) @@ -122,7 +150,7 @@ lower_mem_access_cb(nir_intrinsic_op intrin, uint8_t bytes, uint8_t bit_size, ui if (align_mul < 4) { /* If we split the load, only lower it to 32-bit if this is a SMEM load. */ const unsigned chunk_bytes = align(bytes, 4) - max_pad; - if (!is_smem && chunk_bytes < bytes) + if (chunk_bytes < bytes) return res; } @@ -135,7 +163,7 @@ lower_mem_access_cb(nir_intrinsic_op intrin, uint8_t bytes, uint8_t bit_size, ui res.num_components = MIN2(res.num_components, max_components); res.bit_size = 32; res.align = 4; - res.shift = is_smem ? res.shift : nir_mem_access_shift_method_bytealign_amd; + res.shift = nir_mem_access_shift_method_bytealign_amd; return res; }