diff --git a/src/compiler/nir/nir_lower_shader_calls.c b/src/compiler/nir/nir_lower_shader_calls.c index 8bf51c56acb..49871179908 100644 --- a/src/compiler/nir/nir_lower_shader_calls.c +++ b/src/compiler/nir/nir_lower_shader_calls.c @@ -1550,6 +1550,141 @@ nir_opt_trim_stack_values(nir_shader *shader) return progress; } +struct scratch_item { + unsigned old_offset; + unsigned new_offset; + unsigned bit_size; + unsigned num_components; + unsigned value; + unsigned call_idx; +}; + +static int +sort_scratch_item_by_size_and_value_id(const void *_item1, const void *_item2) +{ + const struct scratch_item *item1 = _item1; + const struct scratch_item *item2 = _item2; + + /* By ascending value_id */ + if (item1->bit_size == item2->bit_size) + return (int) item1->value - (int) item2->value; + + /* By descending size */ + return (int) item2->bit_size - (int) item1->bit_size; +} + +static bool +nir_opt_sort_and_pack_stack(nir_shader *shader, + unsigned start_call_scratch, + unsigned stack_alignment, + unsigned num_calls) +{ + nir_function_impl *impl = nir_shader_get_entrypoint(shader); + + void *mem_ctx = ralloc_context(NULL); + + struct hash_table_u64 *value_id_to_item = + _mesa_hash_table_u64_create(mem_ctx); + struct util_dynarray ops; + util_dynarray_init(&ops, mem_ctx); + + for (unsigned call_idx = 0; call_idx < num_calls; call_idx++) { + _mesa_hash_table_u64_clear(value_id_to_item); + util_dynarray_clear(&ops); + + /* Find all the stack load and their offset. */ + nir_foreach_block_safe(block, impl) { + nir_foreach_instr_safe(instr, block) { + if (instr->type != nir_instr_type_intrinsic) + continue; + + nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); + if (intrin->intrinsic != nir_intrinsic_load_stack) + continue; + + if (nir_intrinsic_call_idx(intrin) != call_idx) + continue; + + const unsigned value_id = nir_intrinsic_value_id(intrin); + nir_ssa_def *def = nir_instr_ssa_def(instr); + + assert(_mesa_hash_table_u64_search(value_id_to_item, + value_id) == NULL); + + struct scratch_item item = { + .old_offset = nir_intrinsic_base(intrin), + .bit_size = def->bit_size, + .num_components = def->num_components, + .value = value_id, + }; + + util_dynarray_append(&ops, struct scratch_item, item); + _mesa_hash_table_u64_insert(value_id_to_item, value_id, (void *)(uintptr_t)true); + } + } + + /* Sort scratch item by component size. */ + qsort(util_dynarray_begin(&ops), + util_dynarray_num_elements(&ops, struct scratch_item), + sizeof(struct scratch_item), + sort_scratch_item_by_size_and_value_id); + + + /* Reorder things on the stack */ + _mesa_hash_table_u64_clear(value_id_to_item); + + unsigned scratch_size = start_call_scratch; + util_dynarray_foreach(&ops, struct scratch_item, item) { + item->new_offset = ALIGN(scratch_size, item->bit_size / 8); + scratch_size = item->new_offset + (item->bit_size * item->num_components) / 8; + _mesa_hash_table_u64_insert(value_id_to_item, item->value, item); + } + shader->scratch_size = ALIGN(scratch_size, stack_alignment); + + /* Update offsets in the instructions */ + nir_foreach_block_safe(block, impl) { + nir_foreach_instr_safe(instr, block) { + if (instr->type != nir_instr_type_intrinsic) + continue; + + nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); + switch (intrin->intrinsic) { + case nir_intrinsic_load_stack: + case nir_intrinsic_store_stack: { + if (nir_intrinsic_call_idx(intrin) != call_idx) + continue; + + struct scratch_item *item = + _mesa_hash_table_u64_search(value_id_to_item, + nir_intrinsic_value_id(intrin)); + assert(item); + + nir_intrinsic_set_base(intrin, item->new_offset); + break; + } + + case nir_intrinsic_rt_trace_ray: + case nir_intrinsic_rt_execute_callable: + case nir_intrinsic_rt_resume: + if (nir_intrinsic_call_idx(intrin) != call_idx) + continue; + nir_intrinsic_set_stack_size(intrin, shader->scratch_size); + break; + + default: + break; + } + } + } + } + + ralloc_free(mem_ctx); + + nir_shader_preserve_all_metadata(shader); + + return true; +} + /** Lower shader call instructions to split shaders. * * Shader calls can be split into an initial shader and a series of "resume" @@ -1609,12 +1744,17 @@ nir_lower_shader_calls(nir_shader *shader, NIR_PASS(progress, shader, nir_opt_cse); } + /* Save the start point of the call stack in scratch */ + unsigned start_call_scratch = shader->scratch_size; + NIR_PASS_V(shader, spill_ssa_defs_and_lower_shader_calls, num_calls, options->stack_alignment); NIR_PASS_V(shader, nir_opt_remove_phis); NIR_PASS_V(shader, nir_opt_trim_stack_values); + NIR_PASS_V(shader, nir_opt_sort_and_pack_stack, + start_call_scratch, options->stack_alignment, num_calls); /* Make N copies of our shader */ nir_shader **resume_shaders = ralloc_array(mem_ctx, nir_shader *, num_calls);