diff --git a/src/compiler/nir/nir.c b/src/compiler/nir/nir.c index 357f16d6293..513fd04f36f 100644 --- a/src/compiler/nir/nir.c +++ b/src/compiler/nir/nir.c @@ -1555,6 +1555,7 @@ nir_def_init(nir_instr *instr, nir_def *def, def->num_components = num_components; def->bit_size = bit_size; def->divergent = true; /* This is the safer default */ + def->loop_invariant = false; if (instr->block) { nir_function_impl *impl = diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index 94999f4cd48..7a781b7fefb 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -1059,6 +1059,13 @@ typedef struct nir_def { * invocations of the shader. This is set by nir_divergence_analysis. */ bool divergent; + + /** + * True if this SSA value is loop invariant w.r.t. the innermost parent + * loop. This is set by nir_divergence_analysis and used to determine + * the divergence of a nir_src. + */ + bool loop_invariant; } nir_def; struct nir_src; diff --git a/src/compiler/nir/nir_divergence_analysis.c b/src/compiler/nir/nir_divergence_analysis.c index 056ce6e7111..21269dc7b09 100644 --- a/src/compiler/nir/nir_divergence_analysis.c +++ b/src/compiler/nir/nir_divergence_analysis.c @@ -40,6 +40,7 @@ struct divergence_state { const gl_shader_stage stage; nir_shader *shader; nir_divergence_options options; + nir_loop *loop; /* Whether the caller requested vertex divergence (meaning between vertices * of the same primitive) instead of subgroup invocation divergence @@ -75,6 +76,26 @@ nir_src_is_divergent(nir_src *src) return src->ssa->divergent; } +static inline bool +src_invariant(nir_src *src, void *loop) +{ + nir_block *first_block = nir_loop_first_block(loop); + + /* Invariant if SSA is defined before the current loop. */ + if (src->ssa->parent_instr->block->index < first_block->index) + return true; + + if (!src->ssa->loop_invariant) + return false; + + /* The value might be defined in a nested loop. */ + nir_cf_node *cf_node = src->ssa->parent_instr->block->cf_node.parent; + while (cf_node->type != nir_cf_node_loop) + cf_node = cf_node->parent; + + return nir_cf_node_as_loop(cf_node) == loop; +} + static bool visit_alu(nir_alu_instr *instr, struct divergence_state *state) { @@ -986,12 +1007,41 @@ visit_jump(nir_jump_instr *jump, struct divergence_state *state) } static bool -set_ssa_def_not_divergent(nir_def *def, UNUSED void *_state) +set_ssa_def_not_divergent(nir_def *def, void *invariant) { def->divergent = false; + def->loop_invariant = *(bool *)invariant; return true; } +static bool +instr_is_loop_invariant(nir_instr *instr, struct divergence_state *state) +{ + if (!state->loop) + return false; + + switch (instr->type) { + case nir_instr_type_load_const: + case nir_instr_type_undef: + case nir_instr_type_debug_info: + case nir_instr_type_jump: + return true; + case nir_instr_type_intrinsic: + if (!nir_intrinsic_can_reorder(nir_instr_as_intrinsic(instr))) + return false; + FALLTHROUGH; + case nir_instr_type_alu: + case nir_instr_type_deref: + case nir_instr_type_tex: + return nir_foreach_src(instr, src_invariant, state->loop); + case nir_instr_type_phi: + case nir_instr_type_call: + case nir_instr_type_parallel_copy: + default: + unreachable("NIR divergence analysis: Unsupported instruction type."); + } +} + static bool update_instr_divergence(nir_instr *instr, struct divergence_state *state) { @@ -1029,8 +1079,10 @@ visit_block(nir_block *block, struct divergence_state *state) if (instr->type == nir_instr_type_phi) continue; - if (state->first_visit) - nir_foreach_def(instr, set_ssa_def_not_divergent, NULL); + if (state->first_visit) { + bool invariant = instr_is_loop_invariant(instr, state); + nir_foreach_def(instr, set_ssa_def_not_divergent, &invariant); + } if (instr->type == nir_instr_type_jump) { has_changed |= visit_jump(nir_instr_as_jump(instr), state); @@ -1161,9 +1213,13 @@ visit_if(nir_if *if_stmt, struct divergence_state *state) progress |= visit_cf_list(&if_stmt->else_list, &else_state); /* handle phis after the IF */ + bool invariant = state->loop && src_invariant(&if_stmt->condition, state->loop); nir_foreach_phi(phi, nir_cf_node_cf_tree_next(&if_stmt->cf_node)) { - if (state->first_visit) + if (state->first_visit) { phi->def.divergent = false; + phi->def.loop_invariant = + invariant && nir_foreach_src(&phi->instr, src_invariant, state->loop); + } bool ignore_undef = state->options & nir_divergence_ignore_undef_if_phi_srcs; progress |= visit_if_merge_phi(phi, if_stmt->condition.ssa->divergent, ignore_undef); } @@ -1196,6 +1252,7 @@ visit_loop(nir_loop *loop, struct divergence_state *state) if (!state->first_visit && phi->def.divergent) continue; + phi->def.loop_invariant = false; nir_foreach_phi_src(src, phi) { if (src->pred == loop_preheader) { phi->def.divergent = src->src.ssa->divergent; @@ -1207,6 +1264,7 @@ visit_loop(nir_loop *loop, struct divergence_state *state) /* setup loop state */ struct divergence_state loop_state = *state; + loop_state.loop = loop; loop_state.divergent_loop_cf = false; loop_state.divergent_loop_continue = false; loop_state.divergent_loop_break = false; @@ -1229,8 +1287,10 @@ visit_loop(nir_loop *loop, struct divergence_state *state) /* handle phis after the loop */ nir_foreach_phi(phi, nir_cf_node_cf_tree_next(&loop->cf_node)) { - if (state->first_visit) + if (state->first_visit) { phi->def.divergent = false; + phi->def.loop_invariant = false; + } progress |= visit_loop_exit_phi(phi, loop_state.divergent_loop_break); } @@ -1273,6 +1333,7 @@ nir_divergence_analysis_impl(nir_function_impl *impl, nir_divergence_options opt .stage = impl->function->shader->info.stage, .shader = impl->function->shader, .options = options, + .loop = NULL, .divergent_loop_cf = false, .divergent_loop_continue = false, .divergent_loop_break = false, @@ -1305,6 +1366,7 @@ nir_vertex_divergence_analysis(nir_shader *shader) .stage = shader->info.stage, .shader = shader, .options = shader->options->divergence_analysis_options, + .loop = NULL, .vertex_divergence = true, .first_visit = true, }; diff --git a/src/compiler/nir/nir_serialize.c b/src/compiler/nir/nir_serialize.c index 1fd05142e87..2735683dd08 100644 --- a/src/compiler/nir/nir_serialize.c +++ b/src/compiler/nir/nir_serialize.c @@ -479,10 +479,10 @@ read_src(read_ctx *ctx, nir_src *src) union packed_def { uint8_t u8; struct { - uint8_t _pad : 1; uint8_t num_components : 3; uint8_t bit_size : 3; uint8_t divergent : 1; + uint8_t loop_invariant : 1; }; }; @@ -608,6 +608,7 @@ write_def(write_ctx *ctx, const nir_def *def, union packed_instr header, encode_num_components_in_3bits(def->num_components); pdef.bit_size = encode_bit_size_3bits(def->bit_size); pdef.divergent = def->divergent; + pdef.loop_invariant = def->loop_invariant; header.any.def = pdef.u8; /* Check if the current ALU instruction has the same header as the previous @@ -670,6 +671,7 @@ read_def(read_ctx *ctx, nir_def *def, nir_instr *instr, num_components = decode_num_components_in_3bits(pdef.num_components); nir_def_init(instr, def, num_components, bit_size); def->divergent = pdef.divergent; + def->loop_invariant = pdef.loop_invariant; read_add_object(ctx, def); } @@ -1254,6 +1256,7 @@ read_load_const(read_ctx *ctx, union packed_instr header) nir_load_const_instr_create(ctx->nir, header.load_const.last_component + 1, decode_bit_size_3bits(header.load_const.bit_size)); lc->def.divergent = false; + lc->def.loop_invariant = true; switch (header.load_const.packing) { case load_const_scalar_hi_19bits: @@ -1348,6 +1351,7 @@ read_ssa_undef(read_ctx *ctx, union packed_instr header) decode_bit_size_3bits(header.undef.bit_size)); undef->def.divergent = false; + undef->def.loop_invariant = true; read_add_object(ctx, &undef->def); return undef;