mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2026-05-03 20:48:08 +02:00
radv: apply fneg/fabs modifiers to wmma
Reviewed-by: Daniel Schürmann <daniel@schuermann.dev> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/34396>
This commit is contained in:
parent
6d7e67d986
commit
c3964e87f8
3 changed files with 113 additions and 0 deletions
|
|
@ -71,6 +71,8 @@ bool radv_nir_lower_io_to_mem(struct radv_device *device, struct radv_shader_sta
|
|||
|
||||
bool radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_level, unsigned wave_size);
|
||||
|
||||
bool radv_nir_opt_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_level);
|
||||
|
||||
bool radv_nir_lower_draw_id_to_zero(nir_shader *shader);
|
||||
|
||||
bool radv_nir_remap_color_attachment(nir_shader *shader, const struct radv_graphics_state_key *gfx_state);
|
||||
|
|
|
|||
|
|
@ -543,3 +543,111 @@ radv_nir_lower_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_lev
|
|||
|
||||
return nir_progress(progress, func->impl, 0);
|
||||
}
|
||||
|
||||
static bool
|
||||
apply_component_mods(nir_scalar *comp, unsigned num_comps, unsigned stride, nir_op alu_op)
|
||||
{
|
||||
for (unsigned i = 0; i < num_comps; i++) {
|
||||
nir_scalar s = comp[i * stride];
|
||||
if (!nir_scalar_is_alu(s) || nir_scalar_alu_op(s) != alu_op)
|
||||
return false;
|
||||
}
|
||||
|
||||
for (unsigned i = 0; i < num_comps; i++)
|
||||
comp[i * stride] = nir_scalar_chase_alu_src(comp[i * stride], 0);
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/* Apply neg_lo/neg_hi modifiers to A/B and neg/abs to C. */
|
||||
static bool
|
||||
opt_cmat_modifiers(nir_builder *b, nir_intrinsic_instr *intrin, enum amd_gfx_level gfx_level, unsigned src_idx)
|
||||
{
|
||||
unsigned length_mul = src_idx == 2 && intrin->src[2].ssa->bit_size == 16 && gfx_level < GFX12 ? 2 : 1;
|
||||
nir_scalar comp[NIR_MAX_VEC_COMPONENTS] = {0};
|
||||
nir_def *src = intrin->src[src_idx].ssa;
|
||||
|
||||
for (unsigned i = 0; i < src->num_components; i += length_mul)
|
||||
comp[i] = nir_scalar_resolved(src, i);
|
||||
|
||||
unsigned neg_lo = nir_intrinsic_neg_lo_amd(intrin);
|
||||
unsigned neg_hi = nir_intrinsic_neg_hi_amd(intrin);
|
||||
|
||||
bool progress = false;
|
||||
if (src_idx == 2) {
|
||||
unsigned num_comp = src->num_components / length_mul;
|
||||
if (apply_component_mods(comp, num_comp, length_mul, nir_op_fneg)) {
|
||||
neg_lo ^= (~neg_hi) & BITFIELD_BIT(src_idx);
|
||||
progress = true;
|
||||
}
|
||||
if (apply_component_mods(comp, num_comp, length_mul, nir_op_fabs)) {
|
||||
neg_hi |= BITFIELD_BIT(src_idx);
|
||||
progress = true;
|
||||
}
|
||||
} else {
|
||||
unsigned num_comp = src->num_components / 2;
|
||||
if (apply_component_mods(comp, num_comp, 2, nir_op_fneg)) {
|
||||
neg_lo ^= BITFIELD_BIT(src_idx);
|
||||
progress = true;
|
||||
}
|
||||
if (apply_component_mods(comp + 1, num_comp, 2, nir_op_fneg)) {
|
||||
neg_hi ^= BITFIELD_BIT(src_idx);
|
||||
progress = true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!progress)
|
||||
return false;
|
||||
|
||||
nir_intrinsic_set_neg_lo_amd(intrin, neg_lo);
|
||||
nir_intrinsic_set_neg_hi_amd(intrin, neg_hi);
|
||||
|
||||
/* Avoid creating a new vec if we don't have to. */
|
||||
nir_def *new_src = comp[0].def;
|
||||
for (unsigned i = 0; i < src->num_components; i += length_mul) {
|
||||
if (comp[i].def != new_src || comp[i].comp != i) {
|
||||
new_src = NULL;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if (!new_src) {
|
||||
b->cursor = nir_before_instr(&intrin->instr);
|
||||
if (length_mul > 1) {
|
||||
nir_scalar undef = nir_get_scalar(nir_undef(b, 1, src->bit_size), 0);
|
||||
for (unsigned i = 0; i < src->num_components; i += length_mul) {
|
||||
for (unsigned j = 1; j < length_mul; j++)
|
||||
comp[i + j] = undef;
|
||||
}
|
||||
}
|
||||
|
||||
new_src = nir_vec_scalars(b, comp, src->num_components);
|
||||
}
|
||||
|
||||
nir_src_rewrite(&intrin->src[src_idx], new_src);
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool
|
||||
opt_cmat(nir_builder *b, nir_intrinsic_instr *intrin, void *data)
|
||||
{
|
||||
enum amd_gfx_level gfx_level = *(enum amd_gfx_level *)data;
|
||||
|
||||
if (intrin->intrinsic != nir_intrinsic_cmat_muladd_amd)
|
||||
return false;
|
||||
|
||||
bool progress = false;
|
||||
|
||||
if (intrin->src[0].ssa->bit_size != 8) {
|
||||
for (unsigned i = 0; i < 3; i++)
|
||||
progress |= opt_cmat_modifiers(b, intrin, gfx_level, i);
|
||||
}
|
||||
|
||||
return progress;
|
||||
}
|
||||
|
||||
bool
|
||||
radv_nir_opt_cooperative_matrix(nir_shader *shader, enum amd_gfx_level gfx_level)
|
||||
{
|
||||
return nir_shader_intrinsics_pass(shader, opt_cmat, nir_metadata_control_flow, &gfx_level);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -539,6 +539,9 @@ radv_postprocess_nir(struct radv_device *device, const struct radv_graphics_stat
|
|||
stage->nir, io_to_mem || lowered_ngg || stage->stage == MESA_SHADER_COMPUTE || stage->stage == MESA_SHADER_TASK,
|
||||
gfx_level >= GFX8);
|
||||
|
||||
if (stage->nir->info.cs.has_cooperative_matrix)
|
||||
NIR_PASS(_, stage->nir, radv_nir_opt_cooperative_matrix, gfx_level);
|
||||
|
||||
NIR_PASS(_, stage->nir, nir_lower_fp16_casts, nir_lower_fp16_split_fp64);
|
||||
|
||||
if (ac_nir_might_lower_bit_size(stage->nir)) {
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue