diff --git a/src/compiler/nir/nir_lower_shader_calls.c b/src/compiler/nir/nir_lower_shader_calls.c index 582a6d4061c..591308ba790 100644 --- a/src/compiler/nir/nir_lower_shader_calls.c +++ b/src/compiler/nir/nir_lower_shader_calls.c @@ -24,6 +24,7 @@ #include "nir.h" #include "nir_builder.h" #include "nir_phi_builder.h" +#include "util/u_dynarray.h" #include "util/u_math.h" static bool @@ -216,14 +217,143 @@ can_remat_ssa_def(nir_ssa_def *def, struct brw_bitset *remat) return can_remat_instr(def->parent_instr, remat); } -static nir_ssa_def * -remat_ssa_def(nir_builder *b, nir_ssa_def *def) +struct add_instr_data { + struct util_dynarray *buf; + struct brw_bitset *remat; +}; + +static bool +add_src_instr(nir_src *src, void *state) { - nir_instr *clone = nir_instr_clone(b->shader, def->parent_instr); + if (!src->is_ssa) + return false; + + struct add_instr_data *data = state; + if (BITSET_TEST(data->remat->set, src->ssa->index)) + return true; + + util_dynarray_foreach(data->buf, nir_instr *, instr_ptr) { + if (*instr_ptr == src->ssa->parent_instr) + return true; + } + + util_dynarray_append(data->buf, nir_instr *, src->ssa->parent_instr); + return true; +} + +static int +compare_instr_indexes(const void *_inst1, const void *_inst2) +{ + const nir_instr * const *inst1 = _inst1; + const nir_instr * const *inst2 = _inst2; + + return (*inst1)->index - (*inst2)->index; +} + +static bool +can_remat_chain_ssa_def(nir_ssa_def *def, struct brw_bitset *remat, struct util_dynarray *buf) +{ + assert(util_dynarray_num_elements(buf, nir_instr *) == 0); + + void *mem_ctx = ralloc_context(NULL); + + /* Add all the instructions involved in build this ssa_def */ + util_dynarray_append(buf, nir_instr *, def->parent_instr); + + unsigned idx = 0; + struct add_instr_data data = { + .buf = buf, + .remat = remat, + }; + while (idx < util_dynarray_num_elements(buf, nir_instr *)) { + nir_instr *instr = *util_dynarray_element(buf, nir_instr *, idx++); + if (!nir_foreach_src(instr, add_src_instr, &data)) + goto fail; + } + + /* Sort instructions by index */ + qsort(util_dynarray_begin(buf), + util_dynarray_num_elements(buf, nir_instr *), + sizeof(nir_instr *), + compare_instr_indexes); + + /* Create a temporary bitset with all values already + * rematerialized/rematerializable. We'll add to this bit set as we go + * through values that might not be in that set but that we can + * rematerialize. + */ + struct brw_bitset potential_remat = bitset_create(mem_ctx, remat->size); + memcpy(potential_remat.set, remat->set, BITSET_WORDS(remat->size) * sizeof(BITSET_WORD)); + + util_dynarray_foreach(buf, nir_instr *, instr_ptr) { + nir_ssa_def *instr_ssa_def = nir_instr_ssa_def(*instr_ptr); + + /* If already in the potential rematerializable, nothing to do. */ + if (BITSET_TEST(potential_remat.set, instr_ssa_def->index)) + continue; + + if (!can_remat_instr(*instr_ptr, &potential_remat)) + goto fail; + + /* All the sources are rematerializable and the instruction is also + * rematerializable, mark it as rematerializable too. + */ + BITSET_SET(potential_remat.set, instr_ssa_def->index); + } + + ralloc_free(mem_ctx); + + return true; + + fail: + util_dynarray_clear(buf); + ralloc_free(mem_ctx); + return false; +} + +static nir_ssa_def * +remat_ssa_def(nir_builder *b, nir_ssa_def *def, struct hash_table *remap_table) +{ + nir_instr *clone = nir_instr_clone_deep(b->shader, def->parent_instr, remap_table); nir_builder_instr_insert(b, clone); return nir_instr_ssa_def(clone); } +static nir_ssa_def * +remat_chain_ssa_def(nir_builder *b, struct util_dynarray *buf, + struct brw_bitset *remat, nir_ssa_def ***fill_defs, + unsigned call_idx, struct hash_table *remap_table) +{ + nir_ssa_def *last_def = NULL; + + util_dynarray_foreach(buf, nir_instr *, instr_ptr) { + nir_ssa_def *instr_ssa_def = nir_instr_ssa_def(*instr_ptr); + unsigned ssa_index = instr_ssa_def->index; + + if (fill_defs[ssa_index] != NULL && + fill_defs[ssa_index][call_idx] != NULL) + continue; + + /* Clone the instruction we want to rematerialize */ + nir_ssa_def *clone_ssa_def = remat_ssa_def(b, instr_ssa_def, remap_table); + + if (fill_defs[ssa_index] == NULL) { + fill_defs[ssa_index] = + rzalloc_array(fill_defs, nir_ssa_def *, remat->size); + } + + /* Add the new ssa_def to the list fill_defs and flag it as + * rematerialized + */ + fill_defs[ssa_index][call_idx] = last_def = clone_ssa_def; + BITSET_SET(remat->set, ssa_index); + + _mesa_hash_table_insert(remap_table, instr_ssa_def, last_def); + } + + return last_def; +} + struct pbv_array { struct nir_phi_builder_value **arr; unsigned len; @@ -317,7 +447,8 @@ spill_ssa_defs_and_lower_shader_calls(nir_shader *shader, uint32_t num_calls, nir_metadata_require(impl, nir_metadata_live_ssa_defs | nir_metadata_dominance | - nir_metadata_block_index); + nir_metadata_block_index | + nir_metadata_instr_index); void *mem_ctx = ralloc_context(shader); @@ -342,6 +473,10 @@ spill_ssa_defs_and_lower_shader_calls(nir_shader *shader, uint32_t num_calls, /* For each call instruction, the block index of the block it lives in */ uint32_t *call_block_indices = rzalloc_array(mem_ctx, uint32_t, num_calls); + /* Remap table when rebuilding instructions out of fill operations */ + struct hash_table *trivial_remap_table = + _mesa_pointer_hash_table_create(mem_ctx); + /* Walk the call instructions and fetch the liveness set and block index * for each one. We need to do this before we start modifying the shader * so that liveness doesn't complain that it's been invalidated. Don't @@ -382,6 +517,7 @@ spill_ssa_defs_and_lower_shader_calls(nir_shader *shader, uint32_t num_calls, if (def != NULL) { if (can_remat_ssa_def(def, &trivial_remat)) { add_ssa_def_to_bitset(def, &trivial_remat); + _mesa_hash_table_insert(trivial_remap_table, def, def); } else { spill_defs[def->index] = def; } @@ -392,6 +528,9 @@ spill_ssa_defs_and_lower_shader_calls(nir_shader *shader, uint32_t num_calls, const BITSET_WORD *live = call_live[call_idx]; + struct hash_table *remap_table = + _mesa_hash_table_clone(trivial_remap_table, mem_ctx); + /* Make a copy of trivial_remat that we'll update as we crawl through * the live SSA defs and unspill them. */ @@ -404,6 +543,12 @@ spill_ssa_defs_and_lower_shader_calls(nir_shader *shader, uint32_t num_calls, before.cursor = nir_before_instr(instr); after.cursor = nir_after_instr(instr); + /* Array used to hold all the values needed to rematerialize a live + * value. + */ + struct util_dynarray remat_chain; + util_dynarray_init(&remat_chain, mem_ctx); + unsigned offset = shader->scratch_size; for (unsigned w = 0; w < live_words; w++) { BITSET_WORD spill_mask = live[w] & ~trivial_remat.set[w]; @@ -413,7 +558,8 @@ spill_ssa_defs_and_lower_shader_calls(nir_shader *shader, uint32_t num_calls, unsigned index = w * BITSET_WORDBITS + i; assert(index < num_ssa_defs); - nir_ssa_def *def = spill_defs[index]; + def = spill_defs[index]; + nir_ssa_def *original_def = def, *new_def; if (can_remat_ssa_def(def, &remat)) { /* If this SSA def is re-materializable or based on other * things we've already spilled, re-materialize it rather @@ -421,7 +567,12 @@ spill_ssa_defs_and_lower_shader_calls(nir_shader *shader, uint32_t num_calls, * re-materializable won't even get here because we take * those into account in spill_mask above. */ - def = remat_ssa_def(&after, def); + new_def = remat_ssa_def(&after, def, remap_table); + } else if (can_remat_chain_ssa_def(def, &remat, &remat_chain)) { + new_def = remat_chain_ssa_def(&after, &remat_chain, &remat, + fill_defs, call_idx, + remap_table); + util_dynarray_clear(&remat_chain); } else { bool is_bool = def->bit_size == 1; if (is_bool) @@ -430,12 +581,12 @@ spill_ssa_defs_and_lower_shader_calls(nir_shader *shader, uint32_t num_calls, const unsigned comp_size = def->bit_size / 8; offset = ALIGN(offset, comp_size); - def = spill_fill(&before, &after, def, - index, call_idx, - offset, stack_alignment); + new_def = spill_fill(&before, &after, def, + index, call_idx, + offset, stack_alignment); if (is_bool) - def = nir_b2b1(&after, def); + new_def = nir_b2b1(&after, new_def); offset += def->num_components * comp_size; } @@ -451,9 +602,10 @@ spill_ssa_defs_and_lower_shader_calls(nir_shader *shader, uint32_t num_calls, */ if (fill_defs[index] == NULL) { fill_defs[index] = - rzalloc_array(mem_ctx, nir_ssa_def *, num_calls); + rzalloc_array(fill_defs, nir_ssa_def *, num_calls); } - fill_defs[index][call_idx] = def; + fill_defs[index][call_idx] = new_def; + _mesa_hash_table_insert(remap_table, original_def, new_def); } }