mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2025-12-26 21:30:09 +01:00
radv/nir: add cooperative matrix lowering for GFX12
Signed-off-by: Samuel Pitoiset <samuel.pitoiset@gmail.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/33378>
This commit is contained in:
parent
ad611adeb7
commit
b05a112d92
3 changed files with 127 additions and 27 deletions
|
|
@ -69,7 +69,7 @@ void radv_nir_lower_io(struct radv_device *device, nir_shader *nir);
|
|||
|
||||
bool radv_nir_lower_io_to_mem(struct radv_device *device, struct radv_shader_stage *stage);
|
||||
|
||||
bool radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size);
|
||||
bool radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_level, unsigned wave_size);
|
||||
|
||||
bool radv_nir_lower_draw_id_to_zero(nir_shader *shader);
|
||||
|
||||
|
|
|
|||
|
|
@ -7,26 +7,85 @@
|
|||
#include "nir_builder.h"
|
||||
#include "radv_nir.h"
|
||||
|
||||
/* This pass lowers cooperative matrix.
|
||||
*
|
||||
* On GFX11, the A&B matrices needs to be replicated, lanes 0..15 are replicated
|
||||
* to 16..31 and for wave64 also into lanes 32..47 and 48..63. A&B matrices are
|
||||
* always vectors of 16 elements.
|
||||
*
|
||||
* On GFX12, there is no data replication and the matrices layout is described
|
||||
* as below:
|
||||
*
|
||||
* Wave32:
|
||||
* A&B:
|
||||
* 0..15 | 16..31 (lanes)
|
||||
* v0 lo: row 0 | row 4
|
||||
* v0 hi: row 1 | row 5
|
||||
* v1 lo: row 2 | row 6
|
||||
* v1 hi: row 3 | row 7
|
||||
* v2 lo: row 8 | row 12
|
||||
* v2 hi: row 9 | row 13
|
||||
* v3 lo: row 10 | row 14
|
||||
* v3 hi: row 11 | row 15
|
||||
*
|
||||
* C:
|
||||
* 0..15 | 16..31 (lanes)
|
||||
* v0 lo: row 0 | row 8
|
||||
* v0 hi: row 1 | row 9
|
||||
* v1 lo: row 2 | row 10
|
||||
* v1 hi: row 3 | row 11
|
||||
* v2 lo: row 4 | row 12
|
||||
* v2 hi: row 5 | row 13
|
||||
* v3 lo: row 6 | row 14
|
||||
* v3 hi: row 7 | row 15
|
||||
*
|
||||
* Wave64:
|
||||
* A&B:
|
||||
* 0..15 | 16..31 | 32..47 | 48..63 (lanes)
|
||||
* v0 lo: row 0 | row 4 | row 8 | row 12
|
||||
* v0 hi: row 1 | row 5 | row 9 | row 13
|
||||
* v1 lo: row 2 | row 6 | row 10 | row 14
|
||||
* v1 hi: row 3 | row 7 | row 11 | row 15
|
||||
*
|
||||
* C:
|
||||
* 0..15 | 16..31 | 32..47 | 48..63 (lanes)
|
||||
* v0 lo: row 0 | row 8 | row 4 | row 12
|
||||
* v0 hi: row 1 | row 9 | row 5 | row 13
|
||||
* v1 lo: row 2 | row 10 | row 6 | row 14
|
||||
* v1 hi: row 3 | row 11 | row 7 | row 15
|
||||
*/
|
||||
|
||||
typedef struct {
|
||||
enum amd_gfx_level gfx_level;
|
||||
unsigned wave_size;
|
||||
} lower_cmat_params;
|
||||
|
||||
static unsigned
|
||||
radv_nir_cmat_length(struct glsl_cmat_description desc, const lower_cmat_params *params)
|
||||
{
|
||||
return desc.use != GLSL_CMAT_USE_ACCUMULATOR
|
||||
? 16
|
||||
: (desc.cols * desc.rows / params->wave_size * 32 / glsl_base_type_bit_size(desc.element_type));
|
||||
if (params->gfx_level >= GFX12) {
|
||||
assert(desc.cols == 16 && desc.rows == 16);
|
||||
return 256 / params->wave_size;
|
||||
} else {
|
||||
return desc.use != GLSL_CMAT_USE_ACCUMULATOR
|
||||
? 16
|
||||
: (desc.cols * desc.rows / params->wave_size * 32 / glsl_base_type_bit_size(desc.element_type));
|
||||
}
|
||||
}
|
||||
|
||||
/* for C matrices we have 1 VGPR per element even if the element type is < 32 bits. So with 8 fp16 elements we implement
|
||||
* that with a f16vec16. We then use the coefficient generated by this function to figure out how many elements we
|
||||
* really have.
|
||||
*/
|
||||
static unsigned
|
||||
radv_nir_cmat_length_mul(struct glsl_cmat_description desc)
|
||||
radv_nir_cmat_length_mul(struct glsl_cmat_description desc, const lower_cmat_params *params)
|
||||
{
|
||||
return desc.use == GLSL_CMAT_USE_ACCUMULATOR ? (32 / glsl_base_type_bit_size(desc.element_type)) : 1;
|
||||
if (params->gfx_level >= GFX12) {
|
||||
return 1;
|
||||
} else {
|
||||
/* For C matrices we have 1 VGPR per element even if the element type is
|
||||
* < 32 bits. So with 8 fp16 elements we implement that with a f16vec16.
|
||||
* We then use the coefficient generated by this function to figure out
|
||||
* how many elements we really have.
|
||||
*/
|
||||
return desc.use == GLSL_CMAT_USE_ACCUMULATOR ? (32 / glsl_base_type_bit_size(desc.element_type)) : 1;
|
||||
}
|
||||
}
|
||||
|
||||
static unsigned
|
||||
|
|
@ -99,8 +158,32 @@ radv_nir_translate_matrix_type(const struct glsl_type *orig_type, struct hash_ta
|
|||
return orig_type;
|
||||
}
|
||||
|
||||
static nir_def *
|
||||
radv_get_base_row(nir_builder *b, struct glsl_cmat_description desc, const lower_cmat_params *params,
|
||||
nir_def *local_idx)
|
||||
{
|
||||
nir_def *base_row;
|
||||
|
||||
if (params->gfx_level >= GFX12) {
|
||||
base_row = nir_udiv_imm(b, local_idx, 16);
|
||||
|
||||
if (desc.use == GLSL_CMAT_USE_ACCUMULATOR && params->wave_size == 64) {
|
||||
/* Switch rows from lanes 16..31 to 32..47, offset right shift by -2
|
||||
* to get implicit * 4.
|
||||
*/
|
||||
base_row = nir_ushr_imm(b, nir_bitfield_reverse(b, base_row), 30 - 2);
|
||||
} else {
|
||||
base_row = nir_imul_imm(b, base_row, desc.use == GLSL_CMAT_USE_ACCUMULATOR && params->wave_size == 32 ? 8 : 4);
|
||||
}
|
||||
} else {
|
||||
base_row = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? nir_udiv_imm(b, local_idx, 16) : nir_imm_int(b, 0);
|
||||
}
|
||||
|
||||
return base_row;
|
||||
}
|
||||
|
||||
bool
|
||||
radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
|
||||
radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_level, unsigned wave_size)
|
||||
{
|
||||
bool progress = false;
|
||||
|
||||
|
|
@ -108,6 +191,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
|
|||
return false;
|
||||
|
||||
const lower_cmat_params params = {
|
||||
.gfx_level = gfx_level,
|
||||
.wave_size = wave_size,
|
||||
};
|
||||
|
||||
|
|
@ -143,7 +227,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
|
|||
switch (intr->intrinsic) {
|
||||
case nir_intrinsic_cmat_length: {
|
||||
struct glsl_cmat_description desc = nir_intrinsic_cmat_desc(intr);
|
||||
unsigned len = radv_nir_cmat_length(desc, ¶ms) / radv_nir_cmat_length_mul(desc);
|
||||
unsigned len = radv_nir_cmat_length(desc, ¶ms) / radv_nir_cmat_length_mul(desc, ¶ms);
|
||||
nir_def_rewrite_uses(&intr->def, nir_imm_int(&b, len));
|
||||
nir_instr_remove(instr);
|
||||
progress = true;
|
||||
|
|
@ -155,7 +239,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
|
|||
nir_def *src0 = radv_nir_load_cmat(&b, ¶ms, intr->src[0].ssa);
|
||||
|
||||
nir_def *index = intr->src[1].ssa;
|
||||
index = nir_imul_imm(&b, index, radv_nir_cmat_length_mul(desc));
|
||||
index = nir_imul_imm(&b, index, radv_nir_cmat_length_mul(desc, ¶ms));
|
||||
|
||||
nir_def *elem = nir_vector_extract(&b, src0, index);
|
||||
|
||||
|
|
@ -169,7 +253,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
|
|||
nir_deref_instr *dst_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
|
||||
struct glsl_cmat_description desc = *glsl_get_cmat_description(dst_deref->type);
|
||||
nir_def *index = intr->src[3].ssa;
|
||||
index = nir_imul_imm(&b, index, radv_nir_cmat_length_mul(desc));
|
||||
index = nir_imul_imm(&b, index, radv_nir_cmat_length_mul(desc, ¶ms));
|
||||
|
||||
nir_def *elem = intr->src[1].ssa;
|
||||
nir_def *r = nir_vector_insert(&b, src1, elem, index);
|
||||
|
|
@ -207,7 +291,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
|
|||
: GLSL_MATRIX_LAYOUT_COLUMN_MAJOR;
|
||||
|
||||
unsigned length = radv_nir_cmat_length(desc, ¶ms);
|
||||
unsigned mul = radv_nir_cmat_length_mul(desc);
|
||||
unsigned mul = radv_nir_cmat_length_mul(desc, ¶ms);
|
||||
unsigned lanes_per_iter = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? params.wave_size : 16;
|
||||
nir_def *vars[16];
|
||||
if (mul > 1) {
|
||||
|
|
@ -217,12 +301,20 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
|
|||
}
|
||||
|
||||
unsigned idx_bits = deref->def.bit_size;
|
||||
nir_def *base_row =
|
||||
desc.use == GLSL_CMAT_USE_ACCUMULATOR ? nir_udiv_imm(&b, local_idx, 16) : nir_imm_int(&b, 0);
|
||||
nir_def *base_row = radv_get_base_row(&b, desc, ¶ms, local_idx);
|
||||
|
||||
for (unsigned i = 0; i < length / mul; ++i) {
|
||||
nir_def *col_offset = inner_idx;
|
||||
nir_def *row_offset = nir_iadd_imm(&b, base_row, i * lanes_per_iter / 16);
|
||||
nir_def *row_offset;
|
||||
uint32_t row_iter;
|
||||
|
||||
if (gfx_level >= GFX12) {
|
||||
row_iter = desc.use != GLSL_CMAT_USE_ACCUMULATOR && wave_size == 32 ? i + (i & 4) : i;
|
||||
} else {
|
||||
row_iter = i * lanes_per_iter / 16;
|
||||
}
|
||||
|
||||
row_offset = nir_iadd_imm(&b, base_row, row_iter);
|
||||
|
||||
if (layout == GLSL_MATRIX_LAYOUT_ROW_MAJOR) {
|
||||
nir_def *tmp = col_offset;
|
||||
|
|
@ -263,7 +355,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
|
|||
|
||||
nir_def *local_idx = nir_load_subgroup_invocation(&b);
|
||||
|
||||
if (desc.use != GLSL_CMAT_USE_ACCUMULATOR)
|
||||
if (gfx_level < GFX12 && desc.use != GLSL_CMAT_USE_ACCUMULATOR)
|
||||
nir_push_if(&b, nir_ilt_imm(&b, local_idx, 16));
|
||||
|
||||
nir_def *inner_idx = nir_iand_imm(&b, local_idx, 15);
|
||||
|
|
@ -274,19 +366,27 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
|
|||
: GLSL_MATRIX_LAYOUT_COLUMN_MAJOR;
|
||||
|
||||
unsigned length = radv_nir_cmat_length(desc, ¶ms);
|
||||
unsigned mul = radv_nir_cmat_length_mul(desc);
|
||||
unsigned mul = radv_nir_cmat_length_mul(desc, ¶ms);
|
||||
unsigned lanes_per_iter = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? params.wave_size : 16;
|
||||
nir_def *vars[16];
|
||||
for (unsigned i = 0; i < length; ++i)
|
||||
vars[i] = nir_channel(&b, src, i);
|
||||
|
||||
unsigned idx_bits = deref->def.bit_size;
|
||||
nir_def *base_row =
|
||||
desc.use == GLSL_CMAT_USE_ACCUMULATOR ? nir_udiv_imm(&b, local_idx, 16) : nir_imm_int(&b, 0);
|
||||
nir_def *base_row = radv_get_base_row(&b, desc, ¶ms, local_idx);
|
||||
|
||||
for (unsigned i = 0; i < length / mul; ++i) {
|
||||
nir_def *col_offset = inner_idx;
|
||||
nir_def *row_offset = nir_iadd_imm(&b, base_row, i * lanes_per_iter / 16);
|
||||
nir_def *row_offset;
|
||||
uint32_t row_iter;
|
||||
|
||||
if (gfx_level >= GFX12) {
|
||||
row_iter = desc.use != GLSL_CMAT_USE_ACCUMULATOR && wave_size == 32 ? i + (i & 4) : i;
|
||||
} else {
|
||||
row_iter = i * lanes_per_iter / 16;
|
||||
}
|
||||
|
||||
row_offset = nir_iadd_imm(&b, base_row, row_iter);
|
||||
|
||||
if (layout == GLSL_MATRIX_LAYOUT_ROW_MAJOR) {
|
||||
nir_def *tmp = col_offset;
|
||||
|
|
@ -308,7 +408,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
|
|||
nir_store_deref(&b, iter_deref, vars[i * mul], 1);
|
||||
}
|
||||
|
||||
if (desc.use != GLSL_CMAT_USE_ACCUMULATOR)
|
||||
if (gfx_level < GFX12 && desc.use != GLSL_CMAT_USE_ACCUMULATOR)
|
||||
nir_pop_if(&b, NULL);
|
||||
|
||||
nir_instr_remove(instr);
|
||||
|
|
@ -338,7 +438,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
|
|||
nir_def *src = radv_nir_load_cmat(&b, ¶ms, intr->src[1].ssa);
|
||||
nir_op op = nir_intrinsic_alu_op(intr);
|
||||
|
||||
if (glsl_base_type_bit_size(src_desc.element_type) == 16 &&
|
||||
if (gfx_level < GFX12 && glsl_base_type_bit_size(src_desc.element_type) == 16 &&
|
||||
glsl_base_type_bit_size(desc.element_type) == 32 && desc.use == GLSL_CMAT_USE_ACCUMULATOR) {
|
||||
nir_def *components[NIR_MAX_VEC_COMPONENTS];
|
||||
for (unsigned i = 0; i * 2 < src->num_components; ++i) {
|
||||
|
|
@ -349,7 +449,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, unsigned wave_size)
|
|||
|
||||
nir_def *ret = nir_build_alu1(&b, op, src);
|
||||
|
||||
if (glsl_base_type_bit_size(src_desc.element_type) == 32 &&
|
||||
if (gfx_level < GFX12 && glsl_base_type_bit_size(src_desc.element_type) == 32 &&
|
||||
glsl_base_type_bit_size(desc.element_type) == 16 && desc.use == GLSL_CMAT_USE_ACCUMULATOR) {
|
||||
nir_def *components[NIR_MAX_VEC_COMPONENTS];
|
||||
for (unsigned i = 0; i < ret->num_components; ++i) {
|
||||
|
|
|
|||
|
|
@ -419,7 +419,7 @@ radv_shader_spirv_to_nir(struct radv_device *device, const struct radv_shader_st
|
|||
*/
|
||||
NIR_PASS(_, nir, nir_lower_variable_initializers, ~0);
|
||||
|
||||
NIR_PASS(_, nir, radv_nir_lower_cooperative_matrix, subgroup_size);
|
||||
NIR_PASS(_, nir, radv_nir_lower_cooperative_matrix, pdev->info.gfx_level, subgroup_size);
|
||||
|
||||
/* Split member structs. We do this before lower_io_to_temporaries so that
|
||||
* it doesn't lower system values to temporaries by accident.
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue