diff --git a/src/amd/common/nir/ac_nir_lower_tess_io_to_mem.c b/src/amd/common/nir/ac_nir_lower_tess_io_to_mem.c index 2bfc8b1f306..d220fb35687 100644 --- a/src/amd/common/nir/ac_nir_lower_tess_io_to_mem.c +++ b/src/amd/common/nir/ac_nir_lower_tess_io_to_mem.c @@ -1112,6 +1112,23 @@ hs_finale(nir_shader *shader, lower_tess_io_state *st) } nir_pop_if(b, if_invocation_id_zero); + if (st->gfx_level >= GFX9) { + /* Wrap the whole shader in a conditional block, allowing only TCS (HS) invocations to execute + * in the LS-HS workgroup. + */ + nir_cf_list *extracted = rzalloc(shader, nir_cf_list); + nir_cf_extract(extracted, nir_before_impl(impl), nir_after_impl(impl)); + + builder = nir_builder_at(nir_before_impl(impl)); + nir_if *if_tcs = + nir_push_if(b, nir_is_subgroup_invocation_lt_amd(b, nir_load_merged_wave_info_amd(b), + .base = 8)); + { + nir_cf_reinsert(extracted, b->cursor); + } + nir_pop_if(b, if_tcs); + } + nir_progress(true, impl, nir_metadata_none); } diff --git a/src/amd/compiler/instruction_selection/aco_select_nir.cpp b/src/amd/compiler/instruction_selection/aco_select_nir.cpp index ab622a4107a..b017e361af0 100644 --- a/src/amd/compiler/instruction_selection/aco_select_nir.cpp +++ b/src/amd/compiler/instruction_selection/aco_select_nir.cpp @@ -1358,6 +1358,7 @@ select_program_merged(isel_context& ctx, const unsigned shader_count, nir_shader { if_context ic_merged_wave_info; const bool ngg_gs = ctx.stage.hw == AC_HW_NEXT_GEN_GEOMETRY_SHADER && ctx.stage.has(SWStage::GS); + const bool hs = ctx.stage.hw == AC_HW_HULL_SHADER; for (unsigned i = 0; i < shader_count; i++) { nir_shader* nir = shaders[i]; @@ -1378,9 +1379,9 @@ select_program_merged(isel_context& ctx, const unsigned shader_count, nir_shader /* See if we need to emit a check of the merged wave info SGPR. */ const bool check_merged_wave_info = - ctx.tcs_in_out_eq ? i == 0 : (shader_count >= 2 && !empty_shader && !(ngg_gs && i == 1)); - const bool endif_merged_wave_info = - ctx.tcs_in_out_eq ? i == 1 : (check_merged_wave_info && !(ngg_gs && i == 1)); + ctx.tcs_in_out_eq ? i == 0 + : (shader_count >= 2 && !empty_shader && ((!ngg_gs && !hs) || i != 1)); + const bool endif_merged_wave_info = ctx.tcs_in_out_eq ? i == 1 : check_merged_wave_info; /* Skip s_barrier from TCS when VS outputs are not stored in the LDS. */ const bool tcs_skip_barrier = diff --git a/src/gallium/drivers/radeonsi/si_shader_llvm.c b/src/gallium/drivers/radeonsi/si_shader_llvm.c index 3270f793f8d..224e8036662 100644 --- a/src/gallium/drivers/radeonsi/si_shader_llvm.c +++ b/src/gallium/drivers/radeonsi/si_shader_llvm.c @@ -307,16 +307,6 @@ static void si_llvm_declare_lds_esgs_ring(struct si_shader_context *ctx) ctx->ac.lds.pointee_type = ctx->ac.i32; } -static void si_init_exec_from_input(struct si_shader_context *ctx, struct ac_arg param, - unsigned bitoffset) -{ - LLVMValueRef args[] = { - ac_get_arg(&ctx->ac, param), - LLVMConstInt(ctx->ac.i32, bitoffset, 0), - }; - ac_build_intrinsic(&ctx->ac, "llvm.amdgcn.init.exec.from.input", ctx->ac.voidt, args, 2, 0); -} - /** * Get the value of a shader input parameter and extract a bitfield. */ @@ -382,18 +372,13 @@ static void si_build_wrapper_function(struct si_shader_context *ctx, } si_llvm_create_func(ctx, "wrapper", NULL, 0, si_get_max_workgroup_size(ctx->shader)); + ac_init_exec_full_mask(&ctx->ac); - if (same_thread_count) { - si_init_exec_from_input(ctx, ctx->args->ac.merged_wave_info, 0); - } else { - ac_init_exec_full_mask(&ctx->ac); + LLVMValueRef count = ac_get_arg(&ctx->ac, ctx->args->ac.merged_wave_info); + count = LLVMBuildAnd(builder, count, LLVMConstInt(ctx->ac.i32, 0x7f, 0), ""); - LLVMValueRef count = ac_get_arg(&ctx->ac, ctx->args->ac.merged_wave_info); - count = LLVMBuildAnd(builder, count, LLVMConstInt(ctx->ac.i32, 0x7f, 0), ""); - - LLVMValueRef ena = LLVMBuildICmp(builder, LLVMIntULT, ac_get_thread_id(&ctx->ac), count, ""); - ac_build_ifcc(&ctx->ac, ena, 6506); - } + LLVMValueRef ena = LLVMBuildICmp(builder, LLVMIntULT, ac_get_thread_id(&ctx->ac), count, ""); + ac_build_ifcc(&ctx->ac, ena, 6506); LLVMValueRef params[AC_MAX_ARGS]; unsigned num_params = LLVMCountParams(ctx->main_fn.value); @@ -403,6 +388,16 @@ static void si_build_wrapper_function(struct si_shader_context *ctx, LLVMValueRef ret = ac_build_call(&ctx->ac, parts[0].pointee_type, parts[0].value, params, num_params); + if (LLVMGetTypeKind(LLVMTypeOf(ret)) != LLVMVoidTypeKind) { + LLVMValueRef ret_var = ac_build_alloca_undef(&ctx->ac, LLVMTypeOf(ret), ""); + LLVMBuildStore(builder, ret, ret_var); + ac_build_endif(&ctx->ac, 6506); + + ret = LLVMBuildLoad2(builder, LLVMTypeOf(ret), ret_var, ""); + } else { + ac_build_endif(&ctx->ac, 6506); + } + if (same_thread_count) { LLVMTypeRef type = LLVMTypeOf(ret); assert(LLVMGetTypeKind(type) == LLVMStructTypeKind); @@ -432,16 +427,6 @@ static void si_build_wrapper_function(struct si_shader_context *ctx, } } } else { - ac_build_endif(&ctx->ac, 6506); - - if (ctx->stage == MESA_SHADER_TESS_CTRL) { - LLVMValueRef count = ac_get_arg(&ctx->ac, ctx->args->ac.merged_wave_info); - count = LLVMBuildLShr(builder, count, LLVMConstInt(ctx->ac.i32, 8, 0), ""); - count = LLVMBuildAnd(builder, count, LLVMConstInt(ctx->ac.i32, 0x7f, 0), ""); - - LLVMValueRef ena = LLVMBuildICmp(builder, LLVMIntULT, ac_get_thread_id(&ctx->ac), count, ""); - ac_build_ifcc(&ctx->ac, ena, 6507); - } /* The second half of the merged shader should use * the inputs from the toplevel (wrapper) function, @@ -457,11 +442,6 @@ static void si_build_wrapper_function(struct si_shader_context *ctx, } ac_build_call(&ctx->ac, parts[1].pointee_type, parts[1].value, params, num_params); - - /* Close the conditional wrapping the second shader. */ - if (ctx->stage == MESA_SHADER_TESS_CTRL && !same_thread_count) - ac_build_endif(&ctx->ac, 6507); - LLVMBuildRetVoid(builder); } @@ -645,8 +625,7 @@ static bool si_llvm_translate_nir(struct si_shader_context *ctx, struct si_shade LLVMValueRef thread_enabled = NULL; - if ((ctx->stage == MESA_SHADER_GEOMETRY && !shader->key.ge.as_ngg) || - (ctx->stage == MESA_SHADER_TESS_CTRL && !shader->is_monolithic)) { + if (ctx->stage == MESA_SHADER_GEOMETRY && !shader->key.ge.as_ngg) { /* Wrap both shaders in an if statement according to the number of enabled threads * there. For monolithic TCS, the if statement is inserted by the wrapper function, * not here. For NGG GS, the if statement is inserted by nir lowering. @@ -738,11 +717,6 @@ static bool si_llvm_translate_nir(struct si_shader_context *ctx, struct si_shade si_llvm_es_build_end(ctx); break; - case MESA_SHADER_TESS_CTRL: - if (!shader->is_monolithic) - si_llvm_tcs_build_end(ctx); - break; - case MESA_SHADER_TESS_EVAL: if (ctx->shader->key.ge.as_es) si_llvm_es_build_end(ctx); diff --git a/src/gallium/drivers/radeonsi/si_shader_llvm.h b/src/gallium/drivers/radeonsi/si_shader_llvm.h index e71627d432f..9daf3ae728f 100644 --- a/src/gallium/drivers/radeonsi/si_shader_llvm.h +++ b/src/gallium/drivers/radeonsi/si_shader_llvm.h @@ -64,7 +64,6 @@ void si_llvm_gs_build_end(struct si_shader_context *ctx); /* si_shader_llvm_tess.c */ void si_llvm_ls_build_end(struct si_shader_context *ctx); -void si_llvm_tcs_build_end(struct si_shader_context *ctx); void si_llvm_init_tcs_callbacks(struct si_shader_context *ctx); /* si_shader_llvm_ps.c */ diff --git a/src/gallium/drivers/radeonsi/si_shader_llvm_tess.c b/src/gallium/drivers/radeonsi/si_shader_llvm_tess.c index 8af63be60a8..17846bcb0ab 100644 --- a/src/gallium/drivers/radeonsi/si_shader_llvm_tess.c +++ b/src/gallium/drivers/radeonsi/si_shader_llvm_tess.c @@ -32,13 +32,6 @@ static LLVMValueRef si_nir_load_tcs_varyings(struct ac_shader_abi *abi, LLVMType return ac_build_varying_gather_values(&ctx->ac, value, num_components, component); } -void si_llvm_tcs_build_end(struct si_shader_context *ctx) -{ - if (ctx->screen->info.gfx_level >= GFX9) { - ac_build_endif(&ctx->ac, SI_MERGED_WRAP_IF_LABEL); - } -} - void si_llvm_ls_build_end(struct si_shader_context *ctx) { struct si_shader *shader = ctx->shader;