radv/nir: add cooperative matrix lowering for GFX12

Signed-off-by: Samuel Pitoiset <samuel.pitoiset@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/33378>
This commit is contained in:
Samuel Pitoiset 2025-01-22 03:52:10 -08:00 committed by Marge Bot
parent ad611adeb7
commit b05a112d92
3 changed files with 127 additions and 27 deletions

View file

@ -69,7 +69,7 @@ void radv_nir_lower_io(struct radv_device *device, nir_shader *nir);
bool radv_nir_lower_io_to_mem(struct radv_device *device, struct radv_shader_stage *stage);
bool radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size);
bool radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_level, unsigned wave_size);
bool radv_nir_lower_draw_id_to_zero(nir_shader *shader);

View file

@ -7,26 +7,85 @@
#include "nir_builder.h"
#include "radv_nir.h"
/* This pass lowers cooperative matrix.
*
* On GFX11, the A&B matrices needs to be replicated, lanes 0..15 are replicated
* to 16..31 and for wave64 also into lanes 32..47 and 48..63. A&B matrices are
* always vectors of 16 elements.
*
* On GFX12, there is no data replication and the matrices layout is described
* as below:
*
* Wave32:
* A&B:
* 0..15 | 16..31 (lanes)
* v0 lo: row 0 | row 4
* v0 hi: row 1 | row 5
* v1 lo: row 2 | row 6
* v1 hi: row 3 | row 7
* v2 lo: row 8 | row 12
* v2 hi: row 9 | row 13
* v3 lo: row 10 | row 14
* v3 hi: row 11 | row 15
*
* C:
* 0..15 | 16..31 (lanes)
* v0 lo: row 0 | row 8
* v0 hi: row 1 | row 9
* v1 lo: row 2 | row 10
* v1 hi: row 3 | row 11
* v2 lo: row 4 | row 12
* v2 hi: row 5 | row 13
* v3 lo: row 6 | row 14
* v3 hi: row 7 | row 15
*
* Wave64:
* A&B:
* 0..15 | 16..31 | 32..47 | 48..63 (lanes)
* v0 lo: row 0 | row 4 | row 8 | row 12
* v0 hi: row 1 | row 5 | row 9 | row 13
* v1 lo: row 2 | row 6 | row 10 | row 14
* v1 hi: row 3 | row 7 | row 11 | row 15
*
* C:
* 0..15 | 16..31 | 32..47 | 48..63 (lanes)
* v0 lo: row 0 | row 8 | row 4 | row 12
* v0 hi: row 1 | row 9 | row 5 | row 13
* v1 lo: row 2 | row 10 | row 6 | row 14
* v1 hi: row 3 | row 11 | row 7 | row 15
*/
typedef struct {
enum amd_gfx_level gfx_level;
unsigned wave_size;
} lower_cmat_params;
static unsigned
radv_nir_cmat_length(struct glsl_cmat_description desc, const lower_cmat_params *params)
{
return desc.use != GLSL_CMAT_USE_ACCUMULATOR
? 16
: (desc.cols * desc.rows / params->wave_size * 32 / glsl_base_type_bit_size(desc.element_type));
if (params->gfx_level >= GFX12) {
assert(desc.cols == 16 && desc.rows == 16);
return 256 / params->wave_size;
} else {
return desc.use != GLSL_CMAT_USE_ACCUMULATOR
? 16
: (desc.cols * desc.rows / params->wave_size * 32 / glsl_base_type_bit_size(desc.element_type));
}
}
/* for C matrices we have 1 VGPR per element even if the element type is < 32 bits. So with 8 fp16 elements we implement
* that with a f16vec16. We then use the coefficient generated by this function to figure out how many elements we
* really have.
*/
static unsigned
radv_nir_cmat_length_mul(struct glsl_cmat_description desc)
radv_nir_cmat_length_mul(struct glsl_cmat_description desc, const lower_cmat_params *params)
{
return desc.use == GLSL_CMAT_USE_ACCUMULATOR ? (32 / glsl_base_type_bit_size(desc.element_type)) : 1;
if (params->gfx_level >= GFX12) {
return 1;
} else {
/* For C matrices we have 1 VGPR per element even if the element type is
* < 32 bits. So with 8 fp16 elements we implement that with a f16vec16.
* We then use the coefficient generated by this function to figure out
* how many elements we really have.
*/
return desc.use == GLSL_CMAT_USE_ACCUMULATOR ? (32 / glsl_base_type_bit_size(desc.element_type)) : 1;
}
}
static unsigned
@ -99,8 +158,32 @@ radv_nir_translate_matrix_type(const struct glsl_type *orig_type, struct hash_ta
return orig_type;
}
static nir_def *
radv_get_base_row(nir_builder *b, struct glsl_cmat_description desc, const lower_cmat_params *params,
nir_def *local_idx)
{
nir_def *base_row;
if (params->gfx_level >= GFX12) {
base_row = nir_udiv_imm(b, local_idx, 16);
if (desc.use == GLSL_CMAT_USE_ACCUMULATOR && params->wave_size == 64) {
/* Switch rows from lanes 16..31 to 32..47, offset right shift by -2
* to get implicit * 4.
*/
base_row = nir_ushr_imm(b, nir_bitfield_reverse(b, base_row), 30 - 2);
} else {
base_row = nir_imul_imm(b, base_row, desc.use == GLSL_CMAT_USE_ACCUMULATOR && params->wave_size == 32 ? 8 : 4);
}
} else {
base_row = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? nir_udiv_imm(b, local_idx, 16) : nir_imm_int(b, 0);
}
return base_row;
}
bool
radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_level, unsigned wave_size)
{
bool progress = false;
@ -108,6 +191,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
return false;
const lower_cmat_params params = {
.gfx_level = gfx_level,
.wave_size = wave_size,
};
@ -143,7 +227,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
switch (intr->intrinsic) {
case nir_intrinsic_cmat_length: {
struct glsl_cmat_description desc = nir_intrinsic_cmat_desc(intr);
unsigned len = radv_nir_cmat_length(desc, &params) / radv_nir_cmat_length_mul(desc);
unsigned len = radv_nir_cmat_length(desc, &params) / radv_nir_cmat_length_mul(desc, &params);
nir_def_rewrite_uses(&intr->def, nir_imm_int(&b, len));
nir_instr_remove(instr);
progress = true;
@ -155,7 +239,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
nir_def *src0 = radv_nir_load_cmat(&b, &params, intr->src[0].ssa);
nir_def *index = intr->src[1].ssa;
index = nir_imul_imm(&b, index, radv_nir_cmat_length_mul(desc));
index = nir_imul_imm(&b, index, radv_nir_cmat_length_mul(desc, &params));
nir_def *elem = nir_vector_extract(&b, src0, index);
@ -169,7 +253,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
nir_deref_instr *dst_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
struct glsl_cmat_description desc = *glsl_get_cmat_description(dst_deref->type);
nir_def *index = intr->src[3].ssa;
index = nir_imul_imm(&b, index, radv_nir_cmat_length_mul(desc));
index = nir_imul_imm(&b, index, radv_nir_cmat_length_mul(desc, &params));
nir_def *elem = intr->src[1].ssa;
nir_def *r = nir_vector_insert(&b, src1, elem, index);
@ -207,7 +291,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
: GLSL_MATRIX_LAYOUT_COLUMN_MAJOR;
unsigned length = radv_nir_cmat_length(desc, &params);
unsigned mul = radv_nir_cmat_length_mul(desc);
unsigned mul = radv_nir_cmat_length_mul(desc, &params);
unsigned lanes_per_iter = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? params.wave_size : 16;
nir_def *vars[16];
if (mul > 1) {
@ -217,12 +301,20 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
}
unsigned idx_bits = deref->def.bit_size;
nir_def *base_row =
desc.use == GLSL_CMAT_USE_ACCUMULATOR ? nir_udiv_imm(&b, local_idx, 16) : nir_imm_int(&b, 0);
nir_def *base_row = radv_get_base_row(&b, desc, &params, local_idx);
for (unsigned i = 0; i < length / mul; ++i) {
nir_def *col_offset = inner_idx;
nir_def *row_offset = nir_iadd_imm(&b, base_row, i * lanes_per_iter / 16);
nir_def *row_offset;
uint32_t row_iter;
if (gfx_level >= GFX12) {
row_iter = desc.use != GLSL_CMAT_USE_ACCUMULATOR && wave_size == 32 ? i + (i & 4) : i;
} else {
row_iter = i * lanes_per_iter / 16;
}
row_offset = nir_iadd_imm(&b, base_row, row_iter);
if (layout == GLSL_MATRIX_LAYOUT_ROW_MAJOR) {
nir_def *tmp = col_offset;
@ -263,7 +355,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
nir_def *local_idx = nir_load_subgroup_invocation(&b);
if (desc.use != GLSL_CMAT_USE_ACCUMULATOR)
if (gfx_level < GFX12 && desc.use != GLSL_CMAT_USE_ACCUMULATOR)
nir_push_if(&b, nir_ilt_imm(&b, local_idx, 16));
nir_def *inner_idx = nir_iand_imm(&b, local_idx, 15);
@ -274,19 +366,27 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
: GLSL_MATRIX_LAYOUT_COLUMN_MAJOR;
unsigned length = radv_nir_cmat_length(desc, &params);
unsigned mul = radv_nir_cmat_length_mul(desc);
unsigned mul = radv_nir_cmat_length_mul(desc, &params);
unsigned lanes_per_iter = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? params.wave_size : 16;
nir_def *vars[16];
for (unsigned i = 0; i < length; ++i)
vars[i] = nir_channel(&b, src, i);
unsigned idx_bits = deref->def.bit_size;
nir_def *base_row =
desc.use == GLSL_CMAT_USE_ACCUMULATOR ? nir_udiv_imm(&b, local_idx, 16) : nir_imm_int(&b, 0);
nir_def *base_row = radv_get_base_row(&b, desc, &params, local_idx);
for (unsigned i = 0; i < length / mul; ++i) {
nir_def *col_offset = inner_idx;
nir_def *row_offset = nir_iadd_imm(&b, base_row, i * lanes_per_iter / 16);
nir_def *row_offset;
uint32_t row_iter;
if (gfx_level >= GFX12) {
row_iter = desc.use != GLSL_CMAT_USE_ACCUMULATOR && wave_size == 32 ? i + (i & 4) : i;
} else {
row_iter = i * lanes_per_iter / 16;
}
row_offset = nir_iadd_imm(&b, base_row, row_iter);
if (layout == GLSL_MATRIX_LAYOUT_ROW_MAJOR) {
nir_def *tmp = col_offset;
@ -308,7 +408,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
nir_store_deref(&b, iter_deref, vars[i * mul], 1);
}
if (desc.use != GLSL_CMAT_USE_ACCUMULATOR)
if (gfx_level < GFX12 && desc.use != GLSL_CMAT_USE_ACCUMULATOR)
nir_pop_if(&b, NULL);
nir_instr_remove(instr);
@ -338,7 +438,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
nir_def *src = radv_nir_load_cmat(&b, &params, intr->src[1].ssa);
nir_op op = nir_intrinsic_alu_op(intr);
if (glsl_base_type_bit_size(src_desc.element_type) == 16 &&
if (gfx_level < GFX12 && glsl_base_type_bit_size(src_desc.element_type) == 16 &&
glsl_base_type_bit_size(desc.element_type) == 32 && desc.use == GLSL_CMAT_USE_ACCUMULATOR) {
nir_def *components[NIR_MAX_VEC_COMPONENTS];
for (unsigned i = 0; i * 2 < src->num_components; ++i) {
@ -349,7 +449,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
nir_def *ret = nir_build_alu1(&b, op, src);
if (glsl_base_type_bit_size(src_desc.element_type) == 32 &&
if (gfx_level < GFX12 && glsl_base_type_bit_size(src_desc.element_type) == 32 &&
glsl_base_type_bit_size(desc.element_type) == 16 && desc.use == GLSL_CMAT_USE_ACCUMULATOR) {
nir_def *components[NIR_MAX_VEC_COMPONENTS];
for (unsigned i = 0; i < ret->num_components; ++i) {

View file

@ -419,7 +419,7 @@ radv_shader_spirv_to_nir(struct radv_device *device, const struct radv_shader_st
*/
NIR_PASS(_, nir, nir_lower_variable_initializers, ~0);
NIR_PASS(_, nir, radv_nir_lower_cooperative_matrix, subgroup_size);
NIR_PASS(_, nir, radv_nir_lower_cooperative_matrix, pdev->info.gfx_level, subgroup_size);
/* Split member structs. We do this before lower_io_to_temporaries so that
* it doesn't lower system values to temporaries by accident.