nir/lower_shader_calls: add a pass to trim scratch values

For example, if we store to scratch a vec4 but only a subset of
components are used after the load operation.

v2: Use nir_intrinsic_write_mask (Konstantin)
    Use u_foreach_bit() instead of u_bit_scan() (Konstantin)
    Fix mask building loop (Konstantin)

v3: Fix reswizzle (Konstantin)

Signed-off-by: Lionel Landwerlin <lionel.g.landwerlin@intel.com>
Reviewed-by: Konstantin Seurer <konstantin.seurer@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/16556>
This commit is contained in:
Lionel Landwerlin 2022-08-17 16:23:51 +03:00 committed by Marge Bot
parent 1d10d17817
commit 4cd90ed7bc

View file

@ -1400,6 +1400,156 @@ nir_opt_remove_respills(nir_shader *shader)
NULL);
}
static void
add_use_mask(struct hash_table_u64 *offset_to_mask,
unsigned offset, unsigned mask)
{
uintptr_t old_mask = (uintptr_t)
_mesa_hash_table_u64_search(offset_to_mask, offset);
_mesa_hash_table_u64_insert(offset_to_mask, offset,
(void *)(uintptr_t)(old_mask | mask));
}
/* When splitting the shaders, we might have inserted store & loads of vec4s,
* because a live value is a 4 components. But sometimes, only some components
* of that vec4 will be used by after the scratch load. This pass removes the
* unused components of scratch load/stores.
*/
static bool
nir_opt_trim_stack_values(nir_shader *shader)
{
nir_function_impl *impl = nir_shader_get_entrypoint(shader);
struct hash_table_u64 *value_id_to_mask = _mesa_hash_table_u64_create(NULL);
bool progress = false;
/* Find all the loads and how their value is being used */
nir_foreach_block_safe(block, impl) {
nir_foreach_instr_safe(instr, block) {
if (instr->type != nir_instr_type_intrinsic)
continue;
nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
if (intrin->intrinsic != nir_intrinsic_load_stack)
continue;
const unsigned value_id = nir_intrinsic_value_id(intrin);
const unsigned mask =
nir_ssa_def_components_read(nir_instr_ssa_def(instr));
add_use_mask(value_id_to_mask, value_id, mask);
}
}
/* For each store, if it stores more than is being used, trim it.
* Otherwise, remove it from the hash table.
*/
nir_foreach_block_safe(block, impl) {
nir_foreach_instr_safe(instr, block) {
if (instr->type != nir_instr_type_intrinsic)
continue;
nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
if (intrin->intrinsic != nir_intrinsic_store_stack)
continue;
const unsigned value_id = nir_intrinsic_value_id(intrin);
const unsigned write_mask = nir_intrinsic_write_mask(intrin);
const unsigned read_mask = (uintptr_t)
_mesa_hash_table_u64_search(value_id_to_mask, value_id);
/* Already removed from the table, nothing to do */
if (read_mask == 0)
continue;
/* Matching read/write mask, nothing to do, remove from the table. */
if (write_mask == read_mask) {
_mesa_hash_table_u64_remove(value_id_to_mask, value_id);
continue;
}
nir_builder b;
nir_builder_init(&b, impl);
b.cursor = nir_before_instr(instr);
nir_ssa_def *value = nir_channels(&b, intrin->src[0].ssa, read_mask);
nir_instr_rewrite_src_ssa(instr, &intrin->src[0], value);
intrin->num_components = util_bitcount(read_mask);
nir_intrinsic_set_write_mask(intrin, (1u << intrin->num_components) - 1);
progress = true;
}
}
/* For each load remaining in the hash table (only the ones we changed the
* number of components of), apply triming/reswizzle.
*/
nir_foreach_block_safe(block, impl) {
nir_foreach_instr_safe(instr, block) {
if (instr->type != nir_instr_type_intrinsic)
continue;
nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
if (intrin->intrinsic != nir_intrinsic_load_stack)
continue;
const unsigned value_id = nir_intrinsic_value_id(intrin);
unsigned read_mask = (uintptr_t)
_mesa_hash_table_u64_search(value_id_to_mask, value_id);
if (read_mask == 0)
continue;
unsigned swiz_map[NIR_MAX_VEC_COMPONENTS] = { 0, };
unsigned swiz_count = 0;
u_foreach_bit(idx, read_mask)
swiz_map[idx] = swiz_count++;
nir_ssa_def *def = nir_instr_ssa_def(instr);
nir_foreach_use_safe(use_src, def) {
if (use_src->parent_instr->type == nir_instr_type_alu) {
nir_alu_instr *alu = nir_instr_as_alu(use_src->parent_instr);
nir_alu_src *alu_src = exec_node_data(nir_alu_src, use_src, src);
unsigned write_mask = alu->dest.write_mask;
u_foreach_bit(idx, write_mask)
alu_src->swizzle[idx] = swiz_map[alu_src->swizzle[idx]];
} else if (use_src->parent_instr->type == nir_instr_type_intrinsic) {
nir_intrinsic_instr *use_intrin =
nir_instr_as_intrinsic(use_src->parent_instr);
assert(nir_intrinsic_has_write_mask(use_intrin));
unsigned write_mask = nir_intrinsic_write_mask(use_intrin);
unsigned new_write_mask = 0;
u_foreach_bit(idx, write_mask)
new_write_mask |= 1 << swiz_map[idx];
nir_intrinsic_set_write_mask(use_intrin, new_write_mask);
} else {
unreachable("invalid instruction type");
}
}
intrin->dest.ssa.num_components = intrin->num_components = swiz_count;
progress = true;
}
}
nir_metadata_preserve(impl,
progress ?
(nir_metadata_dominance |
nir_metadata_block_index |
nir_metadata_loop_analysis) :
nir_metadata_all);
_mesa_hash_table_u64_destroy(value_id_to_mask);
return progress;
}
/** Lower shader call instructions to split shaders.
*
* Shader calls can be split into an initial shader and a series of "resume"
@ -1464,6 +1614,8 @@ nir_lower_shader_calls(nir_shader *shader,
NIR_PASS_V(shader, nir_opt_remove_phis);
NIR_PASS_V(shader, nir_opt_trim_stack_values);
/* Make N copies of our shader */
nir_shader **resume_shaders = ralloc_array(mem_ctx, nir_shader *, num_calls);
for (unsigned i = 0; i < num_calls; i++) {