diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index c77cd52a681..b7287bf6f57 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -4840,6 +4840,11 @@ typedef struct nir_lower_shader_calls_options { /* Stack alignment */ unsigned stack_alignment; + + /* Put loads from the stack as close as possible from where they're needed. + * You might want to disable combined_loads for best effects. + */ + bool localized_loads; } nir_lower_shader_calls_options; bool diff --git a/src/compiler/nir/nir_lower_shader_calls.c b/src/compiler/nir/nir_lower_shader_calls.c index 49871179908..31eb0e5c067 100644 --- a/src/compiler/nir/nir_lower_shader_calls.c +++ b/src/compiler/nir/nir_lower_shader_calls.c @@ -1685,6 +1685,96 @@ nir_opt_sort_and_pack_stack(nir_shader *shader, return true; } +/* Find the last block dominating all the uses of a SSA value. */ +static nir_block * +find_last_dominant_use_block(nir_function_impl *impl, nir_ssa_def *value) +{ + nir_foreach_block_reverse_safe(block, impl) { + bool fits = true; + + /* Store on the current block of the value */ + if (block == value->parent_instr->block) + return block; + + nir_foreach_if_use(src, value) { + nir_block *block_before_if = + nir_cf_node_as_block(nir_cf_node_prev(&src->parent_if->cf_node)); + if (!nir_block_dominates(block, block_before_if)) { + fits = false; + break; + } + } + if (!fits) + continue; + + nir_foreach_use(src, value) { + if (src->parent_instr->type == nir_instr_type_phi && + block == src->parent_instr->block) { + fits = false; + break; + } + + if (!nir_block_dominates(block, src->parent_instr->block)) { + fits = false; + break; + } + } + if (!fits) + continue; + + return block; + } + unreachable("Cannot find block"); +} + +/* Put the scratch loads in the branches where they're needed. */ +static bool +nir_opt_stack_loads(nir_shader *shader) +{ + bool progress = false; + + nir_foreach_function(func, shader) { + if (!func->impl) + continue; + + nir_metadata_require(func->impl, nir_metadata_dominance | + nir_metadata_block_index); + + bool func_progress = false; + nir_foreach_block_safe(block, func->impl) { + nir_foreach_instr_safe(instr, block) { + if (instr->type != nir_instr_type_intrinsic) + continue; + + nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); + if (intrin->intrinsic != nir_intrinsic_load_stack) + continue; + + nir_ssa_def *value = &intrin->dest.ssa; + nir_block *new_block = find_last_dominant_use_block(func->impl, value); + if (new_block == block) + continue; + + /* Move the scratch load in the new block, after the phis. */ + nir_instr_remove(instr); + nir_instr_insert(nir_before_block_after_phis(new_block), instr); + + func_progress = true; + } + } + + nir_metadata_preserve(func->impl, + func_progress ? (nir_metadata_block_index | + nir_metadata_dominance | + nir_metadata_loop_analysis) : + nir_metadata_all); + + progress |= func_progress; + } + + return progress; +} + /** Lower shader call instructions to split shaders. * * Shader calls can be split into an initial shader and a series of "resume" @@ -1785,8 +1875,15 @@ nir_lower_shader_calls(nir_shader *shader, for (unsigned i = 0; i < num_calls; i++) NIR_PASS_V(resume_shaders[i], nir_opt_remove_respills); - NIR_PASS_V(shader, nir_lower_stack_to_scratch, - options->address_format); + if (options->localized_loads) { + /* Once loads have been combined we can try to put them closer to where + * they're needed. + */ + for (unsigned i = 0; i < num_calls; i++) + NIR_PASS_V(resume_shaders[i], nir_opt_stack_loads); + } + + NIR_PASS_V(shader, nir_lower_stack_to_scratch, options->address_format); for (unsigned i = 0; i < num_calls; i++) { NIR_PASS_V(resume_shaders[i], nir_lower_stack_to_scratch, options->address_format);