mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2026-03-08 01:00:31 +01:00
radv/nir/lower_cmat: use cmat_mul instead of duplicating hw details for type conversion
Reviewed-by: Timur Kristóf <timur.kristof@gmail.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/34382>
This commit is contained in:
parent
31a3430570
commit
bbc9bc9d24
1 changed files with 14 additions and 10 deletions
|
|
@ -445,25 +445,29 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
|
|||
nir_get_nir_type_for_glsl_base_type(dst_element_type),
|
||||
nir_rounding_mode_undef);
|
||||
|
||||
if (gfx_level < GFX12 && radv_nir_cmat_bits(src_desc) == 16 && radv_nir_cmat_bits(dst_desc) == 32 &&
|
||||
dst_desc.use == GLSL_CMAT_USE_ACCUMULATOR) {
|
||||
unsigned dst_mul = radv_nir_cmat_length_mul(dst_desc, ¶ms);
|
||||
unsigned src_mul = radv_nir_cmat_length_mul(src_desc, ¶ms);
|
||||
|
||||
if (src_mul > dst_mul) {
|
||||
nir_def *components[NIR_MAX_VEC_COMPONENTS];
|
||||
for (unsigned i = 0; i * 2 < src->num_components; ++i) {
|
||||
components[i] = nir_channel(&b, src, i * 2);
|
||||
unsigned scale = src_mul / dst_mul;
|
||||
for (unsigned i = 0; i * scale < src->num_components; ++i) {
|
||||
components[i] = nir_channel(&b, src, i * scale);
|
||||
}
|
||||
src = nir_vec(&b, components, src->num_components / 2);
|
||||
src = nir_vec(&b, components, src->num_components / scale);
|
||||
}
|
||||
|
||||
nir_def *ret = nir_build_alu1(&b, op, src);
|
||||
|
||||
if (gfx_level < GFX12 && radv_nir_cmat_bits(src_desc) == 32 && radv_nir_cmat_bits(dst_desc) == 16 &&
|
||||
dst_desc.use == GLSL_CMAT_USE_ACCUMULATOR) {
|
||||
if (dst_mul > src_mul) {
|
||||
nir_def *components[NIR_MAX_VEC_COMPONENTS];
|
||||
unsigned scale = dst_mul / src_mul;
|
||||
for (unsigned i = 0; i < ret->num_components; ++i) {
|
||||
components[i * 2] = nir_channel(&b, ret, i);
|
||||
components[i * 2 + 1] = nir_undef(&b, 1, 16);
|
||||
components[i * scale] = nir_channel(&b, ret, i);
|
||||
for (unsigned j = 1; j < scale; j++)
|
||||
components[i * scale + j] = nir_undef(&b, 1, ret->bit_size);
|
||||
}
|
||||
ret = nir_vec(&b, components, ret->num_components * 2);
|
||||
ret = nir_vec(&b, components, ret->num_components * scale);
|
||||
}
|
||||
|
||||
nir_store_deref(&b, dst_deref, ret, nir_component_mask(ret->num_components));
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue