radv, ac/nir: Fix multiview layer export for mesh shaders.

Unfortunately, radv_lower_multiview is not suitable for mesh shaders
because it can't know the mapping between API mesh shader
invocations and output primitives.

Additionally, when lowering view id to layer, it must be created
as a per-primitive PS input.

Fixes: d32656bc65
Signed-off-by: Timur Kristóf <timur.kristof@gmail.com>
Reviewed-by: Samuel Pitoiset <samuel.pitoiset@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/16504>
This commit is contained in:
Timur Kristóf 2022-05-13 21:32:12 +02:00 committed by Marge Bot
parent c636660585
commit c69b771e35
6 changed files with 39 additions and 11 deletions

View file

@ -125,7 +125,8 @@ ac_nir_lower_ngg_gs(nir_shader *shader,
void
ac_nir_lower_ngg_ms(nir_shader *shader,
unsigned wave_size);
unsigned wave_size,
bool multiview);
void
ac_nir_apply_first_task_to_task_shader(nir_shader *shader);

View file

@ -112,6 +112,9 @@ typedef struct
nir_ssa_def *workgroup_index;
/* True if the lowering needs to insert the layer output. */
bool insert_layer_output;
struct {
/* Bitmask of components used: 4 bits per slot, 1 bit per component. */
uint32_t components_mask;
@ -2535,6 +2538,15 @@ emit_ms_finale(nir_builder *b, lower_ngg_ms_state *s)
ms_emit_arrayed_outputs(b, invocation_index, s->per_primitive_outputs,
s->num_per_primitive_outputs, s->prim_attr_lds_addr, s);
/* Insert layer output store if the pipeline uses multiview but the API shader doesn't write it. */
if (s->insert_layer_output) {
nir_ssa_def *layer = nir_load_view_index(b);
const nir_io_semantics io_sem = { .location = VARYING_SLOT_LAYER, .num_slots = 1 };
nir_store_output(b, layer, nir_imm_int(b, 0), .base = VARYING_SLOT_LAYER, .component = 0, .io_semantics = io_sem);
b->shader->info.outputs_written |= VARYING_BIT_LAYER;
b->shader->info.per_primitive_outputs |= VARYING_BIT_LAYER;
}
/* Primitive connectivity data: describes which vertices the primitive uses. */
nir_ssa_def *prim_idx_addr = nir_imul_imm(b, invocation_index, s->vertices_per_prim);
nir_ssa_def *indices_loaded = nir_load_shared(b, s->vertices_per_prim, 8, prim_idx_addr, .base = s->prim_vtx_indices_addr);
@ -2697,7 +2709,8 @@ handle_smaller_ms_api_workgroup(nir_builder *b,
void
ac_nir_lower_ngg_ms(nir_shader *shader,
unsigned wave_size)
unsigned wave_size,
bool multiview)
{
unsigned vertices_per_prim =
num_mesh_vertices_per_primitive(shader->info.mesh.primitive_type);
@ -2761,6 +2774,7 @@ ac_nir_lower_ngg_ms(nir_shader *shader,
.numprims_lds_addr = numprims_lds_addr,
.api_workgroup_size = api_workgroup_size,
.hw_workgroup_size = hw_workgroup_size,
.insert_layer_output = multiview && !(shader->info.outputs_written & VARYING_BIT_LAYER),
};
nir_function_impl *impl = nir_shader_get_entrypoint(shader);

View file

@ -2873,7 +2873,6 @@ find_layer_out_var(nir_shader *nir)
return var;
var = nir_variable_create(nir, nir_var_shader_out, glsl_int_type(), "layer id");
var->data.per_primitive = nir->info.stage == MESA_SHADER_MESH;
var->data.location = VARYING_SLOT_LAYER;
var->data.interpolation = INTERP_MODE_NONE;
@ -2883,6 +2882,13 @@ find_layer_out_var(nir_shader *nir)
static bool
radv_lower_multiview(nir_shader *nir)
{
/* This pass is not suitable for mesh shaders, because it can't know
* the mapping between API mesh shader invocations and output primitives.
* Needs to be handled in ac_nir_lower_ngg.
*/
if (nir->info.stage == MESA_SHADER_MESH)
return false;
nir_function_impl *impl = nir_shader_get_entrypoint(nir);
bool progress = false;
@ -2924,8 +2930,6 @@ radv_lower_multiview(nir_shader *nir)
/* Update outputs_written to reflect that the pass added a new output. */
nir->info.outputs_written |= BITFIELD64_BIT(VARYING_SLOT_LAYER);
if (nir->info.stage == MESA_SHADER_MESH)
nir->info.per_primitive_outputs |= BITFIELD64_BIT(VARYING_SLOT_LAYER);
progress = true;
if (nir->info.stage == MESA_SHADER_VERTEX)
@ -4764,7 +4768,7 @@ radv_create_shaders(struct radv_pipeline *pipeline, struct radv_pipeline_layout
/* Gather info again, information such as outputs_read can be out-of-date. */
nir_shader_gather_info(stages[i].nir, nir_shader_get_entrypoint(stages[i].nir));
radv_lower_io(device, stages[i].nir);
radv_lower_io(device, stages[i].nir, stages[MESA_SHADER_MESH].nir);
stages[i].feedback.duration += os_time_get_nano() - stage_start;
}

View file

@ -991,7 +991,7 @@ find_layer_in_var(nir_shader *nir)
*/
static bool
lower_view_index(nir_shader *nir)
lower_view_index(nir_shader *nir, bool per_primitive)
{
bool progress = false;
nir_function_impl *entry = nir_shader_get_entrypoint(nir);
@ -1011,12 +1011,15 @@ lower_view_index(nir_shader *nir)
if (!layer)
layer = find_layer_in_var(nir);
layer->data.per_primitive = per_primitive;
b.cursor = nir_before_instr(instr);
nir_ssa_def *def = nir_load_var(&b, layer);
nir_ssa_def_rewrite_uses(&load->dest.ssa, def);
/* Update inputs_read to reflect that the pass added a new input. */
nir->info.inputs_read |= VARYING_BIT_LAYER;
if (per_primitive)
nir->info.per_primitive_inputs |= VARYING_BIT_LAYER;
nir_instr_remove(instr);
progress = true;
@ -1032,13 +1035,13 @@ lower_view_index(nir_shader *nir)
}
void
radv_lower_io(struct radv_device *device, nir_shader *nir)
radv_lower_io(struct radv_device *device, nir_shader *nir, bool is_mesh_shading)
{
if (nir->info.stage == MESA_SHADER_COMPUTE)
return;
if (nir->info.stage == MESA_SHADER_FRAGMENT) {
NIR_PASS(_, nir, lower_view_index);
NIR_PASS(_, nir, lower_view_index, is_mesh_shading);
nir_assign_io_var_locations(nir, nir_var_shader_in, &nir->num_inputs, MESA_SHADER_FRAGMENT);
}
@ -1239,7 +1242,7 @@ void radv_lower_ngg(struct radv_device *device, struct radv_pipeline_stage *ngg_
info->ngg_info.esgs_ring_size, info->gs.gsvs_vertex_size,
info->ngg_info.ngg_emit_size * 4u, pl_key->vs.provoking_vtx_last);
} else if (nir->info.stage == MESA_SHADER_MESH) {
NIR_PASS_V(nir, ac_nir_lower_ngg_ms, info->wave_size);
NIR_PASS_V(nir, ac_nir_lower_ngg_ms, info->wave_size, pl_key->has_multiview_view_index);
} else {
unreachable("invalid SW stage passed to radv_lower_ngg");
}

View file

@ -678,7 +678,7 @@ get_tcs_num_patches(unsigned tcs_num_input_vertices, unsigned tcs_num_output_ver
return num_patches;
}
void radv_lower_io(struct radv_device *device, nir_shader *nir);
void radv_lower_io(struct radv_device *device, nir_shader *nir, bool is_mesh_shading);
bool radv_lower_io_to_mem(struct radv_device *device, struct radv_pipeline_stage *stage,
const struct radv_pipeline_key *pl_key);

View file

@ -509,6 +509,12 @@ radv_nir_shader_info_pass(struct radv_device *device, const struct nir_shader *n
uint64_t per_vtx_mask =
nir->info.outputs_written & ~nir->info.per_primitive_outputs & ~special_mask;
/* Mesh multivew is only lowered in ac_nir_lower_ngg, so we have to fake it here. */
if (nir->info.stage == MESA_SHADER_MESH && pipeline_key->has_multiview_view_index) {
per_prim_mask |= VARYING_BIT_LAYER;
info->uses_view_index = true;
}
/* Per vertex outputs. */
outinfo->writes_pointsize = per_vtx_mask & VARYING_BIT_PSIZ;
outinfo->writes_viewport_index = per_vtx_mask & VARYING_BIT_VIEWPORT;