radv: unconditionally enable scratch for RT shaders

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/21159>
This commit is contained in:
Daniel Schürmann 2023-02-08 10:35:30 +01:00 committed by Marge Bot
parent aa362b4b6f
commit b338d59047
7 changed files with 9 additions and 5 deletions

View file

@ -9037,6 +9037,7 @@ visit_intrinsic(isel_context* ctx, nir_intrinsic_instr* instr)
case nir_intrinsic_load_rt_dynamic_callable_stack_base_amd:
bld.copy(Definition(get_ssa_temp(ctx, &instr->dest.ssa)),
get_arg(ctx, ctx->args->ac.rt_dynamic_callable_stack_base));
ctx->program->rt_stack = true;
break;
case nir_intrinsic_overwrite_vs_arguments_amd: {
ctx->arg_temps[ctx->args->ac.vertex_id.arg_index] = get_ssa_temp(ctx, instr->src[0].ssa);

View file

@ -2168,6 +2168,7 @@ public:
uint16_t min_waves = 0;
unsigned workgroup_size; /* if known; otherwise UINT_MAX */
bool wgp_mode;
bool rt_stack = false;
bool needs_vcc = false;

View file

@ -310,7 +310,8 @@ uint16_t
get_extra_sgprs(Program* program)
{
/* We don't use this register on GFX6-8 and it's removed on GFX10+. */
bool needs_flat_scr = program->config->scratch_bytes_per_wave && program->gfx_level == GFX9;
bool needs_flat_scr =
(program->config->scratch_bytes_per_wave || program->rt_stack) && program->gfx_level == GFX9;
if (program->gfx_level >= GFX10) {
assert(!program->dev.xnack_enabled);

View file

@ -2461,7 +2461,7 @@ lower_to_hw_instr(Program* program)
}
case aco_opcode::p_init_scratch: {
assert(program->gfx_level >= GFX8 && program->gfx_level <= GFX10_3);
if (!program->config->scratch_bytes_per_wave)
if (!program->config->scratch_bytes_per_wave && !program->rt_stack)
break;
Operand scratch_addr = instr->operands[0];

View file

@ -225,9 +225,10 @@ radv_pipeline_init_scratch(const struct radv_device *device, struct radv_pipelin
{
unsigned scratch_bytes_per_wave = 0;
unsigned max_waves = 0;
bool is_rt = pipeline->type == RADV_PIPELINE_RAY_TRACING;
for (int i = 0; i < MESA_VULKAN_SHADER_STAGES; ++i) {
if (pipeline->shaders[i] && pipeline->shaders[i]->config.scratch_bytes_per_wave) {
if (pipeline->shaders[i] && (pipeline->shaders[i]->config.scratch_bytes_per_wave || is_rt)) {
unsigned max_stage_waves = device->scratch_waves;
scratch_bytes_per_wave =

View file

@ -1659,7 +1659,7 @@ create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInf
nir_pop_loop(&b, loop);
if (radv_rt_pipeline_has_dynamic_stack_size(pCreateInfo))
b.shader->scratch_size = 4; /* To enable scratch. */
b.shader->scratch_size = 0; /* Stack size is set by the application. */
else
b.shader->scratch_size += compute_rt_stack_size(pCreateInfo, stack_sizes);

View file

@ -1721,7 +1721,7 @@ radv_postprocess_config(const struct radv_device *device, const struct ac_shader
struct ac_shader_config *config_out)
{
const struct radv_physical_device *pdevice = device->physical_device;
bool scratch_enabled = config_in->scratch_bytes_per_wave > 0;
bool scratch_enabled = config_in->scratch_bytes_per_wave > 0 || info->cs.is_rt_shader;
bool trap_enabled = !!device->trap_handler_shader;
unsigned vgpr_comp_cnt = 0;
unsigned num_input_vgprs = args->ac.num_vgprs_used;