diff --git a/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c b/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c index c31ca27db60..7558eb0c449 100644 --- a/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c +++ b/src/amd/vulkan/nir/radv_nir_lower_cooperative_matrix.c @@ -204,6 +204,128 @@ convert_base_type(nir_builder *b, nir_def *src, enum glsl_base_type src_type, en return nir_build_alu1(b, op, src); } +static nir_def * +convert_use(nir_builder *b, nir_def *src, enum glsl_cmat_use src_use, enum glsl_cmat_use dst_use, + const lower_cmat_params *params) +{ + if (src_use == dst_use) + return src; + if (params->gfx_level >= GFX12) { + if (src_use == GLSL_CMAT_USE_B && dst_use == GLSL_CMAT_USE_ACCUMULATOR) + return src; + if (src_use == GLSL_CMAT_USE_ACCUMULATOR && dst_use == GLSL_CMAT_USE_B) + return src; + } + + if (src_use == GLSL_CMAT_USE_A && dst_use == GLSL_CMAT_USE_ACCUMULATOR) { + src = convert_use(b, src, GLSL_CMAT_USE_A, GLSL_CMAT_USE_B, params); + return convert_use(b, src, GLSL_CMAT_USE_B, GLSL_CMAT_USE_ACCUMULATOR, params); + } else if (src_use == GLSL_CMAT_USE_ACCUMULATOR && dst_use == GLSL_CMAT_USE_A) { + src = convert_use(b, src, GLSL_CMAT_USE_ACCUMULATOR, GLSL_CMAT_USE_B, params); + return convert_use(b, src, GLSL_CMAT_USE_B, GLSL_CMAT_USE_A, params); + } + + nir_def *components[NIR_MAX_VEC_COMPONENTS] = {NULL}; + + unsigned num_comps = src->num_components; + for (unsigned i = 0; i < num_comps; i++) + components[i] = nir_channel(b, src, i); + + if (src_use == GLSL_CMAT_USE_ACCUMULATOR && dst_use == GLSL_CMAT_USE_B) { + assert(params->gfx_level < GFX12); + nir_def *tmp[NIR_MAX_VEC_COMPONENTS]; + + if (params->wave_size == 64) { + nir_def *low_lanes = nir_inverse_ballot(b, 1, nir_imm_intN_t(b, UINT32_MAX, 64)); + for (int i = 0; i < num_comps; i++) { + nir_def *comp = components[i]; + nir_def *half_swap = nir_rotate(b, comp, nir_imm_int(b, 32), .cluster_size = 64); + + tmp[i * 2] = nir_bcsel(b, low_lanes, comp, half_swap); + tmp[i * 2 + 1] = nir_bcsel(b, low_lanes, half_swap, comp); + } + num_comps *= 2; + memcpy(components, tmp, sizeof(components)); + } + + for (int i = 0; i < num_comps; i++) { + unsigned broadcast_low16 = 0xf; + unsigned broadcast_high16 = 0xf | (0x10 << 10); + tmp[i * 2] = nir_masked_swizzle_amd(b, components[i], .swizzle_mask = broadcast_low16, .fetch_inactive = 1); + tmp[i * 2 + 1] = + nir_masked_swizzle_amd(b, components[i], .swizzle_mask = broadcast_high16, .fetch_inactive = 1); + } + + num_comps *= 2; + memcpy(components, tmp, sizeof(components)); + assert(num_comps == 16); + } else if (src_use == GLSL_CMAT_USE_B && dst_use == GLSL_CMAT_USE_ACCUMULATOR) { + assert(params->gfx_level < GFX12); + assert(num_comps == 16); + for (unsigned keep32 = 0; keep32 < ((params->wave_size == 64) ? 2 : 1); keep32++) { + nir_def *ballot = nir_imm_intN_t(b, keep32 ? UINT32_MAX : 0xffff0000ffffull, params->wave_size); + nir_def *keep = nir_inverse_ballot(b, 1, ballot); + for (unsigned i = 0; i < num_comps; i++) { + components[i] = nir_bcsel(b, keep, components[i], components[i + 1]); + } + num_comps /= 2; + } + } else if ((src_use == GLSL_CMAT_USE_A && dst_use == GLSL_CMAT_USE_B) || + (src_use == GLSL_CMAT_USE_B && dst_use == GLSL_CMAT_USE_A)) { + /* Transpose is a mess... */ + for (unsigned x_mask = 1; x_mask < num_comps; x_mask *= 2) { + /* Use separate masks to always keep the masked_swizzle on the first source of v_cndmask. */ + uint64_t mask = 0; + for (unsigned i = 0; i < 64; i += 2 * x_mask) { + mask |= BITFIELD64_MASK(x_mask) << i; + } + + nir_def *even = nir_inverse_ballot(b, 1, nir_imm_intN_t(b, mask, params->wave_size)); + nir_def *odd = nir_inverse_ballot(b, 1, nir_imm_intN_t(b, mask << x_mask, params->wave_size)); + + for (unsigned i = 0; i < num_comps; i += 2 * x_mask) { + for (unsigned j = 0; j < x_mask; j++) { + unsigned pos0 = i + j; + unsigned pos1 = pos0 + x_mask; + nir_def *comp0 = components[pos0]; + nir_def *comp1 = components[pos1]; + + nir_def *comp0x = + nir_masked_swizzle_amd(b, comp0, .swizzle_mask = 0x1f | (x_mask << 10), .fetch_inactive = 1); + nir_def *comp1x = + nir_masked_swizzle_amd(b, comp1, .swizzle_mask = 0x1f | (x_mask << 10), .fetch_inactive = 1); + + components[pos0] = nir_bcsel(b, even, comp0, comp1x); + components[pos1] = nir_bcsel(b, odd, comp1, comp0x); + } + } + } + + assert(num_comps == 16 || params->gfx_level >= GFX12); + + if (params->gfx_level >= GFX12) { + if (params->wave_size == 64) { + nir_def *cond = nir_inverse_ballot(b, 1, nir_imm_intN_t(b, 0xf0f0f0f00f0f0f0f, params->wave_size)); + for (unsigned i = 0; i < num_comps; i++) { + nir_def *comp = components[i]; + nir_def *compx = nir_rotate(b, comp, nir_imm_int(b, 32)); + compx = nir_masked_swizzle_amd(b, compx, .swizzle_mask = 0x1f | (0x4 << 10), .fetch_inactive = 1); + components[i] = nir_bcsel(b, cond, comp, compx); + } + } + + nir_def *cond = nir_inverse_ballot(b, 1, nir_imm_intN_t(b, 0xff0000ffff0000ff, params->wave_size)); + for (unsigned i = 0; i < num_comps; i++) { + nir_def *comp = components[i]; + nir_def *compx = nir_masked_swizzle_amd(b, comp, .swizzle_mask = 0x1f | (0x18 << 10), .fetch_inactive = 1); + components[i] = nir_bcsel(b, cond, comp, compx); + } + } + } + + return nir_vec(b, components, num_comps); +} + bool radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_level, unsigned wave_size) { @@ -461,6 +583,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev progress = true; break; } + case nir_intrinsic_cmat_transpose: case nir_intrinsic_cmat_convert: { nir_deref_instr *dst_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr); nir_deref_instr *src_deref = nir_instr_as_deref(intr->src[1].ssa->parent_instr); @@ -468,13 +591,38 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev struct glsl_cmat_description src_desc = *glsl_get_cmat_description(src_deref->type); nir_def *src = radv_nir_load_cmat(&b, ¶ms, intr->src[1].ssa); - const nir_cmat_signed cmat_signed_mask = nir_intrinsic_cmat_signed_mask(intr); - const bool sat = nir_intrinsic_saturate(intr); + bool sat = false; + const bool transpose = intr->intrinsic == nir_intrinsic_cmat_transpose; - enum glsl_base_type dst_element_type = glsl_apply_signedness_to_base_type( - dst_desc.element_type, cmat_signed_mask & NIR_CMAT_RESULT_SIGNED); - enum glsl_base_type src_element_type = glsl_apply_signedness_to_base_type( - src_desc.element_type, cmat_signed_mask & NIR_CMAT_A_SIGNED); + enum glsl_cmat_use dst_use = dst_desc.use; + enum glsl_cmat_use src_use = src_desc.use; + + enum glsl_base_type dst_element_type = dst_desc.element_type; + enum glsl_base_type src_element_type = src_desc.element_type; + + if (transpose) { + /* NV_cmat2 only support acc -> b transpose, but we can handle any transpose except acc -> acc. */ + if (dst_use == GLSL_CMAT_USE_A) { + dst_use = GLSL_CMAT_USE_B; + } else if (dst_use == GLSL_CMAT_USE_B) { + dst_use = GLSL_CMAT_USE_A; + } else if (dst_use == GLSL_CMAT_USE_ACCUMULATOR) { + if (src_use == GLSL_CMAT_USE_A) + src_use = GLSL_CMAT_USE_B; + else if (src_use == GLSL_CMAT_USE_B) + src_use = GLSL_CMAT_USE_A; + else + unreachable("unsupported transpose"); + } + } else { + sat = nir_intrinsic_saturate(intr); + nir_cmat_signed cmat_signed_mask = nir_intrinsic_cmat_signed_mask(intr); + + dst_element_type = + glsl_apply_signedness_to_base_type(dst_element_type, cmat_signed_mask & NIR_CMAT_RESULT_SIGNED); + src_element_type = + glsl_apply_signedness_to_base_type(src_element_type, cmat_signed_mask & NIR_CMAT_A_SIGNED); + } unsigned dst_mul = radv_nir_cmat_length_mul(dst_desc, ¶ms); unsigned src_mul = radv_nir_cmat_length_mul(src_desc, ¶ms); @@ -488,6 +636,8 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev src = nir_vec(&b, components, src->num_components / scale); } + src = convert_use(&b, src, src_use, dst_use, ¶ms); + nir_def *ret = convert_base_type(&b, src, src_element_type, dst_element_type, sat); if (dst_mul > src_mul) {