nak/nir: Re-materialize load_const instructions in use blocks

This is useful both for correctness (to ensure that things we think are
constant stay constant) and it improves performance a bit by reducing
register pressure and avoiding spilling.

Pipeline-db stats:

    CodeSize: 29665072 -> 29437344 (-0.77%); split: -0.92%, +0.16%
    Number of GPRs: 157124 -> 156082 (-0.66%)
    SLM Size: 148900 -> 146436 (-1.65%)
    Static cycle count: 6840286 -> 6805711 (-0.51%); split: -0.98%, +0.47%
    Spills to memory: 177779 -> 173337 (-2.50%)
    Fills from memory: 177779 -> 173337 (-2.50%)
    Spills to reg: 17692 -> 16731 (-5.43%)
    Fills from reg: 12013 -> 11897 (-0.97%)
    Max warps/SM: 309128 -> 309456 (+0.11%)

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/33771>
This commit is contained in:
Faith Ekstrand 2025-02-26 16:00:42 -06:00 committed by Marge Bot
parent 8de37b142e
commit 8fffcdb18b
4 changed files with 113 additions and 0 deletions

View file

@ -33,6 +33,7 @@ libnak_c_files = files(
'nak_nir_lower_tex.c',
'nak_nir_lower_vtg_io.c',
'nak_nir_mark_lcssa_invariants.c',
'nak_nir_rematerialize_load_const.c',
'nak_nir_split_64bit_conversions.c',
)

View file

@ -1058,6 +1058,15 @@ nak_postprocess_nir(nir_shader *nir,
if (nak->sm < 70)
OPT(nir, nak_nir_split_64bit_conversions);
/* Re-materialize load_const instructions in the blocks that use them.
* This is both a register pressure optimization and a ensures correctness
* in the presence of all of the control flow modifications we're about to
* do. Without this, we can't rely on anything to be constant in NIR to
* NAK translation.
*/
if (OPT(nir, nak_nir_rematerialize_load_const))
OPT(nir, nir_opt_dce);
bool lcssa_progress = nir_convert_to_lcssa(nir, false, false);
if (nak->sm >= 75) {

View file

@ -0,0 +1,102 @@
/*
* Copyright © 2025 Collabora, Ltd.
* SPDX-License-Identifier: MIT
*/
#include "nak_private.h"
#include "nir_builder.h"
#include "util/hash_table.h"
struct remat_ctx {
struct hash_table *remap;
nir_block *block;
nir_builder b;
};
static bool
rematerialize_load_const(nir_src *src, void *_ctx)
{
struct remat_ctx *ctx = _ctx;
if (!nir_src_is_const(*src))
return true;
struct hash_entry *entry = _mesa_hash_table_search(ctx->remap, src->ssa);
if (entry != NULL) {
nir_src_rewrite(src, entry->data);
return true;
}
nir_load_const_instr *old_lc =
nir_instr_as_load_const(src->ssa->parent_instr);
nir_load_const_instr *new_lc =
nir_instr_as_load_const(nir_instr_clone(ctx->b.shader, &old_lc->instr));
nir_builder_instr_insert(&ctx->b, &new_lc->instr);
_mesa_hash_table_insert(ctx->remap, &old_lc->def, &new_lc->def);
nir_src_rewrite(src, &new_lc->def);
return true;
}
static bool
rematerialize_load_const_impl(nir_function_impl *impl)
{
bool progress = false;
struct remat_ctx ctx = {
.remap = _mesa_pointer_hash_table_create(NULL),
.b = nir_builder_create(impl),
};
nir_foreach_block(block, impl) {
_mesa_hash_table_clear(ctx.remap, NULL);
ctx.block = block;
nir_foreach_instr(instr, block) {
if (instr->type == nir_instr_type_phi)
continue;
ctx.b.cursor = nir_before_instr(instr);
nir_foreach_src(instr, rematerialize_load_const, &ctx);
}
ctx.b.cursor = nir_after_block_before_jump(block);
for (unsigned i = 0; i < ARRAY_SIZE(block->successors); i++) {
nir_block *succ = block->successors[i];
if (succ == NULL)
continue;
nir_foreach_instr(instr, succ) {
if (instr->type != nir_instr_type_phi)
break;
nir_phi_instr *phi = nir_instr_as_phi(instr);
nir_foreach_phi_src(src, phi) {
if (src->pred != block)
continue;
rematerialize_load_const(&src->src, &ctx);
}
}
}
}
_mesa_hash_table_destroy(ctx.remap, NULL);
return nir_progress(progress, impl, nir_metadata_control_flow |
nir_metadata_divergence);
}
bool
nak_nir_rematerialize_load_const(nir_shader *nir)
{
bool progress = false;
nir_foreach_function_impl(impl, nir)
progress |= rematerialize_load_const_impl(impl);
return progress;
}

View file

@ -249,6 +249,7 @@ enum nak_fs_out {
#define NAK_FS_OUT_COLOR(n) (NAK_FS_OUT_COLOR0 + (n) * 16)
bool nak_nir_rematerialize_load_const(nir_shader *nir);
bool nak_nir_mark_lcssa_invariants(nir_shader *nir);
bool nak_nir_split_64bit_conversions(nir_shader *nir);
bool nak_nir_lower_non_uniform_ldcx(nir_shader *nir);