mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2026-05-05 03:08:05 +02:00
radv/nir/lower_cmat: use radv_nir_cmat_bits consistently
Reviewed-by: Timur Kristóf <timur.kristof@gmail.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/34382>
This commit is contained in:
parent
300b6f7371
commit
31a3430570
1 changed files with 18 additions and 21 deletions
|
|
@ -60,6 +60,12 @@ typedef struct {
|
|||
unsigned wave_size;
|
||||
} lower_cmat_params;
|
||||
|
||||
static unsigned
|
||||
radv_nir_cmat_bits(struct glsl_cmat_description desc)
|
||||
{
|
||||
return glsl_base_type_bit_size(desc.element_type);
|
||||
}
|
||||
|
||||
static unsigned
|
||||
radv_nir_cmat_length(struct glsl_cmat_description desc, const lower_cmat_params *params)
|
||||
{
|
||||
|
|
@ -69,7 +75,7 @@ radv_nir_cmat_length(struct glsl_cmat_description desc, const lower_cmat_params
|
|||
} else {
|
||||
return desc.use != GLSL_CMAT_USE_ACCUMULATOR
|
||||
? 16
|
||||
: (desc.cols * desc.rows / params->wave_size * 32 / glsl_base_type_bit_size(desc.element_type));
|
||||
: (desc.cols * desc.rows / params->wave_size * 32 / radv_nir_cmat_bits(desc));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -84,23 +90,16 @@ radv_nir_cmat_length_mul(struct glsl_cmat_description desc, const lower_cmat_par
|
|||
* We then use the coefficient generated by this function to figure out
|
||||
* how many elements we really have.
|
||||
*/
|
||||
return desc.use == GLSL_CMAT_USE_ACCUMULATOR ? (32 / glsl_base_type_bit_size(desc.element_type)) : 1;
|
||||
return desc.use == GLSL_CMAT_USE_ACCUMULATOR ? (32 / radv_nir_cmat_bits(desc)) : 1;
|
||||
}
|
||||
}
|
||||
|
||||
static unsigned
|
||||
radv_nir_cmat_bits(struct glsl_cmat_description desc)
|
||||
{
|
||||
return glsl_base_type_bit_size(desc.element_type);
|
||||
}
|
||||
|
||||
static nir_def *
|
||||
radv_nir_load_cmat(nir_builder *b, const lower_cmat_params *params, nir_def *src)
|
||||
{
|
||||
nir_deref_instr *deref = nir_instr_as_deref(src->parent_instr);
|
||||
struct glsl_cmat_description desc = *glsl_get_cmat_description(deref->type);
|
||||
return nir_build_load_deref(b, radv_nir_cmat_length(desc, params), glsl_base_type_bit_size(desc.element_type), src,
|
||||
0);
|
||||
return nir_build_load_deref(b, radv_nir_cmat_length(desc, params), radv_nir_cmat_bits(desc), src, 0);
|
||||
}
|
||||
|
||||
static const struct glsl_type *
|
||||
|
|
@ -297,7 +296,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
|
|||
if (mul > 1) {
|
||||
for (unsigned i = 0; i < length; ++i)
|
||||
if (i % mul != 0)
|
||||
vars[i] = nir_undef(&b, 1, glsl_base_type_bit_size(desc.element_type));
|
||||
vars[i] = nir_undef(&b, 1, radv_nir_cmat_bits(desc));
|
||||
}
|
||||
|
||||
unsigned idx_bits = deref->def.bit_size;
|
||||
|
|
@ -328,9 +327,8 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
|
|||
row_offset = nir_u2uN(&b, row_offset, idx_bits);
|
||||
|
||||
nir_deref_instr *iter_deref = nir_build_deref_ptr_as_array(&b, deref, col_offset);
|
||||
iter_deref =
|
||||
nir_build_deref_cast(&b, &iter_deref->def, deref->modes, glsl_scalar_type(desc.element_type),
|
||||
glsl_base_type_bit_size(desc.element_type) / 8);
|
||||
iter_deref = nir_build_deref_cast(&b, &iter_deref->def, deref->modes,
|
||||
glsl_scalar_type(desc.element_type), radv_nir_cmat_bits(desc) / 8);
|
||||
iter_deref = nir_build_deref_ptr_as_array(&b, iter_deref, row_offset);
|
||||
|
||||
vars[i * mul] = nir_load_deref(&b, iter_deref);
|
||||
|
|
@ -400,9 +398,8 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
|
|||
row_offset = nir_u2uN(&b, row_offset, idx_bits);
|
||||
|
||||
nir_deref_instr *iter_deref = nir_build_deref_ptr_as_array(&b, deref, col_offset);
|
||||
iter_deref =
|
||||
nir_build_deref_cast(&b, &iter_deref->def, deref->modes, glsl_scalar_type(desc.element_type),
|
||||
glsl_base_type_bit_size(desc.element_type) / 8);
|
||||
iter_deref = nir_build_deref_cast(&b, &iter_deref->def, deref->modes,
|
||||
glsl_scalar_type(desc.element_type), radv_nir_cmat_bits(desc) / 8);
|
||||
iter_deref = nir_build_deref_ptr_as_array(&b, iter_deref, row_offset);
|
||||
|
||||
nir_store_deref(&b, iter_deref, vars[i * mul], 1);
|
||||
|
|
@ -448,8 +445,8 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
|
|||
nir_get_nir_type_for_glsl_base_type(dst_element_type),
|
||||
nir_rounding_mode_undef);
|
||||
|
||||
if (gfx_level < GFX12 && glsl_base_type_bit_size(src_element_type) == 16 &&
|
||||
glsl_base_type_bit_size(dst_element_type) == 32 && dst_desc.use == GLSL_CMAT_USE_ACCUMULATOR) {
|
||||
if (gfx_level < GFX12 && radv_nir_cmat_bits(src_desc) == 16 && radv_nir_cmat_bits(dst_desc) == 32 &&
|
||||
dst_desc.use == GLSL_CMAT_USE_ACCUMULATOR) {
|
||||
nir_def *components[NIR_MAX_VEC_COMPONENTS];
|
||||
for (unsigned i = 0; i * 2 < src->num_components; ++i) {
|
||||
components[i] = nir_channel(&b, src, i * 2);
|
||||
|
|
@ -459,8 +456,8 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
|
|||
|
||||
nir_def *ret = nir_build_alu1(&b, op, src);
|
||||
|
||||
if (gfx_level < GFX12 && glsl_base_type_bit_size(src_element_type) == 32 &&
|
||||
glsl_base_type_bit_size(dst_element_type) == 16 && dst_desc.use == GLSL_CMAT_USE_ACCUMULATOR) {
|
||||
if (gfx_level < GFX12 && radv_nir_cmat_bits(src_desc) == 32 && radv_nir_cmat_bits(dst_desc) == 16 &&
|
||||
dst_desc.use == GLSL_CMAT_USE_ACCUMULATOR) {
|
||||
nir_def *components[NIR_MAX_VEC_COMPONENTS];
|
||||
for (unsigned i = 0; i < ret->num_components; ++i) {
|
||||
components[i * 2] = nir_channel(&b, ret, i);
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue