radv/nir/lower_cmat: implement use conversions/transpose

This could potentially be improved using packed 32bit subgroup ops,
but what we actually care about (gfx12 ACC -> B) is free.

Reviewed-by: Rhys Perry <pendingchaos02@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/34793>
This commit is contained in:
Georg Lehmann 2025-05-02 14:01:54 +02:00 committed by Marge Bot
parent bdd2c7b9f2
commit 249ccc6b4c

View file

@ -204,6 +204,128 @@ convert_base_type(nir_builder *b, nir_def *src, enum glsl_base_type src_type, en
return nir_build_alu1(b, op, src);
}
static nir_def *
convert_use(nir_builder *b, nir_def *src, enum glsl_cmat_use src_use, enum glsl_cmat_use dst_use,
const lower_cmat_params *params)
{
if (src_use == dst_use)
return src;
if (params->gfx_level >= GFX12) {
if (src_use == GLSL_CMAT_USE_B && dst_use == GLSL_CMAT_USE_ACCUMULATOR)
return src;
if (src_use == GLSL_CMAT_USE_ACCUMULATOR && dst_use == GLSL_CMAT_USE_B)
return src;
}
if (src_use == GLSL_CMAT_USE_A && dst_use == GLSL_CMAT_USE_ACCUMULATOR) {
src = convert_use(b, src, GLSL_CMAT_USE_A, GLSL_CMAT_USE_B, params);
return convert_use(b, src, GLSL_CMAT_USE_B, GLSL_CMAT_USE_ACCUMULATOR, params);
} else if (src_use == GLSL_CMAT_USE_ACCUMULATOR && dst_use == GLSL_CMAT_USE_A) {
src = convert_use(b, src, GLSL_CMAT_USE_ACCUMULATOR, GLSL_CMAT_USE_B, params);
return convert_use(b, src, GLSL_CMAT_USE_B, GLSL_CMAT_USE_A, params);
}
nir_def *components[NIR_MAX_VEC_COMPONENTS] = {NULL};
unsigned num_comps = src->num_components;
for (unsigned i = 0; i < num_comps; i++)
components[i] = nir_channel(b, src, i);
if (src_use == GLSL_CMAT_USE_ACCUMULATOR && dst_use == GLSL_CMAT_USE_B) {
assert(params->gfx_level < GFX12);
nir_def *tmp[NIR_MAX_VEC_COMPONENTS];
if (params->wave_size == 64) {
nir_def *low_lanes = nir_inverse_ballot(b, 1, nir_imm_intN_t(b, UINT32_MAX, 64));
for (int i = 0; i < num_comps; i++) {
nir_def *comp = components[i];
nir_def *half_swap = nir_rotate(b, comp, nir_imm_int(b, 32), .cluster_size = 64);
tmp[i * 2] = nir_bcsel(b, low_lanes, comp, half_swap);
tmp[i * 2 + 1] = nir_bcsel(b, low_lanes, half_swap, comp);
}
num_comps *= 2;
memcpy(components, tmp, sizeof(components));
}
for (int i = 0; i < num_comps; i++) {
unsigned broadcast_low16 = 0xf;
unsigned broadcast_high16 = 0xf | (0x10 << 10);
tmp[i * 2] = nir_masked_swizzle_amd(b, components[i], .swizzle_mask = broadcast_low16, .fetch_inactive = 1);
tmp[i * 2 + 1] =
nir_masked_swizzle_amd(b, components[i], .swizzle_mask = broadcast_high16, .fetch_inactive = 1);
}
num_comps *= 2;
memcpy(components, tmp, sizeof(components));
assert(num_comps == 16);
} else if (src_use == GLSL_CMAT_USE_B && dst_use == GLSL_CMAT_USE_ACCUMULATOR) {
assert(params->gfx_level < GFX12);
assert(num_comps == 16);
for (unsigned keep32 = 0; keep32 < ((params->wave_size == 64) ? 2 : 1); keep32++) {
nir_def *ballot = nir_imm_intN_t(b, keep32 ? UINT32_MAX : 0xffff0000ffffull, params->wave_size);
nir_def *keep = nir_inverse_ballot(b, 1, ballot);
for (unsigned i = 0; i < num_comps; i++) {
components[i] = nir_bcsel(b, keep, components[i], components[i + 1]);
}
num_comps /= 2;
}
} else if ((src_use == GLSL_CMAT_USE_A && dst_use == GLSL_CMAT_USE_B) ||
(src_use == GLSL_CMAT_USE_B && dst_use == GLSL_CMAT_USE_A)) {
/* Transpose is a mess... */
for (unsigned x_mask = 1; x_mask < num_comps; x_mask *= 2) {
/* Use separate masks to always keep the masked_swizzle on the first source of v_cndmask. */
uint64_t mask = 0;
for (unsigned i = 0; i < 64; i += 2 * x_mask) {
mask |= BITFIELD64_MASK(x_mask) << i;
}
nir_def *even = nir_inverse_ballot(b, 1, nir_imm_intN_t(b, mask, params->wave_size));
nir_def *odd = nir_inverse_ballot(b, 1, nir_imm_intN_t(b, mask << x_mask, params->wave_size));
for (unsigned i = 0; i < num_comps; i += 2 * x_mask) {
for (unsigned j = 0; j < x_mask; j++) {
unsigned pos0 = i + j;
unsigned pos1 = pos0 + x_mask;
nir_def *comp0 = components[pos0];
nir_def *comp1 = components[pos1];
nir_def *comp0x =
nir_masked_swizzle_amd(b, comp0, .swizzle_mask = 0x1f | (x_mask << 10), .fetch_inactive = 1);
nir_def *comp1x =
nir_masked_swizzle_amd(b, comp1, .swizzle_mask = 0x1f | (x_mask << 10), .fetch_inactive = 1);
components[pos0] = nir_bcsel(b, even, comp0, comp1x);
components[pos1] = nir_bcsel(b, odd, comp1, comp0x);
}
}
}
assert(num_comps == 16 || params->gfx_level >= GFX12);
if (params->gfx_level >= GFX12) {
if (params->wave_size == 64) {
nir_def *cond = nir_inverse_ballot(b, 1, nir_imm_intN_t(b, 0xf0f0f0f00f0f0f0f, params->wave_size));
for (unsigned i = 0; i < num_comps; i++) {
nir_def *comp = components[i];
nir_def *compx = nir_rotate(b, comp, nir_imm_int(b, 32));
compx = nir_masked_swizzle_amd(b, compx, .swizzle_mask = 0x1f | (0x4 << 10), .fetch_inactive = 1);
components[i] = nir_bcsel(b, cond, comp, compx);
}
}
nir_def *cond = nir_inverse_ballot(b, 1, nir_imm_intN_t(b, 0xff0000ffff0000ff, params->wave_size));
for (unsigned i = 0; i < num_comps; i++) {
nir_def *comp = components[i];
nir_def *compx = nir_masked_swizzle_amd(b, comp, .swizzle_mask = 0x1f | (0x18 << 10), .fetch_inactive = 1);
components[i] = nir_bcsel(b, cond, comp, compx);
}
}
}
return nir_vec(b, components, num_comps);
}
bool
radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_level, unsigned wave_size)
{
@ -461,6 +583,7 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
progress = true;
break;
}
case nir_intrinsic_cmat_transpose:
case nir_intrinsic_cmat_convert: {
nir_deref_instr *dst_deref = nir_instr_as_deref(intr->src[0].ssa->parent_instr);
nir_deref_instr *src_deref = nir_instr_as_deref(intr->src[1].ssa->parent_instr);
@ -468,13 +591,38 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
struct glsl_cmat_description src_desc = *glsl_get_cmat_description(src_deref->type);
nir_def *src = radv_nir_load_cmat(&b, &params, intr->src[1].ssa);
const nir_cmat_signed cmat_signed_mask = nir_intrinsic_cmat_signed_mask(intr);
const bool sat = nir_intrinsic_saturate(intr);
bool sat = false;
const bool transpose = intr->intrinsic == nir_intrinsic_cmat_transpose;
enum glsl_base_type dst_element_type = glsl_apply_signedness_to_base_type(
dst_desc.element_type, cmat_signed_mask & NIR_CMAT_RESULT_SIGNED);
enum glsl_base_type src_element_type = glsl_apply_signedness_to_base_type(
src_desc.element_type, cmat_signed_mask & NIR_CMAT_A_SIGNED);
enum glsl_cmat_use dst_use = dst_desc.use;
enum glsl_cmat_use src_use = src_desc.use;
enum glsl_base_type dst_element_type = dst_desc.element_type;
enum glsl_base_type src_element_type = src_desc.element_type;
if (transpose) {
/* NV_cmat2 only support acc -> b transpose, but we can handle any transpose except acc -> acc. */
if (dst_use == GLSL_CMAT_USE_A) {
dst_use = GLSL_CMAT_USE_B;
} else if (dst_use == GLSL_CMAT_USE_B) {
dst_use = GLSL_CMAT_USE_A;
} else if (dst_use == GLSL_CMAT_USE_ACCUMULATOR) {
if (src_use == GLSL_CMAT_USE_A)
src_use = GLSL_CMAT_USE_B;
else if (src_use == GLSL_CMAT_USE_B)
src_use = GLSL_CMAT_USE_A;
else
unreachable("unsupported transpose");
}
} else {
sat = nir_intrinsic_saturate(intr);
nir_cmat_signed cmat_signed_mask = nir_intrinsic_cmat_signed_mask(intr);
dst_element_type =
glsl_apply_signedness_to_base_type(dst_element_type, cmat_signed_mask & NIR_CMAT_RESULT_SIGNED);
src_element_type =
glsl_apply_signedness_to_base_type(src_element_type, cmat_signed_mask & NIR_CMAT_A_SIGNED);
}
unsigned dst_mul = radv_nir_cmat_length_mul(dst_desc, &params);
unsigned src_mul = radv_nir_cmat_length_mul(src_desc, &params);
@ -488,6 +636,8 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
src = nir_vec(&b, components, src->num_components / scale);
}
src = convert_use(&b, src, src_use, dst_use, &params);
nir_def *ret = convert_base_type(&b, src, src_element_type, dst_element_type, sat);
if (dst_mul > src_mul) {