zink: add provoking vertex mode lowering

Can be used as fallback for when VK_EXT_provoking_vertex is missing

Acked-by: Mike Blumenkrantz <michael.blumenkrantz@gmail.com>
Reviewed-by: Erik Faye-Lund <erik.faye-lund@collabora.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/22162>
This commit is contained in:
antonino 2023-03-24 16:03:02 +01:00 committed by Marge Bot
parent 9466a6e2f8
commit 5a4083349f
2 changed files with 233 additions and 1 deletions

View file

@ -385,6 +385,233 @@ lower_gl_point_gs(nir_shader *shader)
nir_metadata_dominance, &state);
}
struct lower_pv_mode_state {
nir_variable *varyings[VARYING_SLOT_MAX];
nir_variable *pos_counter;
nir_variable *out_pos_counter;
unsigned primitive_vert_count;
unsigned prim;
};
static bool
lower_pv_mode_gs_store(nir_builder *b,
nir_intrinsic_instr *intrin,
struct lower_pv_mode_state *state)
{
b->cursor = nir_before_instr(&intrin->instr);
nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]);
if (nir_deref_mode_is(deref, nir_var_shader_out)) {
nir_variable *var = nir_deref_instr_get_variable(deref);
gl_varying_slot location = var->data.location;
assert(state->varyings[location]);
assert(intrin->src[1].is_ssa);
nir_ssa_def *pos_counter = nir_load_var(b, state->pos_counter);
nir_store_array_var(b, state->varyings[location],
pos_counter, intrin->src[1].ssa,
nir_intrinsic_write_mask(intrin));
nir_instr_remove(&intrin->instr);
return true;
}
return false;
}
static void
lower_pv_mode_emit_rotated_prim(nir_builder *b,
struct lower_pv_mode_state *state,
nir_ssa_def *current_vertex)
{
nir_ssa_def *two = nir_imm_int(b, 2);
nir_ssa_def *three = nir_imm_int(b, 3);
bool is_triangle = state->primitive_vert_count == 3;
/* This shader will always see the last three vertices emitted by the user gs.
* The following table is used to to rotate primitives within a strip generated
* by the user gs such that the last vertex becomes the first.
*
* [lines, tris][even/odd index][vertex mod 3]
*/
static const unsigned vert_maps[2][2][3] = {
{{1, 0, 0}, {0, 1, 0}},
{{2, 0, 1}, {2, 1, 0}}
};
/* When the primive supplied to the gs comes from a strip, the last provoking vertex
* is either the last or the second, depending on whether the triangle is at an odd
* or even position within the strip.
*
* odd or even primitive within draw
*/
nir_ssa_def *odd_prim = nir_imod(b, nir_load_primitive_id(b), two);
for (unsigned i = 0; i < state->primitive_vert_count; i++) {
/* odd or even triangle within strip emitted by user GS
* this is handled using the table
*/
nir_ssa_def *odd_user_prim = nir_imod(b, current_vertex, two);
unsigned offset_even = vert_maps[is_triangle][0][i];
unsigned offset_odd = vert_maps[is_triangle][1][i];
nir_ssa_def *offset_even_value = nir_imm_int(b, offset_even);
nir_ssa_def *offset_odd_value = nir_imm_int(b, offset_odd);
nir_ssa_def *rotated_i = nir_bcsel(b, nir_b2b1(b, odd_user_prim),
offset_odd_value, offset_even_value);
/* Here we account for how triangles are provided to the gs from a strip.
* For even primitives we rotate by 3, meaning we do nothing.
* For odd primitives we rotate by 2, combined with the previous rotation this
* means the second vertex becomes the last.
*/
if (state->prim == ZINK_PVE_PRIMITIVE_TRISTRIP)
rotated_i = nir_imod(b, nir_iadd(b, rotated_i,
nir_isub(b, three,
odd_prim)),
three);
/* Triangles that come from fans are provided to the gs the same way as
* odd triangles from a strip so always rotate by 2.
*/
else if (state->prim == ZINK_PVE_PRIMITIVE_FAN)
rotated_i = nir_imod(b, nir_iadd_imm(b, rotated_i, 2),
three);
nir_foreach_variable_with_modes(var, b->shader, nir_var_shader_out) {
gl_varying_slot location = var->data.location;
if (state->varyings[location]) {
nir_ssa_def *value = nir_load_array_var(b, state->varyings[location], rotated_i);
nir_store_var(b, var, value, (1u << value->num_components) - 1);
}
}
nir_emit_vertex(b);
}
}
static bool
lower_pv_mode_gs_emit_vertex(nir_builder *b,
nir_intrinsic_instr *intrin,
struct lower_pv_mode_state *state)
{
b->cursor = nir_before_instr(&intrin->instr);
// increment pos_counter
nir_ssa_def *pos_counter = nir_load_var(b, state->pos_counter);
nir_store_var(b, state->pos_counter, nir_iadd_imm(b, pos_counter, 1), 1);
nir_instr_remove(&intrin->instr);
return true;
}
static bool
lower_pv_mode_gs_end_primitive(nir_builder *b,
nir_intrinsic_instr *intrin,
struct lower_pv_mode_state *state)
{
b->cursor = nir_before_instr(&intrin->instr);
nir_ssa_def *pos_counter = nir_load_var(b, state->pos_counter);
nir_push_loop(b);
{
nir_ssa_def *out_pos_counter = nir_load_var(b, state->out_pos_counter);
nir_push_if(b, nir_ilt(b, nir_isub(b, pos_counter, out_pos_counter),
nir_imm_int(b, state->primitive_vert_count - 1)));
nir_jump(b, nir_jump_break);
nir_pop_if(b, NULL);
lower_pv_mode_emit_rotated_prim(b, state, out_pos_counter);
nir_end_primitive(b);
nir_store_var(b, state->out_pos_counter, nir_iadd_imm(b, out_pos_counter, 1), 1);
}
nir_pop_loop(b, NULL);
nir_store_var(b, state->pos_counter, nir_imm_int(b, 0), 1);
nir_store_var(b, state->out_pos_counter, nir_imm_int(b, 0), 1);
nir_instr_remove(&intrin->instr);
return true;
}
static bool
lower_pv_mode_gs_instr(nir_builder *b, nir_instr *instr, void *data)
{
if (instr->type != nir_instr_type_intrinsic)
return false;
struct lower_pv_mode_state *state = data;
nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
switch (intrin->intrinsic) {
case nir_intrinsic_store_deref:
return lower_pv_mode_gs_store(b, intrin, state);
case nir_intrinsic_copy_deref:
unreachable("should be lowered");
case nir_intrinsic_emit_vertex_with_counter:
case nir_intrinsic_emit_vertex:
return lower_pv_mode_gs_emit_vertex(b, intrin, state);
case nir_intrinsic_end_primitive:
case nir_intrinsic_end_primitive_with_counter:
return lower_pv_mode_gs_end_primitive(b, intrin, state);
default:
return false;
}
}
static unsigned int
lower_pv_mode_vertices_for_prim(enum shader_prim prim)
{
switch (prim) {
case SHADER_PRIM_POINTS:
return 1;
case SHADER_PRIM_LINE_STRIP:
return 2;
case SHADER_PRIM_TRIANGLE_STRIP:
return 3;
default:
unreachable("unsupported primitive for gs output");
}
}
static bool
lower_pv_mode_gs(nir_shader *shader, unsigned prim)
{
nir_builder b;
struct lower_pv_mode_state state;
memset(state.varyings, 0, sizeof(state.varyings));
nir_function_impl *entry = nir_shader_get_entrypoint(shader);
nir_builder_init(&b, entry);
b.cursor = nir_before_cf_list(&entry->body);
state.primitive_vert_count =
lower_pv_mode_vertices_for_prim(shader->info.gs.output_primitive);
nir_foreach_variable_with_modes(var, shader, nir_var_shader_out) {
gl_varying_slot location = var->data.location;
char name[100];
snprintf(name, sizeof(name), "__tmp_primverts_%d", location);
state.varyings[location] =
nir_local_variable_create(entry,
glsl_array_type(var->type,
shader->info.gs.vertices_out,
false),
name);
}
state.pos_counter = nir_local_variable_create(entry,
glsl_uint_type(),
"__pos_counter");
state.out_pos_counter = nir_local_variable_create(entry,
glsl_uint_type(),
"__out_pos_counter");
state.prim = prim;
// initialize pos_counter and out_pos_counter
nir_store_var(&b, state.pos_counter, nir_imm_int(&b, 0), 1);
nir_store_var(&b, state.out_pos_counter, nir_imm_int(&b, 0), 1);
shader->info.gs.vertices_out = (shader->info.gs.vertices_out -
(state.primitive_vert_count - 1)) *
state.primitive_vert_count;
return nir_shader_instructions_pass(shader, lower_pv_mode_gs_instr,
nir_metadata_dominance, &state);
}
struct lower_line_stipple_state {
nir_variable *pos_out;
nir_variable *stipple_out;
@ -3251,6 +3478,11 @@ zink_shader_compile(struct zink_screen *screen, struct zink_shader *zs,
NIR_PASS_V(nir, lower_gl_point_gs);
need_optimize = true;
}
if (zink_gs_key(key)->lower_pv_mode) {
NIR_PASS_V(nir, lower_pv_mode_gs, zink_gs_key(key)->lower_pv_mode);
need_optimize = true; //TODO verify that this is required
}
break;
default:

View file

@ -2310,7 +2310,7 @@ zink_set_primitive_emulation_keys(struct zink_context *ctx)
if (lower_line_stipple || lower_line_smooth ||
lower_edge_flags || lower_quad_prim ||
zink_get_gs_key(ctx)->lower_gl_point) {
lower_pv_mode || zink_get_gs_key(ctx)->lower_gl_point) {
enum pipe_shader_type prev_vertex_stage =
ctx->gfx_stages[MESA_SHADER_TESS_EVAL] ?
MESA_SHADER_TESS_EVAL : MESA_SHADER_VERTEX;