From 5717f13dffe51431975e22bbec4e9328514c2f1b Mon Sep 17 00:00:00 2001 From: Lionel Landwerlin Date: Mon, 22 Aug 2022 10:33:40 +0300 Subject: [PATCH] nir/lower_shader_calls: add a pass to sort/pack values on the stack The previous pass shrinking values stored on the stack might have left some gaps on the stack (a vec4 turned into a vec3 for instance). This pass reorders variables on the stack, by component bit size and by ssa value number. The component size is useful to pack smaller values together. The ssa value number is also important because if we have 2 calls spilling the same values, then we can avoid reemiting the spillings if the values are stored in the same location. v2: Remove unused sorting function (Konstantin) Signed-off-by: Lionel Landwerlin Reviewed-by: Konstantin Seurer Part-of: --- src/compiler/nir/nir_lower_shader_calls.c | 140 ++++++++++++++++++++++ 1 file changed, 140 insertions(+) diff --git a/src/compiler/nir/nir_lower_shader_calls.c b/src/compiler/nir/nir_lower_shader_calls.c index 8bf51c56acb..49871179908 100644 --- a/src/compiler/nir/nir_lower_shader_calls.c +++ b/src/compiler/nir/nir_lower_shader_calls.c @@ -1550,6 +1550,141 @@ nir_opt_trim_stack_values(nir_shader *shader) return progress; } +struct scratch_item { + unsigned old_offset; + unsigned new_offset; + unsigned bit_size; + unsigned num_components; + unsigned value; + unsigned call_idx; +}; + +static int +sort_scratch_item_by_size_and_value_id(const void *_item1, const void *_item2) +{ + const struct scratch_item *item1 = _item1; + const struct scratch_item *item2 = _item2; + + /* By ascending value_id */ + if (item1->bit_size == item2->bit_size) + return (int) item1->value - (int) item2->value; + + /* By descending size */ + return (int) item2->bit_size - (int) item1->bit_size; +} + +static bool +nir_opt_sort_and_pack_stack(nir_shader *shader, + unsigned start_call_scratch, + unsigned stack_alignment, + unsigned num_calls) +{ + nir_function_impl *impl = nir_shader_get_entrypoint(shader); + + void *mem_ctx = ralloc_context(NULL); + + struct hash_table_u64 *value_id_to_item = + _mesa_hash_table_u64_create(mem_ctx); + struct util_dynarray ops; + util_dynarray_init(&ops, mem_ctx); + + for (unsigned call_idx = 0; call_idx < num_calls; call_idx++) { + _mesa_hash_table_u64_clear(value_id_to_item); + util_dynarray_clear(&ops); + + /* Find all the stack load and their offset. */ + nir_foreach_block_safe(block, impl) { + nir_foreach_instr_safe(instr, block) { + if (instr->type != nir_instr_type_intrinsic) + continue; + + nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); + if (intrin->intrinsic != nir_intrinsic_load_stack) + continue; + + if (nir_intrinsic_call_idx(intrin) != call_idx) + continue; + + const unsigned value_id = nir_intrinsic_value_id(intrin); + nir_ssa_def *def = nir_instr_ssa_def(instr); + + assert(_mesa_hash_table_u64_search(value_id_to_item, + value_id) == NULL); + + struct scratch_item item = { + .old_offset = nir_intrinsic_base(intrin), + .bit_size = def->bit_size, + .num_components = def->num_components, + .value = value_id, + }; + + util_dynarray_append(&ops, struct scratch_item, item); + _mesa_hash_table_u64_insert(value_id_to_item, value_id, (void *)(uintptr_t)true); + } + } + + /* Sort scratch item by component size. */ + qsort(util_dynarray_begin(&ops), + util_dynarray_num_elements(&ops, struct scratch_item), + sizeof(struct scratch_item), + sort_scratch_item_by_size_and_value_id); + + + /* Reorder things on the stack */ + _mesa_hash_table_u64_clear(value_id_to_item); + + unsigned scratch_size = start_call_scratch; + util_dynarray_foreach(&ops, struct scratch_item, item) { + item->new_offset = ALIGN(scratch_size, item->bit_size / 8); + scratch_size = item->new_offset + (item->bit_size * item->num_components) / 8; + _mesa_hash_table_u64_insert(value_id_to_item, item->value, item); + } + shader->scratch_size = ALIGN(scratch_size, stack_alignment); + + /* Update offsets in the instructions */ + nir_foreach_block_safe(block, impl) { + nir_foreach_instr_safe(instr, block) { + if (instr->type != nir_instr_type_intrinsic) + continue; + + nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); + switch (intrin->intrinsic) { + case nir_intrinsic_load_stack: + case nir_intrinsic_store_stack: { + if (nir_intrinsic_call_idx(intrin) != call_idx) + continue; + + struct scratch_item *item = + _mesa_hash_table_u64_search(value_id_to_item, + nir_intrinsic_value_id(intrin)); + assert(item); + + nir_intrinsic_set_base(intrin, item->new_offset); + break; + } + + case nir_intrinsic_rt_trace_ray: + case nir_intrinsic_rt_execute_callable: + case nir_intrinsic_rt_resume: + if (nir_intrinsic_call_idx(intrin) != call_idx) + continue; + nir_intrinsic_set_stack_size(intrin, shader->scratch_size); + break; + + default: + break; + } + } + } + } + + ralloc_free(mem_ctx); + + nir_shader_preserve_all_metadata(shader); + + return true; +} + /** Lower shader call instructions to split shaders. * * Shader calls can be split into an initial shader and a series of "resume" @@ -1609,12 +1744,17 @@ nir_lower_shader_calls(nir_shader *shader, NIR_PASS(progress, shader, nir_opt_cse); } + /* Save the start point of the call stack in scratch */ + unsigned start_call_scratch = shader->scratch_size; + NIR_PASS_V(shader, spill_ssa_defs_and_lower_shader_calls, num_calls, options->stack_alignment); NIR_PASS_V(shader, nir_opt_remove_phis); NIR_PASS_V(shader, nir_opt_trim_stack_values); + NIR_PASS_V(shader, nir_opt_sort_and_pack_stack, + start_call_scratch, options->stack_alignment, num_calls); /* Make N copies of our shader */ nir_shader **resume_shaders = ralloc_array(mem_ctx, nir_shader *, num_calls);