From ed2ecf9ef87d9944e4c461e2320f15af4adb0b95 Mon Sep 17 00:00:00 2001 From: Georg Lehmann Date: Thu, 19 Jun 2025 12:37:06 +0200 Subject: [PATCH] radv/nir/lower_cmat: share cmat_load/cmat_store code Reviewed-by: Rhys Perry Part-of: --- .../nir/radv_nir_lower_cooperative_matrix.c | 112 +++++------------- 1 file changed, 30 insertions(+), 82 deletions(-) diff --git a/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c b/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c index 5830194e187..de965b87ce8 100644 --- a/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c +++ b/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c @@ -418,89 +418,18 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev progress = true; break; } - case nir_intrinsic_cmat_load: { - nir_deref_instr *dst_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr); - struct glsl_cmat_description desc = *glsl_get_cmat_description(dst_deref->type); - enum glsl_matrix_layout layout = nir_intrinsic_matrix_layout(intr); - - nir_deref_instr *deref = nir_instr_as_deref(intr->src[1].ssa->parent_instr); - nir_def *stride = intr->src[2].ssa; - - nir_def *local_idx = nir_load_subgroup_invocation(&b); - nir_def *inner_idx = nir_iand_imm(&b, local_idx, 15); - - /* A input is transposed */ - if (desc.use == GLSL_CMAT_USE_A) - layout = layout == GLSL_MATRIX_LAYOUT_COLUMN_MAJOR ? GLSL_MATRIX_LAYOUT_ROW_MAJOR - : GLSL_MATRIX_LAYOUT_COLUMN_MAJOR; - - unsigned length = radv_nir_cmat_length(desc, ¶ms); - unsigned mul = radv_nir_cmat_length_mul(desc, ¶ms); - unsigned lanes_per_iter = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? params.wave_size : 16; - nir_def *vars[16]; - if (mul > 1) { - for (unsigned i = 0; i < length; ++i) - if (i % mul != 0) - vars[i] = nir_undef(&b, 1, radv_nir_cmat_bits(desc)); - } - - unsigned idx_bits = deref->def.bit_size; - nir_def *base_row = radv_get_base_row(&b, desc, ¶ms, local_idx); - - for (unsigned i = 0; i < length / mul; ++i) { - nir_def *col_offset = inner_idx; - nir_def *row_offset; - uint32_t row_iter; - - if (gfx_level >= GFX12) { - row_iter = i; - } else { - row_iter = i * lanes_per_iter / 16; - } - - row_offset = nir_iadd_imm(&b, base_row, row_iter); - - if (layout == GLSL_MATRIX_LAYOUT_ROW_MAJOR) { - nir_def *tmp = col_offset; - col_offset = row_offset; - row_offset = tmp; - } - - col_offset = nir_imul(&b, col_offset, stride); - - col_offset = nir_u2uN(&b, col_offset, idx_bits); - row_offset = nir_u2uN(&b, row_offset, idx_bits); - - nir_deref_instr *iter_deref = nir_build_deref_ptr_as_array(&b, deref, col_offset); - iter_deref = nir_build_deref_cast(&b, &iter_deref->def, deref->modes, - glsl_scalar_type(desc.element_type), radv_nir_cmat_bits(desc) / 8); - iter_deref = nir_build_deref_ptr_as_array(&b, iter_deref, row_offset); - - vars[i * mul] = nir_load_deref(&b, iter_deref); - } - - nir_def *mat = nir_vec(&b, vars, length); - nir_store_deref(&b, dst_deref, mat, nir_component_mask(mat->num_components)); - nir_instr_remove(instr); - progress = true; - break; - } + case nir_intrinsic_cmat_load: case nir_intrinsic_cmat_store: { + const bool is_load = intr->intrinsic == nir_intrinsic_cmat_load; + + nir_deref_instr *cmat_deref = nir_instr_as_deref(intr->src[!is_load].ssa->parent_instr); + struct glsl_cmat_description desc = *glsl_get_cmat_description(cmat_deref->type); enum glsl_matrix_layout layout = nir_intrinsic_matrix_layout(intr); - nir_deref_instr *deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr); - nir_def *src = intr->src[1].ssa; + nir_deref_instr *deref = nir_instr_as_deref(intr->src[is_load].ssa->parent_instr); nir_def *stride = intr->src[2].ssa; - nir_deref_instr *src_deref = nir_instr_as_deref(src->parent_instr); - struct glsl_cmat_description desc = *glsl_get_cmat_description(src_deref->type); - src = radv_nir_load_cmat(&b, ¶ms, src); - nir_def *local_idx = nir_load_subgroup_invocation(&b); - - if (gfx_level < GFX12 && desc.use != GLSL_CMAT_USE_ACCUMULATOR) - nir_push_if(&b, nir_ilt_imm(&b, local_idx, 16)); - nir_def *inner_idx = nir_iand_imm(&b, local_idx, 15); /* A input is transposed */ @@ -512,8 +441,20 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev unsigned mul = radv_nir_cmat_length_mul(desc, ¶ms); unsigned lanes_per_iter = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? params.wave_size : 16; nir_def *vars[16]; - for (unsigned i = 0; i < length; ++i) - vars[i] = nir_channel(&b, src, i); + if (is_load) { + if (mul > 1) { + for (unsigned i = 0; i < length; ++i) + if (i % mul != 0) + vars[i] = nir_undef(&b, 1, radv_nir_cmat_bits(desc)); + } + } else { + if (gfx_level < GFX12 && desc.use != GLSL_CMAT_USE_ACCUMULATOR) + nir_push_if(&b, nir_ilt_imm(&b, local_idx, 16)); + + nir_def *src = radv_nir_load_cmat(&b, ¶ms, &cmat_deref->def); + for (unsigned i = 0; i < length; ++i) + vars[i] = nir_channel(&b, src, i); + } unsigned idx_bits = deref->def.bit_size; nir_def *base_row = radv_get_base_row(&b, desc, ¶ms, local_idx); @@ -547,12 +488,19 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev glsl_scalar_type(desc.element_type), radv_nir_cmat_bits(desc) / 8); iter_deref = nir_build_deref_ptr_as_array(&b, iter_deref, row_offset); - nir_store_deref(&b, iter_deref, vars[i * mul], 1); + if (is_load) { + vars[i * mul] = nir_load_deref(&b, iter_deref); + } else { + nir_store_deref(&b, iter_deref, vars[i * mul], 1); + } } - if (gfx_level < GFX12 && desc.use != GLSL_CMAT_USE_ACCUMULATOR) + if (is_load) { + nir_def *mat = nir_vec(&b, vars, length); + nir_store_deref(&b, cmat_deref, mat, nir_component_mask(mat->num_components)); + } else if (gfx_level < GFX12 && desc.use != GLSL_CMAT_USE_ACCUMULATOR) { nir_pop_if(&b, NULL); - + } nir_instr_remove(instr); progress = true; break;