ac/nir: use s_sendmsg(HS_TESSFACTOR) to optimize writing tess factors for gfx11

This uses the new shader message. It eliminates memory stores and latency
for simple cases of tess level values.

Reviewed-by: Timur Kristóf <timur.kristof@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/31673>
This commit is contained in:
Marek Olšák 2024-09-23 19:31:20 -04:00 committed by Marge Bot
parent f4eebb373c
commit b49eab68a8
2 changed files with 232 additions and 27 deletions

View file

@ -769,6 +769,213 @@ hs_resize_tess_factor(nir_builder *b, nir_def *tf, unsigned comps)
return tf;
}
static nir_if *
hs_if_invocation_id_zero(nir_builder *b)
{
nir_def *invocation_id = nir_load_invocation_id(b);
/* Only the 1st invocation of each patch needs to do this. */
nir_if *invocation_id_zero = nir_push_if(b, nir_ieq_imm(b, invocation_id, 0));
/* When the output patch size is <= 32 then we can flatten the branch here
* because we know for sure that at least 1 invocation in all waves will
* take the branch.
*/
if (b->shader->info.tess.tcs_vertices_out <= 32)
invocation_id_zero->control = nir_selection_control_divergent_always_taken;
return invocation_id_zero;
}
static nir_def *
tess_level_has_effect(nir_builder *b, nir_def *prim_mode, unsigned comp, bool outer)
{
if (outer && comp <= 1)
return nir_imm_true(b);
else if ((outer && comp == 2) || (!outer && comp == 0))
return nir_ine_imm(b, prim_mode, TESS_PRIMITIVE_ISOLINES);
else if ((outer && comp == 3) || (!outer && comp == 1))
return nir_ieq_imm(b, prim_mode, TESS_PRIMITIVE_QUADS);
else
unreachable("invalid comp");
}
/* Return true if memory should be used. If false is returned, the shader message has been used. */
static nir_def *
hs_msg_group_vote_use_memory(nir_builder *b, lower_tess_io_state *st,
tess_levels *tessfactors, nir_def *prim_mode)
{
/* Don't do the group vote and send the message directly if tess level values were determined
* by nir_gather_tcs_info at compile time.
*
* Disable the shader cache if you set the environment variable.
*/
if (debug_get_bool_option("AMD_FAST_HS_MSG", true) &&
(st->tcs_info.all_tess_levels_are_effectively_zero ||
st->tcs_info.all_tess_levels_are_effectively_one)) {
nir_if *if_subgroup0 = nir_push_if(b, nir_ieq_imm(b, nir_load_subgroup_id(b), 0));
{
/* m0[0] == 0 means all TF are 0 in the workgroup.
* m0[0] == 1 means all TF are 1 in the workgroup.
*/
nir_def *m0 = nir_imm_int(b, st->tcs_info.all_tess_levels_are_effectively_zero ? 0 : 1);
nir_sendmsg_amd(b, m0, .base = AC_SENDMSG_HS_TESSFACTOR);
}
nir_pop_if(b, if_subgroup0);
return nir_imm_false(b);
}
/* Initialize the first LDS dword for the tf0/1 group vote at the beginning of TCS. */
nir_block *start_block = nir_start_block(nir_shader_get_entrypoint(b->shader));
nir_builder top_b = nir_builder_at(nir_before_block(start_block));
nir_if *thread0 = nir_push_if(&top_b,
nir_iand(&top_b, nir_ieq_imm(&top_b, nir_load_subgroup_id(&top_b), 0),
nir_inverse_ballot(&top_b, 1, nir_imm_ivec4(&top_b, 0x1, 0, 0, 0))));
{
/* 0x3 is the initial bitmask (tf0 | tf1). Each subgroup will do atomic iand on it for the vote. */
nir_store_shared(&top_b, nir_imm_int(&top_b, 0x3), nir_imm_int(&top_b, 0),
.write_mask = 0x1, .align_mul = 4);
}
nir_pop_if(&top_b, thread0);
/* Insert a barrier to wait for initialization above if there hasn't been any other barrier
* in the shader.
*/
if (!st->tcs_info.always_executes_barrier) {
nir_barrier(b, .execution_scope = SCOPE_WORKGROUP, .memory_scope = SCOPE_WORKGROUP,
.memory_semantics = NIR_MEMORY_ACQ_REL, .memory_modes = nir_var_mem_shared);
}
/* Use s_sendmsg to tell the hw whether the whole workgroup has either of these cases:
*
* tf0: All patches in the workgroup have at least one outer tess level component either
* in the [-inf, 0] range or equal to NaN, causing them to be discarded. Inner tess levels
* have no effect.
*
* tf1: All patches in the workgroup have the values of tess levels set to 1 or equivalent numbers,
* which doesn't discard any patches. Each spacing interprets different tess level ranges as 1:
*
* 1) equal_spacing, fractional_odd_spacing, and unknown spacing
* For undiscarded patches, the tessellator clamps all tess levels to 1. If all tess levels
* are in the (0, 1] range, which is effectively 1, untessellated patches are
* drawn.
*
* 2) fractional_even_spacing
* For undiscarded patches, the tessellator clamps all tess levels to 2 (both outer and inner)
* except isolines, which clamp the first outer tess level component to 1. If all outer tess
* levels are in the (0, 2] or (0, 1] range (for outer[0] of isolines) and all inner tess levels
* are in the [-inf, 2] range, the tf1 message can be used. The tessellator will receive 1 via
* the message, but will clamp them to 2 or keep 1 (for outer[0] of isolines).
*
* If we make this mutually exclusive with tf0, we only have to compare against the upper bound.
*/
/* Determine tf0/tf1 for the subgroup at the end of TCS. */
nir_if *if_invocation_id_zero = hs_if_invocation_id_zero(b);
{
*tessfactors = hs_load_tess_levels(b, st);
nir_def *lane_tf_effectively_0 = nir_imm_false(b);
for (unsigned i = 0; i < tessfactors->outer->num_components; i++) {
nir_def *valid = tess_level_has_effect(b, prim_mode, i, true);
/* fgeu returns true for NaN */
nir_def *le0 = nir_fgeu(b, nir_imm_float(b, 0), nir_channel(b, tessfactors->outer, i));
lane_tf_effectively_0 = nir_ior(b, lane_tf_effectively_0, nir_iand(b, le0, valid));
}
/* Use case 1: unknown spacing */
nir_def *lane_tf_effectively_1 = nir_imm_true(b);
for (unsigned i = 0; i < tessfactors->outer->num_components; i++) {
nir_def *valid = tess_level_has_effect(b, prim_mode, i, true);
nir_def *le1 = nir_fle_imm(b, nir_channel(b, tessfactors->outer, i), 1);
lane_tf_effectively_1 = nir_iand(b, lane_tf_effectively_1, nir_ior(b, le1, nir_inot(b, valid)));
}
if (tessfactors->inner) {
for (unsigned i = 0; i < tessfactors->inner->num_components; i++) {
nir_def *valid = tess_level_has_effect(b, prim_mode, i, false);
nir_def *le1 = nir_fle_imm(b, nir_channel(b, tessfactors->inner, i), 1);
lane_tf_effectively_1 = nir_iand(b, lane_tf_effectively_1, nir_ior(b, le1, nir_inot(b, valid)));
}
}
/* Make them mutually exclusive. */
lane_tf_effectively_1 = nir_iand(b, lane_tf_effectively_1, nir_inot(b, lane_tf_effectively_0));
nir_def *subgroup_uses_tf0 = nir_b2i32(b, nir_vote_all(b, 1, lane_tf_effectively_0));
nir_def *subgroup_uses_tf1 = nir_b2i32(b, nir_vote_all(b, 1, lane_tf_effectively_1));
/* Pack the value for LDS. Encoding:
* 0 = none of the below
* 1 = all tess factors are effectively 0
* 2 = all tess factors are effectively 1
* 3 = invalid
*
* Since we will do bitwise AND reduction across all waves, 3 can never occur.
*/
nir_def *packed_tf01_mask = nir_ior(b, subgroup_uses_tf0,
nir_ishl_imm(b, subgroup_uses_tf1, 1));
/* This function is only called within a block that only executes for patch invocation 0, so we
* only need to mask out invocation 0 of other patches in the subgroup to execute on only 1 lane.
*
* Since patch invocations are placed sequentially in the subgroup, we know that invocation 0
* of the lowest patch must be somewhere in BITFIELD_MASK(tcs_vertices_out) lanes.
*/
const unsigned tcs_vertices_out = b->shader->info.tess.tcs_vertices_out;
assert(tcs_vertices_out <= 32);
nir_def *is_first_active_lane =
nir_inverse_ballot(b, 1, nir_imm_ivec4(b, BITFIELD_MASK(tcs_vertices_out), 0, 0, 0));
/* Only the first active invocation in each subgroup performs the AND reduction through LDS. */
nir_if *if_first_active_lane = nir_push_if(b, is_first_active_lane);
if_first_active_lane->control = nir_selection_control_divergent_always_taken;
{
/* Use atomic iand to combine results from all subgroups. */
nir_shared_atomic(b, 32, nir_imm_int(b, 0), packed_tf01_mask,
.atomic_op = nir_atomic_op_iand);
}
nir_pop_if(b, if_first_active_lane);
}
nir_pop_if(b, if_invocation_id_zero);
/* The caller will reuse these. */
tessfactors->outer = nir_if_phi(b, tessfactors->outer, nir_undef(b, tessfactors->outer->num_components, 32));
if (tessfactors->inner) /* Isolines don't have inner tess levels. */
tessfactors->inner = nir_if_phi(b, tessfactors->inner, nir_undef(b, tessfactors->inner->num_components, 32));
/* Wait for all waves to execute the LDS atomic. */
nir_barrier(b, .execution_scope = SCOPE_WORKGROUP, .memory_scope = SCOPE_WORKGROUP,
.memory_semantics = NIR_MEMORY_ACQ_REL, .memory_modes = nir_var_mem_shared);
/* Read the result from LDS. Only 1 lane should load it to prevent LDS bank conflicts. */
nir_def *lds_result;
nir_if *if_lane0 = nir_push_if(b, nir_inverse_ballot(b, 1, nir_imm_ivec4(b, 0x1, 0, 0, 0)));
if_lane0->control = nir_selection_control_divergent_always_taken;
{
lds_result = nir_load_shared(b, 1, 32, nir_imm_int(b, 0), .align_mul = 4);
}
nir_pop_if(b, if_lane0);
lds_result = nir_if_phi(b, lds_result, nir_undef(b, 1, 32));
lds_result = nir_read_invocation(b, lds_result, nir_imm_int(b, 0));
/* Determine the vote value and send the message. */
nir_def *use_memory = nir_ieq_imm(b, lds_result, 0);
nir_if *if_subgroup0_sendmsg = nir_push_if(b, nir_iand(b, nir_inot(b, use_memory),
nir_ieq_imm(b, nir_load_subgroup_id(b), 0)));
{
/* m0[0] == 0 means all TF are 0 in the workgroup.
* m0[0] == 1 means all TF are 1 in the workgroup.
*/
nir_def *m0 = nir_iadd_imm(b, lds_result, -1);
nir_sendmsg_amd(b, m0, .base = AC_SENDMSG_HS_TESSFACTOR);
}
nir_pop_if(b, if_subgroup0_sendmsg);
return use_memory;
}
static void
hs_store_tess_factors_for_tessellator(nir_builder *b, enum amd_gfx_level gfx_level,
enum tess_primitive_mode prim_mode,
@ -844,27 +1051,8 @@ hs_store_tess_factors_for_tes(nir_builder *b, tess_levels tessfactors, lower_tes
}
}
static nir_if *
hs_if_invocation_id_zero(nir_builder *b)
{
nir_def *invocation_id = nir_load_invocation_id(b);
/* Only the 1st invocation of each patch needs to do this. */
nir_if *invocation_id_zero = nir_push_if(b, nir_ieq_imm(b, invocation_id, 0));
/* When the output patch size is <= 32 then we can flatten the branch here
* because we know for sure that at least 1 invocation in all waves will
* take the branch.
*/
if (b->shader->info.tess.tcs_vertices_out <= 32)
invocation_id_zero->control = nir_selection_control_divergent_always_taken;
return invocation_id_zero;
}
static void
hs_finale(nir_shader *shader,
lower_tess_io_state *st)
hs_finale(nir_shader *shader, lower_tess_io_state *st)
{
nir_function_impl *impl = nir_shader_get_entrypoint(shader);
assert(impl);
@ -874,22 +1062,35 @@ hs_finale(nir_shader *shader,
nir_builder builder = nir_builder_at(nir_after_block(last_block));
nir_builder *b = &builder; /* This is to avoid the & */
/* If tess factors are load from LDS, wait previous LDS stores done. */
/* If tess factors are loaded from LDS, wait for their LDS stores. */
if (!st->tcs_info.all_invocations_define_tess_levels) {
mesa_scope scope = st->tcs_out_patch_fits_subgroup ? SCOPE_SUBGROUP : SCOPE_WORKGROUP;
nir_barrier(b, .execution_scope = scope, .memory_scope = scope,
.memory_semantics = NIR_MEMORY_ACQ_REL, .memory_modes = nir_var_mem_shared);
st->tcs_info.always_executes_barrier = true;
}
nir_def *prim_mode = nir_load_tcs_primitive_mode_amd(b);
nir_def *use_memory = NULL;
tess_levels tessfactors = {0};
/* This also loads tess levels for patch invocation 0. */
if (st->gfx_level >= GFX11)
use_memory = hs_msg_group_vote_use_memory(b, st, &tessfactors, prim_mode);
/* Only the 1st invocation of each patch needs to access VRAM and/or LDS. */
nir_if *if_invocation_id_zero = hs_if_invocation_id_zero(b);
{
tess_levels tessfactors = hs_load_tess_levels(b, st);
if (!tessfactors.outer)
tessfactors = hs_load_tess_levels(b, st);
nir_if *if_use_memory = NULL;
if (use_memory != NULL)
if_use_memory = nir_push_if(b, use_memory);
if (st->gfx_level <= GFX8)
hs_store_dynamic_control_word_gfx6(b);
nir_def *prim_mode = nir_load_tcs_primitive_mode_amd(b);
nir_if *if_triangles = nir_push_if(b, nir_ieq_imm(b, prim_mode, TESS_PRIMITIVE_TRIANGLES));
{
hs_store_tess_factors_for_tessellator(b, st->gfx_level, TESS_PRIMITIVE_TRIANGLES, tessfactors);
@ -908,13 +1109,15 @@ hs_finale(nir_shader *shader,
}
nir_pop_if(b, if_triangles);
if (use_memory != NULL)
nir_pop_if(b, if_use_memory);
nir_if *if_tes_reads_tf = nir_push_if(b, nir_load_tcs_tess_levels_to_tes_amd(b));
{
hs_store_tess_factors_for_tes(b, tessfactors, st);
}
nir_pop_if(b, if_tes_reads_tf);
}
nir_pop_if(b, if_invocation_id_zero);
nir_metadata_preserve(impl, nir_metadata_none);

View file

@ -20,9 +20,11 @@
extern "C" {
#endif
#define AC_SENDMSG_GS 2
#define AC_SENDMSG_GS_DONE 3
#define AC_SENDMSG_GS_ALLOC_REQ 9
#define AC_SENDMSG_HS_TESSFACTOR 2
#define AC_SENDMSG_GS 2
#define AC_SENDMSG_GS_DONE 3
#define AC_SENDMSG_GS_ALLOC_REQ 9
#define AC_SENDMSG_GS_OP_NOP (0 << 4)
#define AC_SENDMSG_GS_OP_CUT (1 << 4)