From b05a112d92b5e83093e7c8fcf86e1ae3eb23c202 Mon Sep 17 00:00:00 2001 From: Samuel Pitoiset Date: Wed, 22 Jan 2025 03:52:10 -0800 Subject: [PATCH] radv/nir: add cooperative matrix lowering for GFX12 Signed-off-by: Samuel Pitoiset Part-of: --- src/amd/vulkan/nir/radv_nir.h | 2 +- .../nir/radv_nir_lower_cooperative_matrix.c | 150 +++++++++++++++--- src/amd/vulkan/radv_shader.c | 2 +- 3 files changed, 127 insertions(+), 27 deletions(-) diff --git a/src/amd/vulkan/nir/radv_nir.h b/src/amd/vulkan/nir/radv_nir.h index 3e1175b5bba..246d9f269ff 100644 --- a/src/amd/vulkan/nir/radv_nir.h +++ b/src/amd/vulkan/nir/radv_nir.h @@ -69,7 +69,7 @@ void radv_nir_lower_io(struct radv_device *device, nir_shader *nir); bool radv_nir_lower_io_to_mem(struct radv_device *device, struct radv_shader_stage *stage); -bool radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size); +bool radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_level, unsigned wave_size); bool radv_nir_lower_draw_id_to_zero(nir_shader *shader); diff --git a/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c b/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c index 884ceaa97ca..b72d0697335 100644 --- a/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c +++ b/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c @@ -7,26 +7,85 @@ #include "nir_builder.h" #include "radv_nir.h" +/* This pass lowers cooperative matrix. + * + * On GFX11, the A&B matrices needs to be replicated, lanes 0..15 are replicated + * to 16..31 and for wave64 also into lanes 32..47 and 48..63. A&B matrices are + * always vectors of 16 elements. + * + * On GFX12, there is no data replication and the matrices layout is described + * as below: + * + * Wave32: + * A&B: + * 0..15 | 16..31 (lanes) + * v0 lo: row 0 | row 4 + * v0 hi: row 1 | row 5 + * v1 lo: row 2 | row 6 + * v1 hi: row 3 | row 7 + * v2 lo: row 8 | row 12 + * v2 hi: row 9 | row 13 + * v3 lo: row 10 | row 14 + * v3 hi: row 11 | row 15 + * + * C: + * 0..15 | 16..31 (lanes) + * v0 lo: row 0 | row 8 + * v0 hi: row 1 | row 9 + * v1 lo: row 2 | row 10 + * v1 hi: row 3 | row 11 + * v2 lo: row 4 | row 12 + * v2 hi: row 5 | row 13 + * v3 lo: row 6 | row 14 + * v3 hi: row 7 | row 15 + * + * Wave64: + * A&B: + * 0..15 | 16..31 | 32..47 | 48..63 (lanes) + * v0 lo: row 0 | row 4 | row 8 | row 12 + * v0 hi: row 1 | row 5 | row 9 | row 13 + * v1 lo: row 2 | row 6 | row 10 | row 14 + * v1 hi: row 3 | row 7 | row 11 | row 15 + * + * C: + * 0..15 | 16..31 | 32..47 | 48..63 (lanes) + * v0 lo: row 0 | row 8 | row 4 | row 12 + * v0 hi: row 1 | row 9 | row 5 | row 13 + * v1 lo: row 2 | row 10 | row 6 | row 14 + * v1 hi: row 3 | row 11 | row 7 | row 15 + */ + typedef struct { + enum amd_gfx_level gfx_level; unsigned wave_size; } lower_cmat_params; static unsigned radv_nir_cmat_length(struct glsl_cmat_description desc, const lower_cmat_params *params) { - return desc.use != GLSL_CMAT_USE_ACCUMULATOR - ? 16 - : (desc.cols * desc.rows / params->wave_size * 32 / glsl_base_type_bit_size(desc.element_type)); + if (params->gfx_level >= GFX12) { + assert(desc.cols == 16 && desc.rows == 16); + return 256 / params->wave_size; + } else { + return desc.use != GLSL_CMAT_USE_ACCUMULATOR + ? 16 + : (desc.cols * desc.rows / params->wave_size * 32 / glsl_base_type_bit_size(desc.element_type)); + } } -/* for C matrices we have 1 VGPR per element even if the element type is < 32 bits. So with 8 fp16 elements we implement - * that with a f16vec16. We then use the coefficient generated by this function to figure out how many elements we - * really have. - */ static unsigned -radv_nir_cmat_length_mul(struct glsl_cmat_description desc) +radv_nir_cmat_length_mul(struct glsl_cmat_description desc, const lower_cmat_params *params) { - return desc.use == GLSL_CMAT_USE_ACCUMULATOR ? (32 / glsl_base_type_bit_size(desc.element_type)) : 1; + if (params->gfx_level >= GFX12) { + return 1; + } else { + /* For C matrices we have 1 VGPR per element even if the element type is + * < 32 bits. So with 8 fp16 elements we implement that with a f16vec16. + * We then use the coefficient generated by this function to figure out + * how many elements we really have. + */ + return desc.use == GLSL_CMAT_USE_ACCUMULATOR ? (32 / glsl_base_type_bit_size(desc.element_type)) : 1; + } } static unsigned @@ -99,8 +158,32 @@ radv_nir_translate_matrix_type(const struct glsl_type *orig_type, struct hash_ta return orig_type; } +static nir_def * +radv_get_base_row(nir_builder *b, struct glsl_cmat_description desc, const lower_cmat_params *params, + nir_def *local_idx) +{ + nir_def *base_row; + + if (params->gfx_level >= GFX12) { + base_row = nir_udiv_imm(b, local_idx, 16); + + if (desc.use == GLSL_CMAT_USE_ACCUMULATOR && params->wave_size == 64) { + /* Switch rows from lanes 16..31 to 32..47, offset right shift by -2 + * to get implicit * 4. + */ + base_row = nir_ushr_imm(b, nir_bitfield_reverse(b, base_row), 30 - 2); + } else { + base_row = nir_imul_imm(b, base_row, desc.use == GLSL_CMAT_USE_ACCUMULATOR && params->wave_size == 32 ? 8 : 4); + } + } else { + base_row = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? nir_udiv_imm(b, local_idx, 16) : nir_imm_int(b, 0); + } + + return base_row; +} + bool -radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size) +radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_level, unsigned wave_size) { bool progress = false; @@ -108,6 +191,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size) return false; const lower_cmat_params params = { + .gfx_level = gfx_level, .wave_size = wave_size, }; @@ -143,7 +227,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size) switch (intr->intrinsic) { case nir_intrinsic_cmat_length: { struct glsl_cmat_description desc = nir_intrinsic_cmat_desc(intr); - unsigned len = radv_nir_cmat_length(desc, ¶ms) / radv_nir_cmat_length_mul(desc); + unsigned len = radv_nir_cmat_length(desc, ¶ms) / radv_nir_cmat_length_mul(desc, ¶ms); nir_def_rewrite_uses(&intr->def, nir_imm_int(&b, len)); nir_instr_remove(instr); progress = true; @@ -155,7 +239,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size) nir_def *src0 = radv_nir_load_cmat(&b, ¶ms, intr->src[0].ssa); nir_def *index = intr->src[1].ssa; - index = nir_imul_imm(&b, index, radv_nir_cmat_length_mul(desc)); + index = nir_imul_imm(&b, index, radv_nir_cmat_length_mul(desc, ¶ms)); nir_def *elem = nir_vector_extract(&b, src0, index); @@ -169,7 +253,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size) nir_deref_instr *dst_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr); struct glsl_cmat_description desc = *glsl_get_cmat_description(dst_deref->type); nir_def *index = intr->src[3].ssa; - index = nir_imul_imm(&b, index, radv_nir_cmat_length_mul(desc)); + index = nir_imul_imm(&b, index, radv_nir_cmat_length_mul(desc, ¶ms)); nir_def *elem = intr->src[1].ssa; nir_def *r = nir_vector_insert(&b, src1, elem, index); @@ -207,7 +291,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size) : GLSL_MATRIX_LAYOUT_COLUMN_MAJOR; unsigned length = radv_nir_cmat_length(desc, ¶ms); - unsigned mul = radv_nir_cmat_length_mul(desc); + unsigned mul = radv_nir_cmat_length_mul(desc, ¶ms); unsigned lanes_per_iter = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? params.wave_size : 16; nir_def *vars[16]; if (mul > 1) { @@ -217,12 +301,20 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size) } unsigned idx_bits = deref->def.bit_size; - nir_def *base_row = - desc.use == GLSL_CMAT_USE_ACCUMULATOR ? nir_udiv_imm(&b, local_idx, 16) : nir_imm_int(&b, 0); + nir_def *base_row = radv_get_base_row(&b, desc, ¶ms, local_idx); for (unsigned i = 0; i < length / mul; ++i) { nir_def *col_offset = inner_idx; - nir_def *row_offset = nir_iadd_imm(&b, base_row, i * lanes_per_iter / 16); + nir_def *row_offset; + uint32_t row_iter; + + if (gfx_level >= GFX12) { + row_iter = desc.use != GLSL_CMAT_USE_ACCUMULATOR && wave_size == 32 ? i + (i & 4) : i; + } else { + row_iter = i * lanes_per_iter / 16; + } + + row_offset = nir_iadd_imm(&b, base_row, row_iter); if (layout == GLSL_MATRIX_LAYOUT_ROW_MAJOR) { nir_def *tmp = col_offset; @@ -263,7 +355,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size) nir_def *local_idx = nir_load_subgroup_invocation(&b); - if (desc.use != GLSL_CMAT_USE_ACCUMULATOR) + if (gfx_level < GFX12 && desc.use != GLSL_CMAT_USE_ACCUMULATOR) nir_push_if(&b, nir_ilt_imm(&b, local_idx, 16)); nir_def *inner_idx = nir_iand_imm(&b, local_idx, 15); @@ -274,19 +366,27 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size) : GLSL_MATRIX_LAYOUT_COLUMN_MAJOR; unsigned length = radv_nir_cmat_length(desc, ¶ms); - unsigned mul = radv_nir_cmat_length_mul(desc); + unsigned mul = radv_nir_cmat_length_mul(desc, ¶ms); unsigned lanes_per_iter = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? params.wave_size : 16; nir_def *vars[16]; for (unsigned i = 0; i < length; ++i) vars[i] = nir_channel(&b, src, i); unsigned idx_bits = deref->def.bit_size; - nir_def *base_row = - desc.use == GLSL_CMAT_USE_ACCUMULATOR ? nir_udiv_imm(&b, local_idx, 16) : nir_imm_int(&b, 0); + nir_def *base_row = radv_get_base_row(&b, desc, ¶ms, local_idx); for (unsigned i = 0; i < length / mul; ++i) { nir_def *col_offset = inner_idx; - nir_def *row_offset = nir_iadd_imm(&b, base_row, i * lanes_per_iter / 16); + nir_def *row_offset; + uint32_t row_iter; + + if (gfx_level >= GFX12) { + row_iter = desc.use != GLSL_CMAT_USE_ACCUMULATOR && wave_size == 32 ? i + (i & 4) : i; + } else { + row_iter = i * lanes_per_iter / 16; + } + + row_offset = nir_iadd_imm(&b, base_row, row_iter); if (layout == GLSL_MATRIX_LAYOUT_ROW_MAJOR) { nir_def *tmp = col_offset; @@ -308,7 +408,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size) nir_store_deref(&b, iter_deref, vars[i * mul], 1); } - if (desc.use != GLSL_CMAT_USE_ACCUMULATOR) + if (gfx_level < GFX12 && desc.use != GLSL_CMAT_USE_ACCUMULATOR) nir_pop_if(&b, NULL); nir_instr_remove(instr); @@ -338,7 +438,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size) nir_def *src = radv_nir_load_cmat(&b, ¶ms, intr->src[1].ssa); nir_op op = nir_intrinsic_alu_op(intr); - if (glsl_base_type_bit_size(src_desc.element_type) == 16 && + if (gfx_level < GFX12 && glsl_base_type_bit_size(src_desc.element_type) == 16 && glsl_base_type_bit_size(desc.element_type) == 32 && desc.use == GLSL_CMAT_USE_ACCUMULATOR) { nir_def *components[NIR_MAX_VEC_COMPONENTS]; for (unsigned i = 0; i * 2 < src->num_components; ++i) { @@ -349,7 +449,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size) nir_def *ret = nir_build_alu1(&b, op, src); - if (glsl_base_type_bit_size(src_desc.element_type) == 32 && + if (gfx_level < GFX12 && glsl_base_type_bit_size(src_desc.element_type) == 32 && glsl_base_type_bit_size(desc.element_type) == 16 && desc.use == GLSL_CMAT_USE_ACCUMULATOR) { nir_def *components[NIR_MAX_VEC_COMPONENTS]; for (unsigned i = 0; i < ret->num_components; ++i) { diff --git a/src/amd/vulkan/radv_shader.c b/src/amd/vulkan/radv_shader.c index 922bdaa56db..4095d9256f3 100644 --- a/src/amd/vulkan/radv_shader.c +++ b/src/amd/vulkan/radv_shader.c @@ -419,7 +419,7 @@ radv_shader_spirv_to_nir(struct radv_device *device, const struct radv_shader_st */ NIR_PASS(_, nir, nir_lower_variable_initializers, ~0); - NIR_PASS(_, nir, radv_nir_lower_cooperative_matrix, subgroup_size); + NIR_PASS(_, nir, radv_nir_lower_cooperative_matrix, pdev->info.gfx_level, subgroup_size); /* Split member structs. We do this before lower_io_to_temporaries so that * it doesn't lower system values to temporaries by accident.