diff --git a/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c b/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c index c4517502ff8..1b7ed023e39 100644 --- a/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c +++ b/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c @@ -607,21 +607,17 @@ convert_use(nir_builder *b, nir_def *src, enum glsl_cmat_use src_use, enum glsl_ assert(num_comps == 16 || params->gfx_level >= GFX12); if (params->gfx_level >= GFX12) { - if (params->wave_size == 64) { - nir_def *cond = nir_inverse_ballot_imm(b, 0xf0f0f0f00f0f0f0f, params->wave_size); + /* One component contains 2/4 rows in wave32/64, so we must transpose inside it. */ + for (int cross32 = params->wave_size == 64; cross32 >= 0; cross32--) { + uint64_t even = cross32 ? 0xf0f0f0f00f0f0f0f : 0xff0000ffff0000ff; + nir_def *cond = nir_inverse_ballot_imm(b, even, params->wave_size); + unsigned x_mask = cross32 ? 0x24 : 0x18; for (unsigned i = 0; i < num_comps; i++) { nir_def *comp = components[i]; - nir_def *compx = shuffle_xor_imm(b, comp, 0x24); + nir_def *compx = shuffle_xor_imm(b, comp, x_mask); components[i] = nir_bcsel(b, cond, comp, compx); } } - - nir_def *cond = nir_inverse_ballot_imm(b, 0xff0000ffff0000ff, params->wave_size); - for (unsigned i = 0; i < num_comps; i++) { - nir_def *comp = components[i]; - nir_def *compx = shuffle_xor_imm(b, comp, 0x18); - components[i] = nir_bcsel(b, cond, comp, compx); - } } }