radv/nir/lower_cmat: share cmat_load/cmat_store code

Reviewed-by: Rhys Perry <pendingchaos02@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/35633>
This commit is contained in:
Georg Lehmann 2025-06-19 12:37:06 +02:00 committed by Marge Bot
parent f3f67823c4
commit ed2ecf9ef8

View file

@ -418,89 +418,18 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
progress = true;
break;
}
case nir_intrinsic_cmat_load: {
nir_deref_instr *dst_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
struct glsl_cmat_description desc = *glsl_get_cmat_description(dst_deref->type);
enum glsl_matrix_layout layout = nir_intrinsic_matrix_layout(intr);
nir_deref_instr *deref = nir_instr_as_deref(intr->src[1].ssa->parent_instr);
nir_def *stride = intr->src[2].ssa;
nir_def *local_idx = nir_load_subgroup_invocation(&b);
nir_def *inner_idx = nir_iand_imm(&b, local_idx, 15);
/* A input is transposed */
if (desc.use == GLSL_CMAT_USE_A)
layout = layout == GLSL_MATRIX_LAYOUT_COLUMN_MAJOR ? GLSL_MATRIX_LAYOUT_ROW_MAJOR
: GLSL_MATRIX_LAYOUT_COLUMN_MAJOR;
unsigned length = radv_nir_cmat_length(desc, &params);
unsigned mul = radv_nir_cmat_length_mul(desc, &params);
unsigned lanes_per_iter = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? params.wave_size : 16;
nir_def *vars[16];
if (mul > 1) {
for (unsigned i = 0; i < length; ++i)
if (i % mul != 0)
vars[i] = nir_undef(&b, 1, radv_nir_cmat_bits(desc));
}
unsigned idx_bits = deref->def.bit_size;
nir_def *base_row = radv_get_base_row(&b, desc, &params, local_idx);
for (unsigned i = 0; i < length / mul; ++i) {
nir_def *col_offset = inner_idx;
nir_def *row_offset;
uint32_t row_iter;
if (gfx_level >= GFX12) {
row_iter = i;
} else {
row_iter = i * lanes_per_iter / 16;
}
row_offset = nir_iadd_imm(&b, base_row, row_iter);
if (layout == GLSL_MATRIX_LAYOUT_ROW_MAJOR) {
nir_def *tmp = col_offset;
col_offset = row_offset;
row_offset = tmp;
}
col_offset = nir_imul(&b, col_offset, stride);
col_offset = nir_u2uN(&b, col_offset, idx_bits);
row_offset = nir_u2uN(&b, row_offset, idx_bits);
nir_deref_instr *iter_deref = nir_build_deref_ptr_as_array(&b, deref, col_offset);
iter_deref = nir_build_deref_cast(&b, &iter_deref->def, deref->modes,
glsl_scalar_type(desc.element_type), radv_nir_cmat_bits(desc) / 8);
iter_deref = nir_build_deref_ptr_as_array(&b, iter_deref, row_offset);
vars[i * mul] = nir_load_deref(&b, iter_deref);
}
nir_def *mat = nir_vec(&b, vars, length);
nir_store_deref(&b, dst_deref, mat, nir_component_mask(mat->num_components));
nir_instr_remove(instr);
progress = true;
break;
}
case nir_intrinsic_cmat_load:
case nir_intrinsic_cmat_store: {
const bool is_load = intr->intrinsic == nir_intrinsic_cmat_load;
nir_deref_instr *cmat_deref = nir_instr_as_deref(intr->src[!is_load].ssa->parent_instr);
struct glsl_cmat_description desc = *glsl_get_cmat_description(cmat_deref->type);
enum glsl_matrix_layout layout = nir_intrinsic_matrix_layout(intr);
nir_deref_instr *deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
nir_def *src = intr->src[1].ssa;
nir_deref_instr *deref = nir_instr_as_deref(intr->src[is_load].ssa->parent_instr);
nir_def *stride = intr->src[2].ssa;
nir_deref_instr *src_deref = nir_instr_as_deref(src->parent_instr);
struct glsl_cmat_description desc = *glsl_get_cmat_description(src_deref->type);
src = radv_nir_load_cmat(&b, &params, src);
nir_def *local_idx = nir_load_subgroup_invocation(&b);
if (gfx_level < GFX12 && desc.use != GLSL_CMAT_USE_ACCUMULATOR)
nir_push_if(&b, nir_ilt_imm(&b, local_idx, 16));
nir_def *inner_idx = nir_iand_imm(&b, local_idx, 15);
/* A input is transposed */
@ -512,8 +441,20 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
unsigned mul = radv_nir_cmat_length_mul(desc, &params);
unsigned lanes_per_iter = desc.use == GLSL_CMAT_USE_ACCUMULATOR ? params.wave_size : 16;
nir_def *vars[16];
for (unsigned i = 0; i < length; ++i)
vars[i] = nir_channel(&b, src, i);
if (is_load) {
if (mul > 1) {
for (unsigned i = 0; i < length; ++i)
if (i % mul != 0)
vars[i] = nir_undef(&b, 1, radv_nir_cmat_bits(desc));
}
} else {
if (gfx_level < GFX12 && desc.use != GLSL_CMAT_USE_ACCUMULATOR)
nir_push_if(&b, nir_ilt_imm(&b, local_idx, 16));
nir_def *src = radv_nir_load_cmat(&b, &params, &cmat_deref->def);
for (unsigned i = 0; i < length; ++i)
vars[i] = nir_channel(&b, src, i);
}
unsigned idx_bits = deref->def.bit_size;
nir_def *base_row = radv_get_base_row(&b, desc, &params, local_idx);
@ -547,12 +488,19 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
glsl_scalar_type(desc.element_type), radv_nir_cmat_bits(desc) / 8);
iter_deref = nir_build_deref_ptr_as_array(&b, iter_deref, row_offset);
nir_store_deref(&b, iter_deref, vars[i * mul], 1);
if (is_load) {
vars[i * mul] = nir_load_deref(&b, iter_deref);
} else {
nir_store_deref(&b, iter_deref, vars[i * mul], 1);
}
}
if (gfx_level < GFX12 && desc.use != GLSL_CMAT_USE_ACCUMULATOR)
if (is_load) {
nir_def *mat = nir_vec(&b, vars, length);
nir_store_deref(&b, cmat_deref, mat, nir_component_mask(mat->num_components));
} else if (gfx_level < GFX12 && desc.use != GLSL_CMAT_USE_ACCUMULATOR) {
nir_pop_if(&b, NULL);
}
nir_instr_remove(instr);
progress = true;
break;