diff --git a/src/amd/common/ac_nir_lower_tess_io_to_mem.c b/src/amd/common/ac_nir_lower_tess_io_to_mem.c index e14ddc3b06e..d3a88fa67b2 100644 --- a/src/amd/common/ac_nir_lower_tess_io_to_mem.c +++ b/src/amd/common/ac_nir_lower_tess_io_to_mem.c @@ -564,6 +564,13 @@ hs_emit_write_tess_factors(nir_shader *shader, /* Only the 1st invocation of each patch needs to do this. */ nir_if *invocation_id_zero = nir_push_if(b, nir_ieq_imm(b, invocation_id, 0)); + /* When the output patch size is <= 32 then we can flatten the branch here + * because we know for sure that at least 1 invocation in all waves will + * take the branch. + */ + if (shader->info.tess.tcs_vertices_out <= 32) + invocation_id_zero->control = nir_selection_control_divergent_always_taken; + /* The descriptor where tess factors have to be stored by the shader. */ nir_ssa_def *tessfactor_ring = nir_load_ring_tess_factors_amd(b);