radv/gfx11.7: take GFX12 paths in radv_nir_lower_cooperative_matrix

Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Georg Lehmann <dadschoorse@gmail.com>
Reviewed-by: Samuel Pitoiset <samuel.pitoiset@gmail.com>
Reviewed-by: Marek Olšák <marek.olsak@amd.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/40917>
This commit is contained in:
Rhys Perry 2026-04-08 15:14:10 +01:00 committed by Marge Bot
parent ee78bea393
commit df9195ac34

View file

@ -13,7 +13,7 @@
* to 16..31 and for wave64 also into lanes 32..47 and 48..63. A&B matrices are
* always vectors of 16 elements.
*
* On GFX12, there is no data replication and the matrices layout is described
* On GFX11.7+, there is no data replication and the matrices layout is described
* as below:
*
* Wave32:
@ -54,7 +54,7 @@ radv_nir_cmat_bits(struct glsl_cmat_description desc)
static unsigned
radv_nir_cmat_length(struct glsl_cmat_description desc, const lower_cmat_params *params)
{
if (params->gfx_level >= GFX12) {
if (params->gfx_level >= GFX11_7) {
assert(desc.cols == 16 && desc.rows == 16);
return 256 / params->wave_size;
} else if (desc.use != GLSL_CMAT_USE_ACCUMULATOR) {
@ -67,7 +67,7 @@ radv_nir_cmat_length(struct glsl_cmat_description desc, const lower_cmat_params
static unsigned
radv_nir_cmat_length_mul(struct glsl_cmat_description desc, const lower_cmat_params *params)
{
if (params->gfx_level >= GFX12 || desc.use != GLSL_CMAT_USE_ACCUMULATOR) {
if (params->gfx_level >= GFX11_7 || desc.use != GLSL_CMAT_USE_ACCUMULATOR) {
return 1;
} else {
/* For GFX11 C matrices we have 1 VGPR per element even if the element type is
@ -148,7 +148,7 @@ radv_get_base_row(nir_builder *b, struct glsl_cmat_description desc, const lower
{
nir_def *base_row;
if (params->gfx_level >= GFX12) {
if (params->gfx_level >= GFX11_7) {
base_row = nir_udiv_imm(b, local_idx, 16);
if (params->wave_size == 64) {
@ -259,7 +259,7 @@ lower_cmat_load_store(nir_builder *b, nir_intrinsic_instr *intr, const lower_cma
nir_def *local_idx = nir_load_subgroup_invocation(b);
nir_def *inner_idx = nir_iand_imm(b, local_idx, 15);
bool load_acc_as_b = is_load && params->gfx_level < GFX12 && desc.use == GLSL_CMAT_USE_ACCUMULATOR &&
bool load_acc_as_b = is_load && params->gfx_level < GFX11_7 && desc.use == GLSL_CMAT_USE_ACCUMULATOR &&
radv_nir_cmat_bits(desc) == 8 && params->wave_size == 32 &&
layout == GLSL_MATRIX_LAYOUT_COLUMN_MAJOR;
if (load_acc_as_b)
@ -281,7 +281,7 @@ lower_cmat_load_store(nir_builder *b, nir_intrinsic_instr *intr, const lower_cma
vars[i] = nir_undef(b, 1, radv_nir_cmat_bits(desc));
}
} else {
if (params->gfx_level < GFX12 && desc.use != GLSL_CMAT_USE_ACCUMULATOR)
if (params->gfx_level < GFX11_7 && desc.use != GLSL_CMAT_USE_ACCUMULATOR)
nir_push_if(b, nir_ilt_imm(b, local_idx, 16));
nir_def *src = radv_nir_load_cmat(b, params, &cmat_deref->def);
@ -302,7 +302,7 @@ lower_cmat_load_store(nir_builder *b, nir_intrinsic_instr *intr, const lower_cma
if (layout == GLSL_MATRIX_LAYOUT_COLUMN_MAJOR)
align_mul = MIN2(16, radv_nir_cmat_bits(desc) * desc.rows / 8);
if (params->gfx_level >= GFX12)
if (params->gfx_level >= GFX11_7)
align_mul /= params->wave_size / 16;
else if (desc.use == GLSL_CMAT_USE_ACCUMULATOR)
align_mul = 0;
@ -312,7 +312,7 @@ lower_cmat_load_store(nir_builder *b, nir_intrinsic_instr *intr, const lower_cma
nir_def *row_offset;
uint32_t row_iter;
if (params->gfx_level >= GFX12) {
if (params->gfx_level >= GFX11_7) {
row_iter = i;
} else {
row_iter = i * lanes_per_iter / 16;
@ -366,7 +366,7 @@ lower_cmat_load_store(nir_builder *b, nir_intrinsic_instr *intr, const lower_cma
}
nir_store_deref(b, cmat_deref, mat, nir_component_mask(mat->num_components));
} else if (params->gfx_level < GFX12 && desc.use != GLSL_CMAT_USE_ACCUMULATOR) {
} else if (params->gfx_level < GFX11_7 && desc.use != GLSL_CMAT_USE_ACCUMULATOR) {
nir_pop_if(b, NULL);
}
nir_instr_remove(&intr->instr);
@ -445,7 +445,7 @@ convert_use(nir_builder *b, nir_def *src, enum glsl_cmat_use src_use, enum glsl_
{
if (src_use == dst_use)
return src;
if (params->gfx_level >= GFX12) {
if (params->gfx_level >= GFX11_7) {
if (src_use == GLSL_CMAT_USE_B && dst_use == GLSL_CMAT_USE_ACCUMULATOR)
return src;
if (src_use == GLSL_CMAT_USE_ACCUMULATOR && dst_use == GLSL_CMAT_USE_B)
@ -467,7 +467,7 @@ convert_use(nir_builder *b, nir_def *src, enum glsl_cmat_use src_use, enum glsl_
components[i] = nir_channel(b, src, i);
if (src_use == GLSL_CMAT_USE_ACCUMULATOR && dst_use == GLSL_CMAT_USE_B) {
assert(params->gfx_level < GFX12);
assert(params->gfx_level < GFX11_7);
nir_def *tmp[NIR_MAX_VEC_COMPONENTS];
if (src->bit_size == 32) {
@ -521,7 +521,7 @@ convert_use(nir_builder *b, nir_def *src, enum glsl_cmat_use src_use, enum glsl_
assert(num_comps == 16);
} else if (src_use == GLSL_CMAT_USE_B && dst_use == GLSL_CMAT_USE_ACCUMULATOR) {
assert(params->gfx_level < GFX12);
assert(params->gfx_level < GFX11_7);
assert(num_comps == 16);
if (src->bit_size == 32) {
for (unsigned keep32 = 0; keep32 < ((params->wave_size == 64) ? 2 : 1); keep32++) {
@ -585,9 +585,9 @@ convert_use(nir_builder *b, nir_def *src, enum glsl_cmat_use src_use, enum glsl_
}
}
assert(num_comps == 16 || params->gfx_level >= GFX12);
assert(num_comps == 16 || params->gfx_level >= GFX11_7);
if (params->gfx_level >= GFX12) {
if (params->gfx_level >= GFX11_7) {
/* One component contains 2/4 rows in wave32/64, so we must transpose inside it. */
for (int cross32 = params->wave_size == 64; cross32 >= 0; cross32--) {
uint64_t even = cross32 ? 0xf0f0f0f00f0f0f0f : 0xff0000ffff0000ff;
@ -901,7 +901,7 @@ lower_cmat_reduce_2x2_call(nir_builder *b, nir_cmat_call_instr *call, const lowe
nir_def *low16 = nir_inverse_ballot_imm(b, 0xffff0000ffff, params->wave_size);
for (unsigned m = 0; m < 4; m++) {
for (unsigned i = 0; i < length / mul / 2; i++) {
if (params->gfx_level >= GFX12) {
if (params->gfx_level >= GFX11_7) {
/* The neighboring row is in the VGPR next to us */
nir_call(b, fnptr, &qd_tmp_deref->def, src_components[m][i * 2], src_components[m][i * 2 + 1]);
src_components[m][i] = nir_load_deref(b, qd_tmp_deref);
@ -946,7 +946,7 @@ lower_cmat_reduce_2x2_call(nir_builder *b, nir_cmat_call_instr *call, const lowe
for (unsigned i = 0; i < length / mul; i++)
vars[i] = nir_lane_permute_16_amd(b, vars[i], perm_low, perm_high);
if (params->gfx_level >= GFX12) {
if (params->gfx_level >= GFX11_7) {
/* For GFX12, we still have to swap the row(s) in upper half coming from the bottom two
* matrices with low row(s) in the from other two matrices.
*/
@ -1021,7 +1021,7 @@ lower_cmat_per_element_op(nir_builder *b, nir_cmat_call_instr *call, const lower
nir_call_instr *new_call = nir_call_instr_create(b->shader, fnptr);
uint32_t row_iter;
if (params->gfx_level >= GFX12) {
if (params->gfx_level >= GFX11_7) {
row_iter = i;
} else {
row_iter = i * lanes_per_iter / 16;
@ -1205,7 +1205,7 @@ apply_component_mods(nir_scalar *comp, unsigned num_comps, unsigned stride, nir_
static bool
opt_cmat_modifiers(nir_builder *b, nir_intrinsic_instr *intrin, enum amd_gfx_level gfx_level, unsigned src_idx)
{
unsigned length_mul = src_idx == 2 && intrin->src[2].ssa->bit_size == 16 && gfx_level < GFX12 ? 2 : 1;
unsigned length_mul = src_idx == 2 && intrin->src[2].ssa->bit_size == 16 && gfx_level < GFX11_7 ? 2 : 1;
nir_scalar comp[NIR_MAX_VEC_COMPONENTS] = {0};
nir_def *src = intrin->src[src_idx].ssa;