ac/nir/tess: add if/endif for HS threads in NIR instead of ACO/LLVM

This just removes the if/endif wrapping for LLVM, and hopefully the ACO
change does the same thing.

ACO had redundant code in endif_merged_wave_info, which is removed here.

Acked-by: Pierre-Eric Pelloux-Prayer <pierre-eric.pelloux-prayer@amd.com>
Reviewed-by: Timur Kristóf <timur.kristof@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/34780>
This commit is contained in:
Marek Olšák 2025-04-19 08:07:12 -04:00 committed by Marge Bot
parent cd366b57d9
commit 80236f2367
5 changed files with 37 additions and 53 deletions

View file

@ -1112,6 +1112,23 @@ hs_finale(nir_shader *shader, lower_tess_io_state *st)
} }
nir_pop_if(b, if_invocation_id_zero); nir_pop_if(b, if_invocation_id_zero);
if (st->gfx_level >= GFX9) {
/* Wrap the whole shader in a conditional block, allowing only TCS (HS) invocations to execute
* in the LS-HS workgroup.
*/
nir_cf_list *extracted = rzalloc(shader, nir_cf_list);
nir_cf_extract(extracted, nir_before_impl(impl), nir_after_impl(impl));
builder = nir_builder_at(nir_before_impl(impl));
nir_if *if_tcs =
nir_push_if(b, nir_is_subgroup_invocation_lt_amd(b, nir_load_merged_wave_info_amd(b),
.base = 8));
{
nir_cf_reinsert(extracted, b->cursor);
}
nir_pop_if(b, if_tcs);
}
nir_progress(true, impl, nir_metadata_none); nir_progress(true, impl, nir_metadata_none);
} }

View file

@ -1358,6 +1358,7 @@ select_program_merged(isel_context& ctx, const unsigned shader_count, nir_shader
{ {
if_context ic_merged_wave_info; if_context ic_merged_wave_info;
const bool ngg_gs = ctx.stage.hw == AC_HW_NEXT_GEN_GEOMETRY_SHADER && ctx.stage.has(SWStage::GS); const bool ngg_gs = ctx.stage.hw == AC_HW_NEXT_GEN_GEOMETRY_SHADER && ctx.stage.has(SWStage::GS);
const bool hs = ctx.stage.hw == AC_HW_HULL_SHADER;
for (unsigned i = 0; i < shader_count; i++) { for (unsigned i = 0; i < shader_count; i++) {
nir_shader* nir = shaders[i]; nir_shader* nir = shaders[i];
@ -1378,9 +1379,9 @@ select_program_merged(isel_context& ctx, const unsigned shader_count, nir_shader
/* See if we need to emit a check of the merged wave info SGPR. */ /* See if we need to emit a check of the merged wave info SGPR. */
const bool check_merged_wave_info = const bool check_merged_wave_info =
ctx.tcs_in_out_eq ? i == 0 : (shader_count >= 2 && !empty_shader && !(ngg_gs && i == 1)); ctx.tcs_in_out_eq ? i == 0
const bool endif_merged_wave_info = : (shader_count >= 2 && !empty_shader && ((!ngg_gs && !hs) || i != 1));
ctx.tcs_in_out_eq ? i == 1 : (check_merged_wave_info && !(ngg_gs && i == 1)); const bool endif_merged_wave_info = ctx.tcs_in_out_eq ? i == 1 : check_merged_wave_info;
/* Skip s_barrier from TCS when VS outputs are not stored in the LDS. */ /* Skip s_barrier from TCS when VS outputs are not stored in the LDS. */
const bool tcs_skip_barrier = const bool tcs_skip_barrier =

View file

@ -307,16 +307,6 @@ static void si_llvm_declare_lds_esgs_ring(struct si_shader_context *ctx)
ctx->ac.lds.pointee_type = ctx->ac.i32; ctx->ac.lds.pointee_type = ctx->ac.i32;
} }
static void si_init_exec_from_input(struct si_shader_context *ctx, struct ac_arg param,
unsigned bitoffset)
{
LLVMValueRef args[] = {
ac_get_arg(&ctx->ac, param),
LLVMConstInt(ctx->ac.i32, bitoffset, 0),
};
ac_build_intrinsic(&ctx->ac, "llvm.amdgcn.init.exec.from.input", ctx->ac.voidt, args, 2, 0);
}
/** /**
* Get the value of a shader input parameter and extract a bitfield. * Get the value of a shader input parameter and extract a bitfield.
*/ */
@ -382,18 +372,13 @@ static void si_build_wrapper_function(struct si_shader_context *ctx,
} }
si_llvm_create_func(ctx, "wrapper", NULL, 0, si_get_max_workgroup_size(ctx->shader)); si_llvm_create_func(ctx, "wrapper", NULL, 0, si_get_max_workgroup_size(ctx->shader));
ac_init_exec_full_mask(&ctx->ac);
if (same_thread_count) { LLVMValueRef count = ac_get_arg(&ctx->ac, ctx->args->ac.merged_wave_info);
si_init_exec_from_input(ctx, ctx->args->ac.merged_wave_info, 0); count = LLVMBuildAnd(builder, count, LLVMConstInt(ctx->ac.i32, 0x7f, 0), "");
} else {
ac_init_exec_full_mask(&ctx->ac);
LLVMValueRef count = ac_get_arg(&ctx->ac, ctx->args->ac.merged_wave_info); LLVMValueRef ena = LLVMBuildICmp(builder, LLVMIntULT, ac_get_thread_id(&ctx->ac), count, "");
count = LLVMBuildAnd(builder, count, LLVMConstInt(ctx->ac.i32, 0x7f, 0), ""); ac_build_ifcc(&ctx->ac, ena, 6506);
LLVMValueRef ena = LLVMBuildICmp(builder, LLVMIntULT, ac_get_thread_id(&ctx->ac), count, "");
ac_build_ifcc(&ctx->ac, ena, 6506);
}
LLVMValueRef params[AC_MAX_ARGS]; LLVMValueRef params[AC_MAX_ARGS];
unsigned num_params = LLVMCountParams(ctx->main_fn.value); unsigned num_params = LLVMCountParams(ctx->main_fn.value);
@ -403,6 +388,16 @@ static void si_build_wrapper_function(struct si_shader_context *ctx,
LLVMValueRef ret = LLVMValueRef ret =
ac_build_call(&ctx->ac, parts[0].pointee_type, parts[0].value, params, num_params); ac_build_call(&ctx->ac, parts[0].pointee_type, parts[0].value, params, num_params);
if (LLVMGetTypeKind(LLVMTypeOf(ret)) != LLVMVoidTypeKind) {
LLVMValueRef ret_var = ac_build_alloca_undef(&ctx->ac, LLVMTypeOf(ret), "");
LLVMBuildStore(builder, ret, ret_var);
ac_build_endif(&ctx->ac, 6506);
ret = LLVMBuildLoad2(builder, LLVMTypeOf(ret), ret_var, "");
} else {
ac_build_endif(&ctx->ac, 6506);
}
if (same_thread_count) { if (same_thread_count) {
LLVMTypeRef type = LLVMTypeOf(ret); LLVMTypeRef type = LLVMTypeOf(ret);
assert(LLVMGetTypeKind(type) == LLVMStructTypeKind); assert(LLVMGetTypeKind(type) == LLVMStructTypeKind);
@ -432,16 +427,6 @@ static void si_build_wrapper_function(struct si_shader_context *ctx,
} }
} }
} else { } else {
ac_build_endif(&ctx->ac, 6506);
if (ctx->stage == MESA_SHADER_TESS_CTRL) {
LLVMValueRef count = ac_get_arg(&ctx->ac, ctx->args->ac.merged_wave_info);
count = LLVMBuildLShr(builder, count, LLVMConstInt(ctx->ac.i32, 8, 0), "");
count = LLVMBuildAnd(builder, count, LLVMConstInt(ctx->ac.i32, 0x7f, 0), "");
LLVMValueRef ena = LLVMBuildICmp(builder, LLVMIntULT, ac_get_thread_id(&ctx->ac), count, "");
ac_build_ifcc(&ctx->ac, ena, 6507);
}
/* The second half of the merged shader should use /* The second half of the merged shader should use
* the inputs from the toplevel (wrapper) function, * the inputs from the toplevel (wrapper) function,
@ -457,11 +442,6 @@ static void si_build_wrapper_function(struct si_shader_context *ctx,
} }
ac_build_call(&ctx->ac, parts[1].pointee_type, parts[1].value, params, num_params); ac_build_call(&ctx->ac, parts[1].pointee_type, parts[1].value, params, num_params);
/* Close the conditional wrapping the second shader. */
if (ctx->stage == MESA_SHADER_TESS_CTRL && !same_thread_count)
ac_build_endif(&ctx->ac, 6507);
LLVMBuildRetVoid(builder); LLVMBuildRetVoid(builder);
} }
@ -645,8 +625,7 @@ static bool si_llvm_translate_nir(struct si_shader_context *ctx, struct si_shade
LLVMValueRef thread_enabled = NULL; LLVMValueRef thread_enabled = NULL;
if ((ctx->stage == MESA_SHADER_GEOMETRY && !shader->key.ge.as_ngg) || if (ctx->stage == MESA_SHADER_GEOMETRY && !shader->key.ge.as_ngg) {
(ctx->stage == MESA_SHADER_TESS_CTRL && !shader->is_monolithic)) {
/* Wrap both shaders in an if statement according to the number of enabled threads /* Wrap both shaders in an if statement according to the number of enabled threads
* there. For monolithic TCS, the if statement is inserted by the wrapper function, * there. For monolithic TCS, the if statement is inserted by the wrapper function,
* not here. For NGG GS, the if statement is inserted by nir lowering. * not here. For NGG GS, the if statement is inserted by nir lowering.
@ -738,11 +717,6 @@ static bool si_llvm_translate_nir(struct si_shader_context *ctx, struct si_shade
si_llvm_es_build_end(ctx); si_llvm_es_build_end(ctx);
break; break;
case MESA_SHADER_TESS_CTRL:
if (!shader->is_monolithic)
si_llvm_tcs_build_end(ctx);
break;
case MESA_SHADER_TESS_EVAL: case MESA_SHADER_TESS_EVAL:
if (ctx->shader->key.ge.as_es) if (ctx->shader->key.ge.as_es)
si_llvm_es_build_end(ctx); si_llvm_es_build_end(ctx);

View file

@ -64,7 +64,6 @@ void si_llvm_gs_build_end(struct si_shader_context *ctx);
/* si_shader_llvm_tess.c */ /* si_shader_llvm_tess.c */
void si_llvm_ls_build_end(struct si_shader_context *ctx); void si_llvm_ls_build_end(struct si_shader_context *ctx);
void si_llvm_tcs_build_end(struct si_shader_context *ctx);
void si_llvm_init_tcs_callbacks(struct si_shader_context *ctx); void si_llvm_init_tcs_callbacks(struct si_shader_context *ctx);
/* si_shader_llvm_ps.c */ /* si_shader_llvm_ps.c */

View file

@ -32,13 +32,6 @@ static LLVMValueRef si_nir_load_tcs_varyings(struct ac_shader_abi *abi, LLVMType
return ac_build_varying_gather_values(&ctx->ac, value, num_components, component); return ac_build_varying_gather_values(&ctx->ac, value, num_components, component);
} }
void si_llvm_tcs_build_end(struct si_shader_context *ctx)
{
if (ctx->screen->info.gfx_level >= GFX9) {
ac_build_endif(&ctx->ac, SI_MERGED_WRAP_IF_LABEL);
}
}
void si_llvm_ls_build_end(struct si_shader_context *ctx) void si_llvm_ls_build_end(struct si_shader_context *ctx)
{ {
struct si_shader *shader = ctx->shader; struct si_shader *shader = ctx->shader;