radv: apply fneg/fabs modifiers to wmma

Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/34396>
This commit is contained in:
Georg Lehmann 2025-04-06 13:09:19 +02:00 committed by Marge Bot
parent 6d7e67d986
commit c3964e87f8
3 changed files with 113 additions and 0 deletions

View file

@ -71,6 +71,8 @@ bool radv_nir_lower_io_to_mem(struct radv_device *device, struct radv_shader_sta
bool radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_level, unsigned wave_size);
bool radv_nir_opt_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_level);
bool radv_nir_lower_draw_id_to_zero(nir_shader *shader);
bool radv_nir_remap_color_attachment(nir_shader *shader, const struct radv_graphics_state_key *gfx_state);

View file

@ -543,3 +543,111 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
return nir_progress(progress, func->impl, 0);
}
static bool
apply_component_mods(nir_scalar *comp, unsigned num_comps, unsigned stride, nir_op alu_op)
{
for (unsigned i = 0; i < num_comps; i++) {
nir_scalar s = comp[i * stride];
if (!nir_scalar_is_alu(s) || nir_scalar_alu_op(s) != alu_op)
return false;
}
for (unsigned i = 0; i < num_comps; i++)
comp[i * stride] = nir_scalar_chase_alu_src(comp[i * stride], 0);
return true;
}
/* Apply neg_lo/neg_hi modifiers to A/B and neg/abs to C. */
static bool
opt_cmat_modifiers(nir_builder *b, nir_intrinsic_instr *intrin, enum amd_gfx_level gfx_level, unsigned src_idx)
{
unsigned length_mul = src_idx == 2 && intrin->src[2].ssa->bit_size == 16 && gfx_level < GFX12 ? 2 : 1;
nir_scalar comp[NIR_MAX_VEC_COMPONENTS] = {0};
nir_def *src = intrin->src[src_idx].ssa;
for (unsigned i = 0; i < src->num_components; i += length_mul)
comp[i] = nir_scalar_resolved(src, i);
unsigned neg_lo = nir_intrinsic_neg_lo_amd(intrin);
unsigned neg_hi = nir_intrinsic_neg_hi_amd(intrin);
bool progress = false;
if (src_idx == 2) {
unsigned num_comp = src->num_components / length_mul;
if (apply_component_mods(comp, num_comp, length_mul, nir_op_fneg)) {
neg_lo ^= (~neg_hi) & BITFIELD_BIT(src_idx);
progress = true;
}
if (apply_component_mods(comp, num_comp, length_mul, nir_op_fabs)) {
neg_hi |= BITFIELD_BIT(src_idx);
progress = true;
}
} else {
unsigned num_comp = src->num_components / 2;
if (apply_component_mods(comp, num_comp, 2, nir_op_fneg)) {
neg_lo ^= BITFIELD_BIT(src_idx);
progress = true;
}
if (apply_component_mods(comp + 1, num_comp, 2, nir_op_fneg)) {
neg_hi ^= BITFIELD_BIT(src_idx);
progress = true;
}
}
if (!progress)
return false;
nir_intrinsic_set_neg_lo_amd(intrin, neg_lo);
nir_intrinsic_set_neg_hi_amd(intrin, neg_hi);
/* Avoid creating a new vec if we don't have to. */
nir_def *new_src = comp[0].def;
for (unsigned i = 0; i < src->num_components; i += length_mul) {
if (comp[i].def != new_src || comp[i].comp != i) {
new_src = NULL;
break;
}
}
if (!new_src) {
b->cursor = nir_before_instr(&intrin->instr);
if (length_mul > 1) {
nir_scalar undef = nir_get_scalar(nir_undef(b, 1, src->bit_size), 0);
for (unsigned i = 0; i < src->num_components; i += length_mul) {
for (unsigned j = 1; j < length_mul; j++)
comp[i + j] = undef;
}
}
new_src = nir_vec_scalars(b, comp, src->num_components);
}
nir_src_rewrite(&intrin->src[src_idx], new_src);
return true;
}
static bool
opt_cmat(nir_builder *b, nir_intrinsic_instr *intrin, void *data)
{
enum amd_gfx_level gfx_level = *(enum amd_gfx_level *)data;
if (intrin->intrinsic != nir_intrinsic_cmat_muladd_amd)
return false;
bool progress = false;
if (intrin->src[0].ssa->bit_size != 8) {
for (unsigned i = 0; i < 3; i++)
progress |= opt_cmat_modifiers(b, intrin, gfx_level, i);
}
return progress;
}
bool
radv_nir_opt_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_level)
{
return nir_shader_intrinsics_pass(shader, opt_cmat, nir_metadata_control_flow, &gfx_level);
}

View file

@ -539,6 +539,9 @@ radv_postprocess_nir(struct radv_device *device, const struct radv_graphics_stat
stage->nir, io_to_mem || lowered_ngg || stage->stage == MESA_SHADER_COMPUTE || stage->stage == MESA_SHADER_TASK,
gfx_level >= GFX8);
if (stage->nir->info.cs.has_cooperative_matrix)
NIR_PASS(_, stage->nir, radv_nir_opt_cooperative_matrix, gfx_level);
NIR_PASS(_, stage->nir, nir_lower_fp16_casts, nir_lower_fp16_split_fp64);
if (ac_nir_might_lower_bit_size(stage->nir)) {