mesa/src/intel/compiler/jay/jay_nir.c
Kenneth Graunke 79fda8977e jay: Drop render target stores with unconditional discards
Surprisingly, this actually appears to come up.  Two Baldur's Gate 3
shaders optimized away to have unconditional "demote" in a shader with
no other side-effects, meaning no writes occur and we can eliminate the
entire program.  One of the shaders still did a fair amount of math to
produce color values that were never used.

We introduce a pass to detect store_render_target_intel intrinsics
where discard == true and eliminate them.  We then DCE and see if we
eliminated the entire program other than "demote" or "terminate" and
drop those too.  We then add back a Null RT store if needed.

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/41872>
2026-06-03 15:23:20 +00:00

505 lines
17 KiB
C

/*
* Copyright 2026 Intel Corporation
* SPDX-License-Identifier: MIT
*/
#include "compiler/brw/brw_eu.h"
#include "compiler/brw/brw_eu_defines.h"
#include "compiler/brw/brw_nir.h"
#include "compiler/brw/brw_private.h"
#include "compiler/intel_nir.h"
#include "jay_private.h"
#include "nir.h"
#include "nir_builder.h"
#include "nir_intrinsics.h"
#include "shader_enums.h"
#define JAY_NIR_SNAPSHOT(name) BRW_NIR_SNAPSHOT(name)
#define JAY_NIR_PASS(pass, ...) BRW_NIR_PASS(pass, ##__VA_ARGS__)
/*
* Jay-to-NIR relies on a careful indexing of defs: every 32-bit word has
* its own index. Vectors/64-bit use contiguous indices. We therefore run a
* modified version of nir_index_ssa_defs right before translating NIR->Jay.
*/
static bool
index_ssa_def_cb(nir_def *def, void *state)
{
unsigned *index = (unsigned *) state;
def->index = *index;
*index += DIV_ROUND_UP(def->num_components * MAX2(def->bit_size, 32), 32);
return true;
}
static void
nj_index_ssa_defs(nir_shader *nir)
{
nir_foreach_function_impl(impl, nir) {
/* The zero index means null in Jay, so start SSA indices at 1 */
unsigned index = 1;
nir_foreach_block_unstructured(block, impl) {
nir_foreach_instr(instr, block)
nir_foreach_def(instr, index_ssa_def_cb, &index);
}
impl->ssa_alloc = index;
}
}
static bool
lower_helper_invocation(nir_builder *b, nir_intrinsic_instr *intr, void *_)
{
if (intr->intrinsic != nir_intrinsic_load_helper_invocation)
return false;
/* TODO: Is this right for multisampling? */
b->cursor = nir_before_instr(&intr->instr);
nir_def *active =
nir_inot(b, nir_inverse_ballot(b, nir_load_dispatch_mask_intel(b)));
nir_def_replace(&intr->def, active);
return true;
}
static bool
lower_frag_coord(nir_builder *b, nir_intrinsic_instr *intr, void *simd_)
{
if (intr->intrinsic != nir_intrinsic_load_frag_coord &&
intr->intrinsic != nir_intrinsic_load_pixel_coord)
return false;
b->cursor = nir_before_instr(&intr->instr);
nir_def *c = nir_unpack_32_2x16(b, nir_load_pixel_coord_intel(b));
if (intr->intrinsic == nir_intrinsic_load_frag_coord) {
c = nir_vec4(b, nir_u2f32(b, nir_channel(b, c, 0)),
nir_u2f32(b, nir_channel(b, c, 1)), nir_load_frag_coord_z(b),
nir_frcp(b, nir_load_frag_coord_w_rcp(b)));
}
nir_def_replace(&intr->def, c);
return true;
}
static bool
jay_nir_lower_simd(nir_builder *b, nir_intrinsic_instr *intr, void *simd_)
{
b->cursor = nir_after_instr(&intr->instr);
unsigned *simd_width = simd_;
/* mask & -mask isolates the lowest set bit in the mask. */
if (intr->intrinsic == nir_intrinsic_elect) {
nir_def *mask = nir_ballot(b, 1, *simd_width, nir_imm_true(b));
mask = nir_iand(b, mask, nir_ineg(b, mask));
nir_def_replace(&intr->def, nir_inverse_ballot(b, mask));
return true;
}
/* Ballots must match the SIMD size */
if (intr->intrinsic == nir_intrinsic_ballot ||
intr->intrinsic == nir_intrinsic_ballot_relaxed) {
unsigned old_bitsize = intr->def.bit_size;
intr->def.bit_size = *simd_width;
nir_def *u2uN = nir_u2uN(b, &intr->def, old_bitsize);
nir_def_rewrite_uses_after(&intr->def, u2uN);
return true;
}
/* Just a constant */
if (intr->intrinsic == nir_intrinsic_load_simd_width_intel) {
nir_def_replace(&intr->def, nir_imm_int(b, *simd_width));
return true;
}
/* Note: we don't treat read_invocation specially because there's little
* benefit but doing so would require expensive uniformizing in some cases.
*/
if (intr->intrinsic != nir_intrinsic_shuffle &&
intr->intrinsic != nir_intrinsic_read_invocation)
return false;
nir_def *data = intr->src[0].ssa;
assert(data->num_components == 1 && data->bit_size <= 32 && "scalarized");
nir_def *offset_B = nir_imul_imm(b, intr->src[1].ssa, 4);
nir_def_replace(&intr->def, nir_shuffle_intel(b, 1, data, offset_B));
return true;
}
struct frag_out_ctx {
nir_def *colour[8], *depth, *stencil, *sample_mask;
};
static bool
collect_fragment_output(nir_builder *b, nir_intrinsic_instr *intr, void *ctx_)
{
struct frag_out_ctx *ctx = ctx_;
if (intr->intrinsic != nir_intrinsic_store_output)
return false;
unsigned wrmask = nir_intrinsic_write_mask(intr);
assert(nir_intrinsic_component(intr) == 0 && "component should be lowered");
assert(util_is_power_of_two_nonzero(wrmask + 1) &&
"complex writemasks should be lowered");
/* TODO: Optimize with write mask? */
gl_frag_result loc = nir_intrinsic_io_semantics(intr).location;
assert(!nir_intrinsic_io_semantics(intr).dual_source_blend_index && "todo");
assert(loc != FRAG_RESULT_DUAL_SRC_BLEND && "todo");
nir_def **out;
if (loc == FRAG_RESULT_COLOR) {
out = &ctx->colour[0];
} else if (loc >= FRAG_RESULT_DATA0 && loc <= FRAG_RESULT_DATA7) {
out = &ctx->colour[loc - FRAG_RESULT_DATA0];
} else if (loc == FRAG_RESULT_DEPTH) {
out = &ctx->depth;
} else if (loc == FRAG_RESULT_STENCIL) {
out = &ctx->stencil;
} else if (loc == FRAG_RESULT_SAMPLE_MASK) {
out = &ctx->sample_mask;
} else {
UNREACHABLE("invalid location");
}
assert((*out) == NULL && "each location written exactly once");
*out = intr->src[0].ssa;
nir_instr_remove(&intr->instr);
return true;
}
static void
insert_rt_store(nir_builder *b,
signed target,
nir_def *colour,
nir_def *src0_colour,
nir_def *depth,
nir_def *stencil,
nir_def *sample_mask)
{
bool null_rt = target < 0;
colour = nir_pad_vec4(b, colour ?: nir_undef(b, 4, 32));
if (null_rt) {
/* Even if we don't write a RT, we still need to write alpha for
* alpha-to-coverage and alpha testing. Optimize the other channels out.
*/
colour = nir_vector_insert_imm(b, nir_undef(b, 4, 32),
nir_channel(b, colour, 3), 3);
}
nir_def *src0_alpha = nir_channel_or_undef(b, src0_colour ?: colour, 3);
nir_def *disable = b->shader->info.fs.uses_discard ?
nir_is_helper_invocation(b, 1) :
nir_imm_false(b);
nir_store_render_target_intel(b, colour, src0_alpha, sample_mask, depth,
stencil, disable, .target = target);
}
static void
lower_fragment_outputs(nir_function_impl *impl,
const struct intel_device_info *devinfo,
unsigned nr_color_regions,
unsigned dispatch_width)
{
struct frag_out_ctx ctx = { { NULL } };
nir_function_intrinsics_pass(impl, collect_fragment_output,
nir_metadata_control_flow, &ctx);
nir_builder b_ = nir_builder_at(nir_after_impl(impl));
nir_builder *b = &b_;
assert(nr_color_regions <= ARRAY_SIZE(ctx.colour));
signed last = -1;
for (signed i = nr_color_regions - 1; i >= 0; --i) {
if (ctx.colour[i]) {
last = i;
break;
}
}
nir_def *undef = nir_undef(b, 1, 32);
for (signed i = 0; i < last; ++i) {
if (ctx.colour[i]) {
insert_rt_store(b, i, ctx.colour[i], i > 0 ? ctx.colour[0] : NULL,
ctx.depth ?: undef, ctx.stencil ?: undef,
ctx.sample_mask ?: undef);
}
}
insert_rt_store(b, last, last >= 0 ? ctx.colour[last] : NULL,
last > 0 ? ctx.colour[0] : NULL, ctx.depth ?: undef,
ctx.stencil ?: undef, ctx.sample_mask ?: undef);
}
/**
* If we've optimized the entire program down to only "demote"
* or "terminate" with no other instructions, delete it entirely.
*
* We're already not writing outputs and ending the program.
*/
static void
delete_solo_discard(nir_shader *nir)
{
nir_function_impl *impl = nir_shader_get_entrypoint(nir);
nir_block *start_block = nir_start_block(impl);
nir_instr *instr = nir_block_first_instr(start_block);
if (start_block != nir_impl_last_block(impl) ||
!exec_list_is_singular(&start_block->instr_list) ||
instr->type != nir_instr_type_intrinsic)
return;
nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
if (intrin->intrinsic == nir_intrinsic_demote ||
intrin->intrinsic == nir_intrinsic_terminate) {
nir_instr_remove(instr);
nir->info.fs.uses_discard = false;
}
}
/**
* Drop render target stores with unconditional discards.
*/
static bool
opt_unconditional_discards(nir_shader *nir)
{
nir_function_impl *impl = nir_shader_get_entrypoint(nir);
nir_block *block = nir_impl_last_block(impl);
bool progress = false;
bool any_remaining_rt_writes = false;
nir_foreach_instr_reverse_safe(instr, block) {
if (instr->type != nir_instr_type_intrinsic)
continue;
nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
if (intr->intrinsic != nir_intrinsic_store_render_target_intel)
continue;
nir_scalar discard = nir_scalar_resolved(intr->src[5].ssa, 0);
if (nir_scalar_is_const(discard) && nir_scalar_as_uint(discard) != 0) {
/* Drop store with unconditional discard */
nir_instr_remove(instr);
progress = true;
} else {
/* This RT store might actually happen */
any_remaining_rt_writes = true;
}
}
if (progress) {
nir_opt_dce_impl(impl);
delete_solo_discard(nir);
}
/* If we eliminated all RT stores, add a Null RT store to end the thread. */
if (!any_remaining_rt_writes) {
nir_builder b = nir_builder_at(nir_after_impl(impl));
nir_def *undef = nir_undef(&b, 1, 32);
insert_rt_store(&b, -1, NULL, NULL, undef, undef, undef);
}
return nir_progress(progress, impl, nir_metadata_control_flow);
}
unsigned
jay_process_nir(const struct intel_device_info *devinfo,
nir_shader *nir,
union brw_any_prog_data *prog_data,
union brw_any_prog_key *key,
debug_archiver *archiver)
{
enum mesa_shader_stage stage = nir->info.stage;
struct brw_compiler compiler = { .devinfo = devinfo };
unsigned nr_packed_regs = 0;
prog_data->base.ray_queries = nir->info.ray_queries;
prog_data->base.stage = stage;
// TODO: Make the driver do this?
// prog_data->base.source_hash = params->source_hash;
prog_data->base.total_shared = nir->info.shared_size;
/* TODO: Real heuristic */
bool do_simd32 = INTEL_SIMD(FS, 32);
do_simd32 &= stage == MESA_SHADER_COMPUTE || stage == MESA_SHADER_FRAGMENT;
/* The 'Render Target Write message' section of the docs says:
*
* "Output Stencil is not supported with SIMD16 Render Target
* Write Messages."
*
* Likewise for Xe2 at SIMD32.
*/
if (stage == MESA_SHADER_FRAGMENT &&
(nir->info.outputs_written & BITFIELD64_BIT(FRAG_RESULT_STENCIL)))
do_simd32 = false;
unsigned simd_width = do_simd32 ? (nir->info.api_subgroup_size ?: 32) : 16;
brw_pass_tracker pt_ = {
.nir = nir,
.key = &key->base,
.dispatch_width = simd_width,
.compiler = &compiler,
.archiver = archiver,
}, *pt = &pt_;
JAY_NIR_SNAPSHOT("first");
if (stage == MESA_SHADER_VERTEX) {
/* We only expect slot compaction to be disabled when using device
* generated commands, to provide an independent 3DSTATE_VERTEX_ELEMENTS
* programming. This should always be enabled together with VF component
* packing to minimize the size of the payload.
*/
assert(!key->vs.no_vf_slot_compaction || key->vs.vf_component_packing);
/* When using Primitive Replication for multiview, each view gets its own
* position slot.
*/
const uint32_t pos_slots =
(nir->info.per_view_outputs & VARYING_BIT_POS) ?
MAX2(1, util_bitcount(key->base.view_mask)) :
1;
/* Only position is allowed to be per-view */
assert(!(nir->info.per_view_outputs & ~VARYING_BIT_POS));
brw_compute_vue_map(devinfo, &prog_data->vue.vue_map,
nir->info.outputs_written, key->base.vue_layout,
pos_slots);
brw_nir_apply_key(pt, &key->base, simd_width);
prog_data->vs.inputs_read = nir->info.inputs_read;
prog_data->vs.double_inputs_read = nir->info.vs.double_inputs;
prog_data->vs.no_vf_slot_compaction = key->vs.no_vf_slot_compaction;
brw_nir_lower_vs_inputs(nir);
brw_nir_lower_vue_outputs(nir);
JAY_NIR_SNAPSHOT("after_lower_io");
memset(prog_data->vs.vf_component_packing, 0,
sizeof(prog_data->vs.vf_component_packing));
if (key->vs.vf_component_packing) {
nr_packed_regs = brw_nir_pack_vs_input(nir, &prog_data->vs);
}
/* Get constant offsets out of the way for proper clip/cull handling */
JAY_NIR_PASS(nir_lower_io_to_scalar, nir_var_shader_out, NULL, NULL);
JAY_NIR_PASS(nir_opt_constant_folding);
JAY_NIR_PASS(brw_nir_lower_deferred_urb_writes, devinfo,
&prog_data->vue.vue_map, 0, 0);
} else if (stage == MESA_SHADER_FRAGMENT) {
assert(key->fs.mesh_input == INTEL_NEVER && "todo");
brw_nir_apply_key(pt, &key->base, 32);
brw_nir_lower_fs_inputs(nir, devinfo, &key->fs);
brw_nir_lower_fs_outputs(nir);
JAY_NIR_SNAPSHOT("after_lower_io");
JAY_NIR_PASS(nir_lower_io_to_scalar, nir_var_shader_in, NULL, NULL);
if (!brw_can_coherent_fb_fetch(devinfo))
JAY_NIR_PASS(brw_nir_lower_fs_load_output, &key->fs);
JAY_NIR_PASS(nir_opt_frag_coord_to_pixel_coord);
JAY_NIR_PASS(nir_shader_intrinsics_pass, lower_frag_coord,
nir_metadata_control_flow, NULL);
JAY_NIR_PASS(nir_opt_barycentric, true);
JAY_NIR_PASS(nir_opt_constant_folding);
lower_fragment_outputs(nir_shader_get_entrypoint(nir), devinfo,
key->fs.nr_color_regions, simd_width);
JAY_NIR_PASS(nir_lower_helper_writes, true);
JAY_NIR_PASS(nir_lower_is_helper_invocation);
JAY_NIR_PASS(nir_shader_intrinsics_pass, lower_helper_invocation,
nir_metadata_control_flow, NULL);
if (key->fs.alpha_to_coverage != INTEL_NEVER) {
/* Run constant fold optimization in order to get the correct source
* offset to determine render target 0 store instruction in
* emit_alpha_to_coverage pass.
*/
JAY_NIR_PASS(nir_opt_constant_folding);
JAY_NIR_PASS(brw_nir_lower_alpha_to_coverage);
}
/* We want to run the standard opt loop after lowering but before
* gathering prog data, so we have accurate information about which system
* values are actually used (vs DCE'd away).
*/
brw_nir_optimize(pt);
NIR_PASS(_, nir, opt_unconditional_discards);
// TODO
// JAY_NIR_PASS(brw_nir_move_interpolation_to_top);
/* Do this before lower_fs_config_intel so that the pass has the right
* information.
*/
jay_populate_prog_data(devinfo, nir, prog_data, key, 0);
JAY_NIR_PASS(brw_nir_lower_fs_config_intel, &key->fs, &prog_data->fs);
} else {
brw_nir_apply_key(pt, &key->base, simd_width);
}
brw_postprocess_nir_opts(pt);
JAY_NIR_PASS(nir_shader_intrinsics_pass, jay_nir_lower_simd,
nir_metadata_control_flow, &simd_width);
JAY_NIR_PASS(nir_opt_algebraic_late);
JAY_NIR_PASS(intel_nir_opt_peephole_imul32x16);
/* Late postprocess while remaining in SSA */
/* Run fsign lowering again after the last time brw_nir_optimize is called.
* As is the case with conversion lowering (below), brw_nir_optimize can
* create additional fsign instructions.
*/
JAY_NIR_PASS(jay_nir_lower_fsign);
JAY_NIR_PASS(jay_nir_lower_bool);
JAY_NIR_PASS(nir_opt_cse);
JAY_NIR_PASS(nir_opt_dce);
JAY_NIR_PASS(jay_nir_opt_sel_zero);
/* Run nir_split_conversions only after the last tiem
* brw_nir_optimize is called. Various optimizations invoked there can
* rematerialize the conversions that the lowering pass eliminates.
*/
const nir_split_conversions_options split_conv_opts = {
.callback = intel_nir_split_conversions_cb,
};
JAY_NIR_PASS(nir_split_conversions, &split_conv_opts);
/* Do this only after the last opt_gcm. GCM will undo this lowering. */
if (stage == MESA_SHADER_FRAGMENT) {
JAY_NIR_PASS(intel_nir_lower_non_uniform_barycentric_at_sample);
}
JAY_NIR_PASS(nir_opt_constant_folding);
JAY_NIR_PASS(nir_lower_load_const_to_scalar);
JAY_NIR_PASS(nir_lower_all_phis_to_scalar);
JAY_NIR_PASS(nir_opt_copy_prop);
JAY_NIR_PASS(nir_opt_dce);
/* Jay requires LCSSA for correctness reading convergent loop-dependent
* values outside of a divergent loop. Converting to LCSSA inserts the
* required divergent 1-source phi after the loop.
*/
JAY_NIR_PASS(nir_convert_to_lcssa, true, true);
/* Run divergence analysis at the end */
nir_sweep(nir);
nj_index_ssa_defs(nir);
nir_divergence_analysis(nir);
if (stage != MESA_SHADER_FRAGMENT)
jay_populate_prog_data(devinfo, nir, prog_data, key, nr_packed_regs);
return simd_width;
}