radv/nir/lower_cmat: use radv_nir_cmat_bits consistently

Reviewed-by: Timur Kristóf <timur.kristof@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/34382>
This commit is contained in:
Georg Lehmann 2025-04-04 13:26:00 +02:00 committed by Marge Bot
parent 300b6f7371
commit 31a3430570

View file

@ -60,6 +60,12 @@ typedef struct {
unsigned wave_size;
} lower_cmat_params;
static unsigned
radv_nir_cmat_bits(struct glsl_cmat_description desc)
{
return glsl_base_type_bit_size(desc.element_type);
}
static unsigned
radv_nir_cmat_length(struct glsl_cmat_description desc, const lower_cmat_params *params)
{
@ -69,7 +75,7 @@ radv_nir_cmat_length(struct glsl_cmat_description desc, const lower_cmat_params
} else {
return desc.use != GLSL_CMAT_USE_ACCUMULATOR
? 16
: (desc.cols * desc.rows / params->wave_size * 32 / glsl_base_type_bit_size(desc.element_type));
: (desc.cols * desc.rows / params->wave_size * 32 / radv_nir_cmat_bits(desc));
}
}
@ -84,23 +90,16 @@ radv_nir_cmat_length_mul(struct glsl_cmat_description desc, const lower_cmat_par
* We then use the coefficient generated by this function to figure out
* how many elements we really have.
*/
return desc.use == GLSL_CMAT_USE_ACCUMULATOR ? (32 / glsl_base_type_bit_size(desc.element_type)) : 1;
return desc.use == GLSL_CMAT_USE_ACCUMULATOR ? (32 / radv_nir_cmat_bits(desc)) : 1;
}
}
static unsigned
radv_nir_cmat_bits(struct glsl_cmat_description desc)
{
return glsl_base_type_bit_size(desc.element_type);
}
static nir_def *
radv_nir_load_cmat(nir_builder *b, const lower_cmat_params *params, nir_def *src)
{
nir_deref_instr *deref = nir_instr_as_deref(src->parent_instr);
struct glsl_cmat_description desc = *glsl_get_cmat_description(deref->type);
return nir_build_load_deref(b, radv_nir_cmat_length(desc, params), glsl_base_type_bit_size(desc.element_type), src,
0);
return nir_build_load_deref(b, radv_nir_cmat_length(desc, params), radv_nir_cmat_bits(desc), src, 0);
}
static const struct glsl_type *
@ -297,7 +296,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
if (mul > 1) {
for (unsigned i = 0; i < length; ++i)
if (i % mul != 0)
vars[i] = nir_undef(&b, 1, glsl_base_type_bit_size(desc.element_type));
vars[i] = nir_undef(&b, 1, radv_nir_cmat_bits(desc));
}
unsigned idx_bits = deref->def.bit_size;
@ -328,9 +327,8 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
row_offset = nir_u2uN(&b, row_offset, idx_bits);
nir_deref_instr *iter_deref = nir_build_deref_ptr_as_array(&b, deref, col_offset);
iter_deref =
nir_build_deref_cast(&b, &iter_deref->def, deref->modes, glsl_scalar_type(desc.element_type),
glsl_base_type_bit_size(desc.element_type) / 8);
iter_deref = nir_build_deref_cast(&b, &iter_deref->def, deref->modes,
glsl_scalar_type(desc.element_type), radv_nir_cmat_bits(desc) / 8);
iter_deref = nir_build_deref_ptr_as_array(&b, iter_deref, row_offset);
vars[i * mul] = nir_load_deref(&b, iter_deref);
@ -400,9 +398,8 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
row_offset = nir_u2uN(&b, row_offset, idx_bits);
nir_deref_instr *iter_deref = nir_build_deref_ptr_as_array(&b, deref, col_offset);
iter_deref =
nir_build_deref_cast(&b, &iter_deref->def, deref->modes, glsl_scalar_type(desc.element_type),
glsl_base_type_bit_size(desc.element_type) / 8);
iter_deref = nir_build_deref_cast(&b, &iter_deref->def, deref->modes,
glsl_scalar_type(desc.element_type), radv_nir_cmat_bits(desc) / 8);
iter_deref = nir_build_deref_ptr_as_array(&b, iter_deref, row_offset);
nir_store_deref(&b, iter_deref, vars[i * mul], 1);
@ -448,8 +445,8 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
nir_get_nir_type_for_glsl_base_type(dst_element_type),
nir_rounding_mode_undef);
if (gfx_level < GFX12 && glsl_base_type_bit_size(src_element_type) == 16 &&
glsl_base_type_bit_size(dst_element_type) == 32 && dst_desc.use == GLSL_CMAT_USE_ACCUMULATOR) {
if (gfx_level < GFX12 && radv_nir_cmat_bits(src_desc) == 16 && radv_nir_cmat_bits(dst_desc) == 32 &&
dst_desc.use == GLSL_CMAT_USE_ACCUMULATOR) {
nir_def *components[NIR_MAX_VEC_COMPONENTS];
for (unsigned i = 0; i * 2 < src->num_components; ++i) {
components[i] = nir_channel(&b, src, i * 2);
@ -459,8 +456,8 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
nir_def *ret = nir_build_alu1(&b, op, src);
if (gfx_level < GFX12 && glsl_base_type_bit_size(src_element_type) == 32 &&
glsl_base_type_bit_size(dst_element_type) == 16 && dst_desc.use == GLSL_CMAT_USE_ACCUMULATOR) {
if (gfx_level < GFX12 && radv_nir_cmat_bits(src_desc) == 32 && radv_nir_cmat_bits(dst_desc) == 16 &&
dst_desc.use == GLSL_CMAT_USE_ACCUMULATOR) {
nir_def *components[NIR_MAX_VEC_COMPONENTS];
for (unsigned i = 0; i < ret->num_components; ++i) {
components[i * 2] = nir_channel(&b, ret, i);