radv: use load_deref_transpose_amd for transposed cooperative matrix loads

This requires that BDA is used or robust buffer access is disabled for
cooperative matrix loads.

No fossil-db changes.

fossil-db (gfx1201, dEQP-VK.compute.pipeline.cooperative_matrix.*):
Totals from 603 (37.15% of 1623) affected shaders:
Instrs: 57422 -> 51212 (-10.81%); split: -11.99%, +1.17%
CodeSize: 357444 -> 310688 (-13.08%); split: -13.70%, +0.62%
VGPRs: 16668 -> 13188 (-20.88%); split: -21.53%, +0.65%
Latency: 492820 -> 469600 (-4.71%); split: -4.82%, +0.11%
InvThroughput: 63548 -> 55754 (-12.26%); split: -13.09%, +0.82%
VClause: 1624 -> 1620 (-0.25%); split: -2.71%, +2.46%
Copies: 2965 -> 3175 (+7.08%); split: -15.41%, +22.50%
PreSGPRs: 6966 -> 5450 (-21.76%); split: -21.91%, +0.14%
PreVGPRs: 7049 -> 5978 (-15.19%); split: -15.39%, +0.20%
VALU: 27454 -> 24315 (-11.43%); split: -14.74%, +3.30%
SALU: 5996 -> 6997 (+16.69%)
VMEM: 6748 -> 4656 (-31.00%)
SMEM: 1225 -> 1577 (+28.73%)
VOPD: 45 -> 39 (-13.33%); split: +4.44%, -17.78%

Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Georg Lehmann <dadschoorse@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/41653>
This commit is contained in:
Rhys Perry 2026-05-08 15:31:50 +01:00 committed by Marge Bot
parent 0cdb7594d7
commit ca0496bc26
3 changed files with 85 additions and 8 deletions

View file

@ -66,7 +66,8 @@ void radv_nir_lower_io(nir_shader *nir);
bool radv_nir_lower_io_to_mem(const struct radv_compiler_info *compiler_info, struct radv_shader_stage *stage);
bool radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_level, unsigned wave_size);
bool radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_level,
struct radv_shader_stage *stage, unsigned wave_size);
bool radv_nir_opt_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_level);

View file

@ -6,6 +6,7 @@
#include "nir_builder.h"
#include "radv_nir.h"
#include "radv_shader.h"
/* This pass lowers cooperative matrix.
*
@ -43,6 +44,8 @@
typedef struct {
enum amd_gfx_level gfx_level;
unsigned wave_size;
bool ubo_robustness;
bool ssbo_robustness;
} lower_cmat_params;
static unsigned
@ -241,6 +244,39 @@ lower_cmat_construct(nir_builder *b, nir_intrinsic_instr *intr, const lower_cmat
return true;
}
static void
get_load_tr_row_col(nir_builder *b, unsigned bit_size, nir_def **row, nir_def **col)
{
nir_def *lane_id = nir_load_subgroup_invocation(b);
/* In wave64, the instruction only cares about the address for lanes 0-31. */
if (bit_size == 16) {
/*
* lane: 0..7 | 8..15 | 16..23 | 24..31
* row: 0..7 | 0..7 | 8..15 | 8..15
* column: 0 | 8 | 0 | 8
*/
*row = nir_imul_imm(b, nir_udiv_imm(b, lane_id, 16), 8);
*row = nir_iadd(b, *row, nir_iand_imm(b, lane_id, 7));
nir_def *odd8 = nir_inverse_ballot_imm(b, UINT64_C(0xff00ff00ff00ff00), b->shader->info.api_subgroup_size);
*col = nir_bcsel(b, odd8, nir_imm_int(b, 8), nir_imm_int(b, 0));
} else {
/*
* lane: 0..3 | 4..7 | 8..11 | 12..15 | 16..19 | 20..23 | 24..27 | 28..31
* row: 0..3 | 0..3 | 4..7 | 4..7 | 8..11 | 8..11 | 12..15 | 12..15
* column: 0 | 8 | 0 | 8 | 0 | 8 | 0 | 8
*/
assert(bit_size == 8);
*row = nir_imul_imm(b, nir_udiv_imm(b, lane_id, 8), 4);
*row = nir_iadd(b, *row, nir_iand_imm(b, lane_id, 3));
nir_def *odd4 = nir_inverse_ballot_imm(b, UINT64_C(0xf0f0f0f0f0f0f0f0), b->shader->info.api_subgroup_size);
*col = nir_bcsel(b, odd4, nir_imm_int(b, 8), nir_imm_int(b, 0));
}
}
static bool
lower_cmat_load_store(nir_builder *b, nir_intrinsic_instr *intr, const lower_cmat_params *params)
{
@ -249,16 +285,15 @@ lower_cmat_load_store(nir_builder *b, nir_intrinsic_instr *intr, const lower_cma
nir_deref_instr *cmat_deref = nir_src_as_deref(intr->src[!is_load]);
struct glsl_cmat_description desc = *glsl_get_cmat_description(cmat_deref->type);
enum glsl_matrix_layout layout = nir_intrinsic_matrix_layout(intr);
unsigned length = radv_nir_cmat_length(desc, params);
nir_deref_instr *deref = nir_src_as_deref(intr->src[is_load]);
nir_def *stride = intr->src[2].ssa;
const uint32_t ptr_stride = glsl_get_bit_size(deref->type) / 8 * glsl_get_vector_elements(deref->type);
const unsigned idx_bits = deref->def.bit_size;
deref = nir_build_deref_cast(b, &deref->def, deref->modes, deref->type, ptr_stride);
nir_def *local_idx = nir_load_subgroup_invocation(b);
nir_def *inner_idx = nir_iand_imm(b, local_idx, 15);
bool load_acc_as_b = is_load && params->gfx_level < GFX11_7 && desc.use == GLSL_CMAT_USE_ACCUMULATOR &&
radv_nir_cmat_bits(desc) == 8 && params->wave_size == 32 &&
layout == GLSL_MATRIX_LAYOUT_COLUMN_MAJOR;
@ -270,7 +305,46 @@ lower_cmat_load_store(nir_builder *b, nir_intrinsic_instr *intr, const lower_cma
layout =
layout == GLSL_MATRIX_LAYOUT_COLUMN_MAJOR ? GLSL_MATRIX_LAYOUT_ROW_MAJOR : GLSL_MATRIX_LAYOUT_COLUMN_MAJOR;
unsigned length = radv_nir_cmat_length(desc, params);
bool use_tr_load = params->gfx_level >= GFX12 && layout == GLSL_MATRIX_LAYOUT_ROW_MAJOR && is_load &&
radv_nir_cmat_bits(desc) < 32 &&
(nir_deref_mode_is(deref, nir_var_mem_global) ||
(nir_deref_mode_is(deref, nir_var_mem_ubo) && !params->ubo_robustness) ||
(nir_deref_mode_is(deref, nir_var_mem_ssbo) && !params->ssbo_robustness));
if (use_tr_load) {
assert(!load_acc_as_b);
const unsigned elem_bits = radv_nir_cmat_bits(desc);
nir_def *row, *col;
get_load_tr_row_col(b, elem_bits, &row, &col);
col = nir_u2uN(b, col, idx_bits);
row = nir_u2uN(b, nir_imul(b, row, stride), idx_bits);
deref = nir_build_deref_ptr_as_array(b, deref, row);
deref = nir_build_deref_cast(b, &deref->def, deref->modes, glsl_scalar_type(desc.element_type), elem_bits / 8);
deref = nir_build_deref_ptr_as_array(b, deref, col);
/* Convert buffer deref to a global one. */
if (nir_deref_mode_is_one_of(deref, nir_var_mem_ssbo | nir_var_mem_ubo)) {
nir_def *descriptor = nir_ssbo_descriptor_amd(b, &deref->def);
nir_def *addr_lo = nir_channel(b, descriptor, 0);
nir_def *addr_hi = nir_extract_i16(b, nir_channel(b, descriptor, 1), nir_imm_int(b, 0));
nir_def *addr = nir_pack_64_2x32_split(b, addr_lo, addr_hi);
nir_def *offset = nir_channel(b, &deref->def, 2);
addr = nir_iadd_nuw(b, addr, nir_u2u64(b, offset));
deref = nir_build_deref_cast(b, addr, nir_var_mem_global, deref->type, elem_bits / 8);
}
nir_def *mat = nir_load_deref_transpose_amd(b, length, elem_bits, &deref->def);
nir_store_deref(b, cmat_deref, mat, nir_component_mask(mat->num_components));
nir_instr_remove(&intr->instr);
return true;
}
nir_def *local_idx = nir_load_subgroup_invocation(b);
nir_def *inner_idx = nir_iand_imm(b, local_idx, 15);
unsigned mul = radv_nir_cmat_length_mul(desc, params);
unsigned lanes_per_iter = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? params->wave_size : 16;
nir_def *vars[16];
@ -289,7 +363,6 @@ lower_cmat_load_store(nir_builder *b, nir_intrinsic_instr *intr, const lower_cma
vars[i] = nir_channel(b, src, i);
}
unsigned idx_bits = deref->def.bit_size;
nir_def *base_row = radv_get_base_row(b, desc, params, local_idx);
/* VUID-RuntimeSpirv-OpCooperativeMatrixLoadKHR-08986:
@ -1064,7 +1137,8 @@ lower_cmat_per_element_op(nir_builder *b, nir_cmat_call_instr *call, const lower
}
bool
radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_level, unsigned wave_size)
radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_level, struct radv_shader_stage *stage,
unsigned wave_size)
{
bool progress = false;
@ -1074,6 +1148,8 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
const lower_cmat_params params = {
.gfx_level = gfx_level,
.wave_size = wave_size,
.ubo_robustness = stage->key.coop_matrix_uniform_robustness,
.ssbo_robustness = stage->key.coop_matrix_storage_robustness,
};
struct nir_function *func = (struct nir_function *)exec_list_get_head_const(&shader->functions);

View file

@ -579,7 +579,7 @@ radv_shader_spirv_to_nir(const struct radv_compiler_info *compiler_info, struct
NIR_PASS(_, nir, nir_remove_dead_variables, nir_var_function_temp | nir_var_shader_temp, NULL);
}
NIR_PASS(progress, nir, radv_nir_lower_cooperative_matrix, compiler_info->ac->gfx_level,
NIR_PASS(progress, nir, radv_nir_lower_cooperative_matrix, compiler_info->ac->gfx_level, stage,
nir->info.max_subgroup_size);
if (progress) {
NIR_PASS(_, nir, nir_opt_dce);