freedreno/ir3: Lower GS builtins before lowering IO

We mostly got away with replacing a store_output with a store_var, but
for complex types like structs, that doesn't work. Once the IO has
been lowered from vars to intrinsic, we've lost the deref chains and
can't properly shadow the outputs.

This commits moves the GS lowering up so we do it before the output
variables get lowered to store_output.  This way the pass works much
like nir_lower_io_to_temporaries() and cleanly shadows the outputs.

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/4562>
This commit is contained in:
Kristian H. Kristensen 2020-04-28 12:52:42 -07:00 committed by Marge Bot
parent 79355fd901
commit dd8d257a30
4 changed files with 64 additions and 61 deletions

View file

@ -234,7 +234,6 @@ ir3_optimize_nir(struct ir3_shader *shader, nir_shader *s,
NIR_PASS_V(s, ir3_nir_lower_to_explicit_output, shader, key->tessellation);
break;
case MESA_SHADER_GEOMETRY:
NIR_PASS_V(s, ir3_nir_lower_gs, shader);
NIR_PASS_V(s, ir3_nir_lower_to_explicit_input);
break;
default:

View file

@ -49,7 +49,7 @@ void ir3_nir_lower_to_explicit_output(nir_shader *shader,
void ir3_nir_lower_to_explicit_input(nir_shader *shader);
void ir3_nir_lower_tess_ctrl(nir_shader *shader, struct ir3_shader *s, unsigned topology);
void ir3_nir_lower_tess_eval(nir_shader *shader, unsigned topology);
void ir3_nir_lower_gs(nir_shader *shader, struct ir3_shader *s);
void ir3_nir_lower_gs(nir_shader *shader);
const nir_shader_compiler_options * ir3_get_compiler_options(struct ir3_compiler *compiler);
bool ir3_key_lowers_nir(const struct ir3_shader_key *key);

View file

@ -38,10 +38,10 @@ struct state {
nir_variable *vertex_count_var;
nir_variable *emitted_vertex_var;
nir_variable *vertex_flags_var;
nir_variable *vertex_flags_out;
nir_variable *output_vars[32];
struct exec_list old_outputs;
struct exec_list emit_outputs;
nir_ssa_def *outer_levels[4];
nir_ssa_def *inner_levels[2];
@ -782,8 +782,6 @@ ir3_nir_lower_tess_eval(nir_shader *shader, unsigned topology)
static void
lower_gs_block(nir_block *block, nir_builder *b, struct state *state)
{
nir_intrinsic_instr *outputs[32] = {};
nir_foreach_instr_safe (instr, block) {
if (instr->type != nir_instr_type_intrinsic)
continue;
@ -791,38 +789,24 @@ lower_gs_block(nir_block *block, nir_builder *b, struct state *state)
nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
switch (intr->intrinsic) {
case nir_intrinsic_store_output: {
// src[] = { value, offset }.
uint32_t loc = nir_intrinsic_base(intr);
outputs[loc] = intr;
break;
}
case nir_intrinsic_end_primitive: {
b->cursor = nir_before_instr(&intr->instr);
nir_store_var(b, state->vertex_flags_var, nir_imm_int(b, 4), 0x1);
nir_store_var(b, state->vertex_flags_out, nir_imm_int(b, 4), 0x1);
nir_instr_remove(&intr->instr);
break;
}
case nir_intrinsic_emit_vertex: {
/* Load the vertex count */
b->cursor = nir_before_instr(&intr->instr);
nir_ssa_def *count = nir_load_var(b, state->vertex_count_var);
nir_push_if(b, nir_ieq(b, count, local_thread_id(b)));
for (uint32_t i = 0; i < ARRAY_SIZE(outputs); i++) {
if (outputs[i]) {
nir_store_var(b, state->output_vars[i],
outputs[i]->src[0].ssa,
(1 << outputs[i]->num_components) - 1);
nir_instr_remove(&outputs[i]->instr);
}
outputs[i] = NULL;
foreach_two_lists(dest_node, &state->emit_outputs, src_node, &state->old_outputs) {
nir_variable *dest = exec_node_data(nir_variable, dest_node, node);
nir_variable *src = exec_node_data(nir_variable, src_node, node);
nir_copy_var(b, dest, src);
}
nir_instr_remove(&intr->instr);
@ -830,15 +814,12 @@ lower_gs_block(nir_block *block, nir_builder *b, struct state *state)
nir_store_var(b, state->emitted_vertex_var,
nir_iadd(b, nir_load_var(b, state->emitted_vertex_var), nir_imm_int(b, 1)), 0x1);
nir_store_var(b, state->vertex_flags_out,
nir_load_var(b, state->vertex_flags_var), 0x1);
nir_pop_if(b, NULL);
/* Increment the vertex count by 1 */
nir_store_var(b, state->vertex_count_var,
nir_iadd(b, count, nir_imm_int(b, 1)), 0x1); /* .x */
nir_store_var(b, state->vertex_flags_var, nir_imm_int(b, 0), 0x1);
nir_store_var(b, state->vertex_flags_out, nir_imm_int(b, 0), 0x1);
break;
}
@ -849,27 +830,6 @@ lower_gs_block(nir_block *block, nir_builder *b, struct state *state)
}
}
static void
emit_store_outputs(nir_builder *b, struct state *state)
{
/* This also stores the internally added vertex_flags output. */
for (uint32_t i = 0; i < ARRAY_SIZE(state->output_vars); i++) {
if (!state->output_vars[i])
continue;
nir_intrinsic_instr *store =
nir_intrinsic_instr_create(b->shader, nir_intrinsic_store_output);
nir_intrinsic_set_base(store, i);
store->src[0] = nir_src_for_ssa(nir_load_var(b, state->output_vars[i]));
store->src[1] = nir_src_for_ssa(nir_imm_int(b, 0));
store->num_components = store->src[0].ssa->num_components;
nir_builder_instr_insert(b, &store->instr);
}
}
static void
clean_up_split_vars(nir_shader *shader, struct exec_list *list)
{
@ -892,7 +852,7 @@ clean_up_split_vars(nir_shader *shader, struct exec_list *list)
}
void
ir3_nir_lower_gs(nir_shader *shader, struct ir3_shader *s)
ir3_nir_lower_gs(nir_shader *shader)
{
struct state state = { };
@ -906,10 +866,15 @@ ir3_nir_lower_gs(nir_shader *shader, struct ir3_shader *s)
build_primitive_map(shader, &state.map, &shader->inputs);
/* Create an output var for vertex_flags. This will be shadowed below,
* same way regular outputs get shadowed, and this variable will become a
* temporary.
*/
state.vertex_flags_out = nir_variable_create(shader, nir_var_shader_out,
glsl_uint_type(), "vertex_flags");
state.vertex_flags_out->data.driver_location = shader->num_outputs++;
state.vertex_flags_out->data.location = VARYING_SLOT_GS_VERTEX_FLAGS_IR3;
state.vertex_flags_out->data.interpolation = INTERP_MODE_NONE;
nir_function_impl *impl = nir_shader_get_entrypoint(shader);
assert(impl);
@ -920,25 +885,48 @@ ir3_nir_lower_gs(nir_shader *shader, struct ir3_shader *s)
state.header = nir_load_gs_header_ir3(&b);
nir_foreach_variable (var, &shader->outputs) {
state.output_vars[var->data.driver_location] =
nir_local_variable_create(impl, var->type,
ralloc_asprintf(var, "%s:gs-temp", var->name));
/* Generate two set of shadow vars for the output variables. The first
* set replaces the real outputs and the second set (emit_outputs) we'll
* assign in the emit_vertex conditionals. Then at the end of the shader
* we copy the emit_outputs to the real outputs, so that we get
* store_output in uniform control flow.
*/
exec_list_move_nodes_to(&shader->outputs, &state.old_outputs);
exec_list_make_empty(&state.emit_outputs);
nir_foreach_variable(var, &state.old_outputs) {
/* Create a new output var by cloning the original output var and
* stealing the name.
*/
nir_variable *output = nir_variable_clone(var, shader);
exec_list_push_tail(&shader->outputs, &output->node);
/* Rewrite the original output to be a shadow variable. */
var->name = ralloc_asprintf(var, "%s@gs-temp", output->name);
var->data.mode = nir_var_shader_temp;
/* Clone the shadow variable to create the emit shadow variable that
* we'll assign in the emit conditionals.
*/
nir_variable *emit_output = nir_variable_clone(var, shader);
emit_output->name = ralloc_asprintf(var, "%s@emit-temp", output->name);
exec_list_push_tail(&state.emit_outputs, &emit_output->node);
}
/* During the shader we'll keep track of which vertex we're currently
* emitting for the EmitVertex test and how many vertices we emitted so we
* know to discard if didn't emit any. In most simple shaders, this can
* all be statically determined and gets optimized away.
*/
state.vertex_count_var =
nir_local_variable_create(impl, glsl_uint_type(), "vertex_count");
state.emitted_vertex_var =
nir_local_variable_create(impl, glsl_uint_type(), "emitted_vertex");
state.vertex_flags_var =
nir_local_variable_create(impl, glsl_uint_type(), "vertex_flags");
state.vertex_flags_out = state.output_vars[state.vertex_flags_out->data.driver_location];
/* initialize to 0 */
/* Initialize to 0. */
b.cursor = nir_before_cf_list(&impl->body);
nir_store_var(&b, state.vertex_count_var, nir_imm_int(&b, 0), 0x1);
nir_store_var(&b, state.emitted_vertex_var, nir_imm_int(&b, 0), 0x1);
nir_store_var(&b, state.vertex_flags_var, nir_imm_int(&b, 4), 0x1);
nir_store_var(&b, state.vertex_flags_out, nir_imm_int(&b, 4), 0x1);
nir_foreach_block_safe (block, impl)
lower_gs_block(block, &b, &state);
@ -956,11 +944,24 @@ ir3_nir_lower_gs(nir_shader *shader, struct ir3_shader *s)
nir_builder_instr_insert(&b, &discard_if->instr);
emit_store_outputs(&b, &state);
foreach_two_lists(dest_node, &shader->outputs, src_node, &state.emit_outputs) {
nir_variable *dest = exec_node_data(nir_variable, dest_node, node);
nir_variable *src = exec_node_data(nir_variable, src_node, node);
nir_copy_var(&b, dest, src);
}
}
exec_list_append(&shader->globals, &state.old_outputs);
exec_list_append(&shader->globals, &state.emit_outputs);
nir_metadata_preserve(impl, 0);
nir_lower_global_vars_to_local(shader);
nir_split_var_copies(shader);
nir_lower_var_copies(shader);
nir_fixup_deref_modes(shader);
if (shader_debug_enabled(shader->info.stage)) {
fprintf(stderr, "NIR (after gs lowering):\n");
nir_print_shader(shader, stderr);

View file

@ -302,6 +302,9 @@ ir3_shader_from_nir(struct ir3_compiler *compiler, nir_shader *nir,
if (stream_output)
memcpy(&shader->stream_output, stream_output, sizeof(shader->stream_output));
if (nir->info.stage == MESA_SHADER_GEOMETRY)
NIR_PASS_V(nir, ir3_nir_lower_gs);
NIR_PASS_V(nir, nir_lower_io, nir_var_all, ir3_glsl_type_size,
(nir_lower_io_options)0);