nir/divergence: calculate divergence without requiring LCSSA form

Reviewed-by: Rhys Perry <pendingchaos02@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/30787>
This commit is contained in:
Daniel Schürmann 2024-09-03 16:14:25 +02:00 committed by Marge Bot
parent d34d2f8fa8
commit 0eff03d385

View file

@ -109,6 +109,12 @@ nir_src_is_divergent(nir_src *src)
return false;
}
static inline bool
src_divergent(nir_src src, struct divergence_state *state)
{
return nir_src_is_divergent(&src);
}
static inline bool
src_invariant(nir_src *src, void *loop)
{
@ -138,7 +144,7 @@ visit_alu(nir_alu_instr *instr, struct divergence_state *state)
unsigned num_src = nir_op_infos[instr->op].num_inputs;
for (unsigned i = 0; i < num_src; i++) {
if (instr->src[i].src.ssa->divergent) {
if (src_divergent(instr->src[i].src, state)) {
instr->def.divergent = true;
return true;
}
@ -355,7 +361,7 @@ visit_intrinsic(nir_intrinsic_instr *instr, struct divergence_state *state)
break;
case nir_intrinsic_load_input:
case nir_intrinsic_load_per_primitive_input:
is_divergent = instr->src[0].ssa->divergent;
is_divergent = src_divergent(instr->src[0], state);
if (stage == MESA_SHADER_FRAGMENT) {
is_divergent |= !(options & nir_divergence_single_prim_per_subgroup);
@ -370,8 +376,8 @@ visit_intrinsic(nir_intrinsic_instr *instr, struct divergence_state *state)
}
break;
case nir_intrinsic_load_per_vertex_input:
is_divergent = instr->src[0].ssa->divergent ||
instr->src[1].ssa->divergent;
is_divergent = src_divergent(instr->src[0], state) ||
src_divergent(instr->src[1], state);
if (stage == MESA_SHADER_TESS_CTRL)
is_divergent |= !(options & nir_divergence_single_patch_per_tcs_subgroup);
if (stage == MESA_SHADER_TESS_EVAL)
@ -380,12 +386,12 @@ visit_intrinsic(nir_intrinsic_instr *instr, struct divergence_state *state)
is_divergent = true;
break;
case nir_intrinsic_load_input_vertex:
is_divergent = instr->src[1].ssa->divergent;
is_divergent = src_divergent(instr->src[1], state);
assert(stage == MESA_SHADER_FRAGMENT);
is_divergent |= !(options & nir_divergence_single_prim_per_subgroup);
break;
case nir_intrinsic_load_output:
is_divergent = instr->src[0].ssa->divergent;
is_divergent = src_divergent(instr->src[0], state);
switch (stage) {
case MESA_SHADER_TESS_CTRL:
is_divergent |= !(options & nir_divergence_single_patch_per_tcs_subgroup);
@ -406,16 +412,16 @@ visit_intrinsic(nir_intrinsic_instr *instr, struct divergence_state *state)
case nir_intrinsic_load_per_vertex_output:
/* TCS and NV_mesh_shader only (EXT_mesh_shader does not allow loading outputs). */
assert(stage == MESA_SHADER_TESS_CTRL || stage == MESA_SHADER_MESH);
is_divergent = instr->src[0].ssa->divergent ||
instr->src[1].ssa->divergent ||
is_divergent = src_divergent(instr->src[0], state) ||
src_divergent(instr->src[1], state) ||
(stage == MESA_SHADER_TESS_CTRL &&
!(options & nir_divergence_single_patch_per_tcs_subgroup));
break;
case nir_intrinsic_load_per_primitive_output:
/* NV_mesh_shader only (EXT_mesh_shader does not allow loading outputs). */
assert(stage == MESA_SHADER_MESH);
is_divergent = instr->src[0].ssa->divergent ||
instr->src[1].ssa->divergent;
is_divergent = src_divergent(instr->src[0], state) ||
src_divergent(instr->src[1], state);
break;
case nir_intrinsic_load_layer_id:
case nir_intrinsic_load_front_face:
@ -434,7 +440,7 @@ visit_intrinsic(nir_intrinsic_instr *instr, struct divergence_state *state)
break;
case nir_intrinsic_load_fs_input_interp_deltas:
assert(stage == MESA_SHADER_FRAGMENT);
is_divergent = instr->src[0].ssa->divergent;
is_divergent = src_divergent(instr->src[0], state);
is_divergent |= !(options & nir_divergence_single_prim_per_subgroup);
break;
case nir_intrinsic_load_instance_id:
@ -494,7 +500,8 @@ visit_intrinsic(nir_intrinsic_instr *instr, struct divergence_state *state)
case nir_intrinsic_inclusive_scan:
case nir_intrinsic_inclusive_scan_clusters_ir3: {
nir_op op = nir_intrinsic_reduction_op(instr);
is_divergent = instr->src[0].ssa->divergent || state->vertex_divergence;
is_divergent = src_divergent(instr->src[0], state) ||
state->vertex_divergence;
if (op != nir_op_umin && op != nir_op_imin && op != nir_op_fmin &&
op != nir_op_umax && op != nir_op_imax && op != nir_op_fmax &&
op != nir_op_iand && op != nir_op_ior)
@ -513,20 +520,23 @@ visit_intrinsic(nir_intrinsic_instr *instr, struct divergence_state *state)
case nir_intrinsic_load_ubo_vec4:
case nir_intrinsic_ldc_nv:
case nir_intrinsic_ldcx_nv:
is_divergent = (instr->src[0].ssa->divergent && (nir_intrinsic_access(instr) & ACCESS_NON_UNIFORM)) ||
instr->src[1].ssa->divergent;
is_divergent = (src_divergent(instr->src[0], state) &&
(nir_intrinsic_access(instr) & ACCESS_NON_UNIFORM)) ||
src_divergent(instr->src[1], state);
break;
case nir_intrinsic_load_ssbo:
case nir_intrinsic_load_ssbo_ir3:
is_divergent = (instr->src[0].ssa->divergent && (nir_intrinsic_access(instr) & ACCESS_NON_UNIFORM)) ||
instr->src[1].ssa->divergent ||
is_divergent = (src_divergent(instr->src[0], state) &&
(nir_intrinsic_access(instr) & ACCESS_NON_UNIFORM)) ||
src_divergent(instr->src[1], state) ||
load_may_tear(state, instr);
break;
case nir_intrinsic_load_shared:
case nir_intrinsic_load_shared_ir3:
is_divergent = instr->src[0].ssa->divergent || (options & nir_divergence_uniform_load_tears);
is_divergent = src_divergent(instr->src[0], state) ||
(options & nir_divergence_uniform_load_tears);
break;
case nir_intrinsic_load_global:
@ -540,7 +550,7 @@ visit_intrinsic(nir_intrinsic_instr *instr, struct divergence_state *state)
unsigned num_srcs = nir_intrinsic_infos[instr->intrinsic].num_srcs;
for (unsigned i = 0; i < num_srcs; i++) {
if (instr->src[i].ssa->divergent) {
if (src_divergent(instr->src[i], state)) {
is_divergent = true;
break;
}
@ -550,7 +560,8 @@ visit_intrinsic(nir_intrinsic_instr *instr, struct divergence_state *state)
case nir_intrinsic_get_ssbo_size:
case nir_intrinsic_deref_buffer_array_length:
is_divergent = instr->src[0].ssa->divergent && (nir_intrinsic_access(instr) & ACCESS_NON_UNIFORM);
is_divergent = src_divergent(instr->src[0], state) &&
(nir_intrinsic_access(instr) & ACCESS_NON_UNIFORM);
break;
case nir_intrinsic_image_samples_identical:
@ -559,16 +570,19 @@ visit_intrinsic(nir_intrinsic_instr *instr, struct divergence_state *state)
case nir_intrinsic_image_fragment_mask_load_amd:
case nir_intrinsic_image_deref_fragment_mask_load_amd:
case nir_intrinsic_bindless_image_fragment_mask_load_amd:
is_divergent = (instr->src[0].ssa->divergent && (nir_intrinsic_access(instr) & ACCESS_NON_UNIFORM)) ||
instr->src[1].ssa->divergent ||
is_divergent = (src_divergent(instr->src[0], state) &&
(nir_intrinsic_access(instr) & ACCESS_NON_UNIFORM)) ||
src_divergent(instr->src[1], state) ||
load_may_tear(state, instr);
break;
case nir_intrinsic_image_texel_address:
case nir_intrinsic_image_deref_texel_address:
case nir_intrinsic_bindless_image_texel_address:
is_divergent = (instr->src[0].ssa->divergent && (nir_intrinsic_access(instr) & ACCESS_NON_UNIFORM)) ||
instr->src[1].ssa->divergent || instr->src[2].ssa->divergent;
is_divergent = (src_divergent(instr->src[0], state) &&
(nir_intrinsic_access(instr) & ACCESS_NON_UNIFORM)) ||
src_divergent(instr->src[1], state) ||
src_divergent(instr->src[2], state);
break;
case nir_intrinsic_image_load:
@ -577,13 +591,16 @@ visit_intrinsic(nir_intrinsic_instr *instr, struct divergence_state *state)
case nir_intrinsic_image_sparse_load:
case nir_intrinsic_image_deref_sparse_load:
case nir_intrinsic_bindless_image_sparse_load:
is_divergent = (instr->src[0].ssa->divergent && (nir_intrinsic_access(instr) & ACCESS_NON_UNIFORM)) ||
instr->src[1].ssa->divergent || instr->src[2].ssa->divergent || instr->src[3].ssa->divergent ||
is_divergent = (src_divergent(instr->src[0], state) &&
(nir_intrinsic_access(instr) & ACCESS_NON_UNIFORM)) ||
src_divergent(instr->src[1], state) ||
src_divergent(instr->src[2], state) ||
src_divergent(instr->src[3], state) ||
load_may_tear(state, instr);
break;
case nir_intrinsic_optimization_barrier_vgpr_amd:
is_divergent = instr->src[0].ssa->divergent;
is_divergent = src_divergent(instr->src[0], state);
break;
/* Intrinsics with divergence depending on sources */
@ -661,7 +678,7 @@ visit_intrinsic(nir_intrinsic_instr *instr, struct divergence_state *state)
case nir_intrinsic_bindless_resource_ir3: {
unsigned num_srcs = nir_intrinsic_infos[instr->intrinsic].num_srcs;
for (unsigned i = 0; i < num_srcs; i++) {
if (instr->src[i].ssa->divergent) {
if (src_divergent(instr->src[i], state)) {
is_divergent = true;
break;
}
@ -678,7 +695,7 @@ visit_intrinsic(nir_intrinsic_instr *instr, struct divergence_state *state)
nir_resource_intel_non_uniform) != 0) {
unsigned num_srcs = nir_intrinsic_infos[instr->intrinsic].num_srcs;
for (unsigned i = 0; i < num_srcs; i++) {
if (instr->src[i].ssa->divergent) {
if (src_divergent(instr->src[i], state)) {
is_divergent = true;
break;
}
@ -687,8 +704,8 @@ visit_intrinsic(nir_intrinsic_instr *instr, struct divergence_state *state)
break;
case nir_intrinsic_shuffle:
is_divergent = instr->src[0].ssa->divergent &&
instr->src[1].ssa->divergent;
is_divergent = src_divergent(instr->src[0], state) &&
src_divergent(instr->src[1], state);
break;
/* Intrinsics which are always divergent */
@ -895,17 +912,17 @@ visit_tex(nir_tex_instr *instr, struct divergence_state *state)
case nir_tex_src_sampler_deref:
case nir_tex_src_sampler_handle:
case nir_tex_src_sampler_offset:
is_divergent |= instr->src[i].src.ssa->divergent &&
is_divergent |= src_divergent(instr->src[i].src, state) &&
instr->sampler_non_uniform;
break;
case nir_tex_src_texture_deref:
case nir_tex_src_texture_handle:
case nir_tex_src_texture_offset:
is_divergent |= instr->src[i].src.ssa->divergent &&
is_divergent |= src_divergent(instr->src[i].src, state) &&
instr->texture_non_uniform;
break;
default:
is_divergent |= instr->src[i].src.ssa->divergent;
is_divergent |= src_divergent(instr->src[i].src, state);
break;
}
}
@ -993,15 +1010,15 @@ visit_deref(nir_shader *shader, nir_deref_instr *deref,
break;
case nir_deref_type_array:
case nir_deref_type_ptr_as_array:
is_divergent = deref->arr.index.ssa->divergent;
is_divergent = src_divergent(deref->arr.index, state);
FALLTHROUGH;
case nir_deref_type_struct:
case nir_deref_type_array_wildcard:
is_divergent |= deref->parent.ssa->divergent;
is_divergent |= src_divergent(deref->parent, state);
break;
case nir_deref_type_cast:
is_divergent = !nir_variable_mode_is_uniform(deref->var->data.mode) ||
deref->parent.ssa->divergent;
src_divergent(deref->parent, state);
break;
}
@ -1149,7 +1166,7 @@ visit_if_merge_phi(nir_phi_instr *phi, bool if_cond_divergent, bool ignore_undef
unsigned defined_srcs = 0;
nir_foreach_phi_src(src, phi) {
/* if any source value is divergent, the resulting value is divergent */
if (src->src.ssa->divergent) {
if (nir_src_is_divergent(&src->src)) {
phi->def.divergent = true;
return true;
}
@ -1181,7 +1198,7 @@ visit_loop_header_phi(nir_phi_instr *phi, nir_block *preheader, bool divergent_c
nir_def *same = NULL;
nir_foreach_phi_src(src, phi) {
/* if any source value is divergent, the resulting value is divergent */
if (src->src.ssa->divergent) {
if (nir_src_is_divergent(&src->src)) {
phi->def.divergent = true;
return true;
}
@ -1223,7 +1240,7 @@ visit_loop_exit_phi(nir_phi_instr *phi, bool divergent_break)
/* if any source value is divergent, the resulting value is divergent */
nir_foreach_phi_src(src, phi) {
if (src->src.ssa->divergent) {
if (nir_src_is_divergent(&src->src)) {
phi->def.divergent = true;
return true;
}
@ -1236,13 +1253,14 @@ static bool
visit_if(nir_if *if_stmt, struct divergence_state *state)
{
bool progress = false;
bool cond_divergent = src_divergent(if_stmt->condition, state);
struct divergence_state then_state = *state;
then_state.divergent_loop_cf |= if_stmt->condition.ssa->divergent;
then_state.divergent_loop_cf |= cond_divergent;
progress |= visit_cf_list(&if_stmt->then_list, &then_state);
struct divergence_state else_state = *state;
else_state.divergent_loop_cf |= if_stmt->condition.ssa->divergent;
else_state.divergent_loop_cf |= cond_divergent;
progress |= visit_cf_list(&if_stmt->else_list, &else_state);
/* handle phis after the IF */
@ -1254,7 +1272,7 @@ visit_if(nir_if *if_stmt, struct divergence_state *state)
invariant && nir_foreach_src(&phi->instr, src_invariant, state->loop);
}
bool ignore_undef = state->options & nir_divergence_ignore_undef_if_phi_srcs;
progress |= visit_if_merge_phi(phi, if_stmt->condition.ssa->divergent, ignore_undef);
progress |= visit_if_merge_phi(phi, cond_divergent, ignore_undef);
}
/* join loop divergence information from both branch legs */
@ -1288,7 +1306,7 @@ visit_loop(nir_loop *loop, struct divergence_state *state)
phi->def.loop_invariant = false;
nir_foreach_phi_src(src, phi) {
if (src->pred == loop_preheader) {
phi->def.divergent = src->src.ssa->divergent;
phi->def.divergent = nir_src_is_divergent(&src->src);
break;
}
}
@ -1318,6 +1336,9 @@ visit_loop(nir_loop *loop, struct divergence_state *state)
loop_state.first_visit = false;
} while (repeat);
loop->divergent_continue = loop_state.divergent_loop_continue;
loop->divergent_break = loop_state.divergent_loop_break;
/* handle phis after the loop */
nir_foreach_phi(phi, nir_cf_node_cf_tree_next(&loop->cf_node)) {
if (state->first_visit) {
@ -1327,9 +1348,6 @@ visit_loop(nir_loop *loop, struct divergence_state *state)
progress |= visit_loop_exit_phi(phi, loop_state.divergent_loop_break);
}
loop->divergent_continue = loop_state.divergent_loop_continue;
loop->divergent_break = loop_state.divergent_loop_break;
return progress;
}