nir/from_ssa: only consider divergence if requested

This pass used to unconditionally use divergence information
which forced the caller to either call divergence_analysis or
ensure that the divergence is properly reset.

Reviewed-by: Marek Olšák <marek.olsak@amd.com>
Reviewed-by: Alyssa Rosenzweig <alyssa@rosenzweig.io>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/33009>
This commit is contained in:
Daniel Schürmann 2025-01-13 16:38:25 +01:00 committed by Marge Bot
parent e7214b9446
commit f3be7ce01b
16 changed files with 41 additions and 44 deletions

View file

@ -1822,7 +1822,7 @@ v3d_attempt_compile(struct v3d_compile *c)
NIR_PASS(_, c->s, nir_lower_bool_to_int32);
NIR_PASS(_, c->s, nir_convert_to_lcssa, true, true);
NIR_PASS_V(c->s, nir_divergence_analysis);
NIR_PASS(_, c->s, nir_convert_from_ssa, true);
NIR_PASS(_, c->s, nir_convert_from_ssa, true, true);
struct nir_schedule_options schedule_options = {
/* Schedule for about half our register space, to enable more

View file

@ -6780,9 +6780,11 @@ nir_rewrite_uses_to_load_reg(struct nir_builder *b, nir_def *old,
/* If phi_webs_only is true, only convert SSA values involved in phi nodes to
* registers. If false, convert all values (even those not involved in a phi
* node) to registers.
* If consider_divergence is true, this pass will use divergence information
* in order to not coalesce copies from uniform to divergent registers.
*/
bool nir_convert_from_ssa(nir_shader *shader,
bool phi_webs_only);
bool phi_webs_only, bool consider_divergence);
bool nir_lower_phis_to_regs_block(nir_block *block);
bool nir_lower_ssa_defs_to_regs_block(nir_block *block);

View file

@ -39,6 +39,7 @@ struct from_ssa_state {
bool phi_webs_only;
struct hash_table *merge_node_table;
nir_instr *instr;
bool consider_divergence;
bool progress;
};
@ -157,7 +158,7 @@ get_merge_node(nir_def *def, struct from_ssa_state *state)
merge_set *set = rzalloc(state->dead_ctx, merge_set);
exec_list_make_empty(&set->nodes);
set->size = 1;
set->divergent = def->divergent;
set->divergent = state->consider_divergence && def->divergent;
merge_node *node = ralloc(state->dead_ctx, merge_node);
node->set = set;
@ -374,7 +375,7 @@ get_parallel_copy_at_end_of_block(nir_block *block)
* time because of potential back-edges in the CFG.
*/
static bool
isolate_phi_nodes_block(nir_shader *shader, nir_block *block, void *dead_ctx)
isolate_phi_nodes_block(nir_shader *shader, nir_block *block, struct from_ssa_state *state)
{
/* If we don't have any phis, then there's nothing for us to do. */
nir_phi_instr *last_phi = nir_block_last_phi_instr(block);
@ -397,13 +398,13 @@ isolate_phi_nodes_block(nir_shader *shader, nir_block *block, void *dead_ctx)
get_parallel_copy_at_end_of_block(src->pred);
assert(pcopy);
nir_parallel_copy_entry *entry = rzalloc(dead_ctx,
nir_parallel_copy_entry *entry = rzalloc(state->dead_ctx,
nir_parallel_copy_entry);
entry->dest_is_reg = false;
nir_def_init(&pcopy->instr, &entry->dest.def,
phi->def.num_components, phi->def.bit_size);
entry->dest.def.divergent = nir_src_is_divergent(&src->src);
entry->dest.def.divergent = state->consider_divergence && nir_src_is_divergent(&src->src);
/* We're adding a source to a live instruction so we need to use
* nir_instr_init_src()
@ -416,13 +417,13 @@ isolate_phi_nodes_block(nir_shader *shader, nir_block *block, void *dead_ctx)
nir_src_rewrite(&src->src, &entry->dest.def);
}
nir_parallel_copy_entry *entry = rzalloc(dead_ctx,
nir_parallel_copy_entry *entry = rzalloc(state->dead_ctx,
nir_parallel_copy_entry);
entry->dest_is_reg = false;
nir_def_init(&block_pcopy->instr, &entry->dest.def,
phi->def.num_components, phi->def.bit_size);
entry->dest.def.divergent = phi->def.divergent;
entry->dest.def.divergent = state->consider_divergence && phi->def.divergent;
nir_def_rewrite_uses(&phi->def, &entry->dest.def);
@ -806,14 +807,14 @@ copy_value_is_divergent(struct copy_value v)
}
static void
copy_values(nir_builder *b, struct copy_value dest, struct copy_value src)
copy_values(struct from_ssa_state *state, struct copy_value dest, struct copy_value src)
{
nir_def *val = src.is_reg ? nir_load_reg(b, src.ssa) : src.ssa;
nir_def *val = src.is_reg ? nir_load_reg(&state->builder, src.ssa) : src.ssa;
assert(!copy_value_is_divergent(src) || copy_value_is_divergent(dest));
assert(!state->consider_divergence || !copy_value_is_divergent(src) || copy_value_is_divergent(dest));
assert(dest.is_reg);
nir_store_reg(b, val, dest.ssa);
nir_store_reg(&state->builder, val, dest.ssa);
}
static void
@ -924,7 +925,7 @@ resolve_parallel_copy(nir_parallel_copy_instr *pcopy,
while (ready_idx >= 0) {
int b = ready[ready_idx--];
int a = pred[b];
copy_values(&state->builder, values[b], values[loc[a]]);
copy_values(state, values[b], values[loc[a]]);
/* b has been filled, mark it as not needing to be copied */
pred[b] = -1;
@ -934,8 +935,8 @@ resolve_parallel_copy(nir_parallel_copy_instr *pcopy,
* divergent), then we can't guarantee we won't need the convergent
* version of it again.
*/
if (copy_value_is_divergent(values[a]) ==
copy_value_is_divergent(values[b])) {
if (!state->consider_divergence ||
copy_value_is_divergent(values[a]) == copy_value_is_divergent(values[b])) {
/* If a needs to be filled... */
if (pred[a] != -1) {
/* If any other copies want a they can find it at b */
@ -984,13 +985,14 @@ resolve_parallel_copy(nir_parallel_copy_instr *pcopy,
} else {
reg = decl_reg_for_ssa_def(&state->builder, values[b].ssa);
}
set_reg_divergent(reg, copy_value_is_divergent(values[b]));
if (state->consider_divergence)
set_reg_divergent(reg, copy_value_is_divergent(values[b]));
values[num_vals] = (struct copy_value){
.is_reg = true,
.ssa = reg,
};
copy_values(&state->builder, values[num_vals], values[b]);
copy_values(state, values[num_vals], values[b]);
loc[b] = num_vals;
ready[++ready_idx] = b;
num_vals++;
@ -1029,7 +1031,7 @@ resolve_parallel_copies_block(nir_block *block, struct from_ssa_state *state)
static bool
nir_convert_from_ssa_impl(nir_function_impl *impl,
bool phi_webs_only)
bool phi_webs_only, bool consider_divergence)
{
nir_shader *shader = impl->function->shader;
@ -1039,6 +1041,7 @@ nir_convert_from_ssa_impl(nir_function_impl *impl,
state.dead_ctx = ralloc_context(NULL);
state.phi_webs_only = phi_webs_only;
state.merge_node_table = _mesa_pointer_hash_table_create(NULL);
state.consider_divergence = consider_divergence;
state.progress = false;
exec_list_make_empty(&state.dead_instrs);
@ -1047,7 +1050,7 @@ nir_convert_from_ssa_impl(nir_function_impl *impl,
}
nir_foreach_block(block, impl) {
isolate_phi_nodes_block(shader, block, state.dead_ctx);
isolate_phi_nodes_block(shader, block, &state);
}
/* Mark metadata as dirty before we ask for liveness analysis */
@ -1082,12 +1085,12 @@ nir_convert_from_ssa_impl(nir_function_impl *impl,
bool
nir_convert_from_ssa(nir_shader *shader,
bool phi_webs_only)
bool phi_webs_only, bool consider_divergence)
{
bool progress = false;
nir_foreach_function_impl(impl, shader) {
progress |= nir_convert_from_ssa_impl(impl, phi_webs_only);
progress |= nir_convert_from_ssa_impl(impl, phi_webs_only, consider_divergence);
}
return progress;

View file

@ -2955,7 +2955,7 @@ void
lp_build_nir_prepasses(struct nir_shader *nir)
{
NIR_PASS_V(nir, nir_convert_to_lcssa, true, true);
NIR_PASS_V(nir, nir_convert_from_ssa, true);
NIR_PASS_V(nir, nir_convert_from_ssa, true, false);
NIR_PASS_V(nir, nir_lower_locals_to_regs, 32);
NIR_PASS_V(nir, nir_remove_dead_derefs);
NIR_PASS_V(nir, nir_remove_dead_variables, nir_var_function_temp, NULL);

View file

@ -3985,7 +3985,7 @@ const void *nir_to_tgsi_options(struct nir_shader *s,
NIR_PASS_V(s, nir_opt_move, move_all);
NIR_PASS_V(s, nir_convert_from_ssa, true);
NIR_PASS_V(s, nir_convert_from_ssa, true, false);
NIR_PASS_V(s, nir_lower_vec_to_regs, ntt_vec_to_mov_writemask_cb, NULL);
/* locals_to_reg_intrinsics will leave dead derefs that are good to clean up.

View file

@ -1068,7 +1068,7 @@ emit_shader(struct etna_compile *c, unsigned *num_temps, unsigned *num_consts)
}
/* call directly to avoid validation (load_const don't pass validation at this point) */
nir_convert_from_ssa(shader, true);
nir_convert_from_ssa(shader, true, false);
nir_trivialize_registers(shader);
etna_ra_assign(c, shader);

View file

@ -1131,7 +1131,7 @@ ir2_nir_compile(struct ir2_context *ctx, bool binning)
OPT_V(ctx->nir, nir_opt_algebraic_late);
OPT_V(ctx->nir, nir_lower_alu_to_scalar, ir2_alu_to_scalar_filter_cb, NULL);
OPT_V(ctx->nir, nir_convert_from_ssa, true);
OPT_V(ctx->nir, nir_convert_from_ssa, true, false);
OPT_V(ctx->nir, nir_move_vec_src_uses_to_dest, false);
OPT_V(ctx->nir, nir_lower_vec_to_regs, NULL, NULL);

View file

@ -154,7 +154,7 @@ lima_program_optimize_vs_nir(struct nir_shader *s)
NIR_PASS_V(s, nir_copy_prop);
NIR_PASS_V(s, nir_opt_dce);
NIR_PASS_V(s, lima_nir_split_loads);
NIR_PASS_V(s, nir_convert_from_ssa, true);
NIR_PASS_V(s, nir_convert_from_ssa, true, false);
NIR_PASS_V(s, nir_opt_dce);
NIR_PASS_V(s, nir_remove_dead_variables, nir_var_function_temp, NULL);
nir_sweep(s);
@ -271,7 +271,7 @@ lima_program_optimize_fs_nir(struct nir_shader *s,
NIR_PASS_V(s, nir_copy_prop);
NIR_PASS_V(s, nir_opt_dce);
NIR_PASS_V(s, nir_convert_from_ssa, true);
NIR_PASS_V(s, nir_convert_from_ssa, true, false);
NIR_PASS_V(s, nir_remove_dead_variables, nir_var_function_temp, NULL);
NIR_PASS_V(s, nir_move_vec_src_uses_to_dest, false);

View file

@ -2345,7 +2345,7 @@ nir_to_rc(struct nir_shader *s, struct pipe_screen *screen)
if (s->info.stage == MESA_SHADER_VERTEX)
NIR_PASS_V(s, nir_opt_vectorize, ntr_should_vectorize_instr, NULL);
NIR_PASS_V(s, nir_convert_from_ssa, true);
NIR_PASS_V(s, nir_convert_from_ssa, true, false);
NIR_PASS_V(s, nir_lower_vec_to_regs, NULL, NULL);
/* locals_to_reg_intrinsics will leave dead derefs that are good to clean up.

View file

@ -865,7 +865,7 @@ r600_lower_and_optimize_nir(nir_shader *sh,
NIR_PASS_V(sh, nir_lower_bool_to_int32);
NIR_PASS_V(sh, nir_lower_locals_to_regs, 32);
NIR_PASS_V(sh, nir_convert_from_ssa, true);
NIR_PASS_V(sh, nir_convert_from_ssa, true, false);
NIR_PASS_V(sh, nir_opt_dce);
}

View file

@ -2334,7 +2334,7 @@ vc4_shader_ntq(struct vc4_context *vc4, enum qstage stage,
NIR_PASS_V(c->s, nir_lower_bool_to_int32);
NIR_PASS_V(c->s, nir_convert_from_ssa, true);
NIR_PASS_V(c->s, nir_convert_from_ssa, true, false);
NIR_PASS_V(c->s, nir_trivialize_registers);
if (VC4_DBG(NIR)) {

View file

@ -3884,15 +3884,7 @@ compile_module(struct zink_screen *screen, struct zink_shader *zs, nir_shader *n
struct zink_shader_info *sinfo = &zs->sinfo;
prune_io(nir);
switch (nir->info.stage) {
case MESA_SHADER_VERTEX:
case MESA_SHADER_TESS_EVAL:
case MESA_SHADER_GEOMETRY:
NIR_PASS_V(nir, nir_divergence_analysis);
break;
default: break;
}
NIR_PASS_V(nir, nir_convert_from_ssa, true);
NIR_PASS_V(nir, nir_convert_from_ssa, true, false);
if (zink_debug & (ZINK_DEBUG_NIR | ZINK_DEBUG_SPIRV))
nir_index_ssa_defs(nir_shader_get_entrypoint(nir));
@ -6629,7 +6621,7 @@ zink_shader_tcs_init(struct zink_screen *screen, struct zink_shader *zs, nir_sha
optimize_nir(nir, NULL, true);
NIR_PASS_V(nir, nir_remove_dead_variables, nir_var_function_temp, NULL);
NIR_PASS_V(nir, nir_convert_from_ssa, true);
NIR_PASS_V(nir, nir_convert_from_ssa, true, false);
*nir_ret = nir;
zink_shader_serialize_blob(nir, &zs->blob);

View file

@ -1934,7 +1934,7 @@ brw_postprocess_nir(nir_shader *nir, const struct brw_compiler *compiler,
NIR_PASS(_, nir, nir_convert_to_lcssa, true, true);
NIR_PASS_V(nir, nir_divergence_analysis);
OPT(nir_convert_from_ssa, true);
OPT(nir_convert_from_ssa, true, true);
OPT(nir_opt_dce);

View file

@ -1524,7 +1524,7 @@ elk_postprocess_nir(nir_shader *nir, const struct elk_compiler *compiler,
NIR_PASS(_, nir, nir_convert_to_lcssa, true, true);
NIR_PASS_V(nir, nir_divergence_analysis);
OPT(nir_convert_from_ssa, true);
OPT(nir_convert_from_ssa, true, true);
if (!is_scalar) {
OPT(nir_move_vec_src_uses_to_dest, true);

View file

@ -3506,7 +3506,7 @@ Converter::run()
NIR_PASS_V(nir, nir_lower_bit_size, Converter::lowerBitSizeCB, this);
NIR_PASS_V(nir, nir_divergence_analysis);
NIR_PASS_V(nir, nir_convert_from_ssa, true);
NIR_PASS_V(nir, nir_convert_from_ssa, true, true);
// Garbage collect dead instructions
nir_sweep(nir);

View file

@ -539,7 +539,7 @@ optimise_nir(nir_shader *nir, unsigned quirks, bool is_blend)
NIR_PASS(_, nir, nir_opt_move, move_all);
/* Take us out of SSA */
NIR_PASS(progress, nir, nir_convert_from_ssa, true);
NIR_PASS(progress, nir, nir_convert_from_ssa, true, false);
/* We are a vector architecture; write combine where possible */
NIR_PASS(progress, nir, nir_move_vec_src_uses_to_dest, false);