radv/rt: create traversal shader independent from main shader

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/19188>
This commit is contained in:
Daniel Schürmann 2022-10-14 12:09:12 +02:00 committed by Marge Bot
parent 22534e0d1a
commit f4270b7659

View file

@ -677,6 +677,45 @@ lower_rt_instructions(nir_shader *shader, struct rt_variables *vars, unsigned ca
nir_pop_if(&b_shader, NULL);
break;
}
case nir_intrinsic_load_sbt_offset_amd: {
ret = nir_load_var(&b_shader, vars->sbt_offset);
break;
}
case nir_intrinsic_load_sbt_stride_amd: {
ret = nir_load_var(&b_shader, vars->sbt_stride);
break;
}
case nir_intrinsic_load_accel_struct_amd: {
ret = nir_load_var(&b_shader, vars->accel_struct);
break;
}
case nir_intrinsic_execute_closest_hit_amd: {
nir_store_var(&b_shader, vars->tmax, intr->src[1].ssa, 0x1);
nir_store_var(&b_shader, vars->primitive_id, intr->src[2].ssa, 0x1);
nir_store_var(&b_shader, vars->instance_addr, intr->src[3].ssa, 0x1);
nir_store_var(&b_shader, vars->geometry_id_and_flags, intr->src[4].ssa, 0x1);
nir_store_var(&b_shader, vars->hit_kind, intr->src[5].ssa, 0x1);
load_sbt_entry(&b_shader, vars, intr->src[0].ssa, SBT_HIT, 0);
nir_ssa_def *should_return =
nir_ior(&b_shader,
nir_test_mask(&b_shader, nir_load_var(&b_shader, vars->flags),
SpvRayFlagsSkipClosestHitShaderKHRMask),
nir_ieq_imm(&b_shader, nir_load_var(&b_shader, vars->idx), 0));
/* should_return is set if we had a hit but we won't be calling the closest hit
* shader and hence need to return immediately to the calling shader. */
nir_push_if(&b_shader, should_return);
insert_rt_return(&b_shader, vars);
nir_pop_if(&b_shader, NULL);
break;
}
case nir_intrinsic_execute_miss_amd: {
nir_store_var(&b_shader, vars->tmax, intr->src[0].ssa, 0x1);
nir_ssa_def *miss_index = nir_load_var(&b_shader, vars->miss_index);
load_sbt_entry(&b_shader, vars, miss_index, SBT_MISS, 0);
break;
}
default:
continue;
}
@ -1314,7 +1353,7 @@ load_stack_entry(nir_builder *b, nir_ssa_def *index, const struct radv_ray_trave
static nir_shader *
build_traversal_shader(struct radv_device *device,
const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
const struct rt_variables *dst_vars, struct hash_table *var_remap)
struct radv_pipeline_shader_stack_size *stack_sizes)
{
nir_builder b = radv_meta_init_shader(device, MESA_SHADER_COMPUTE, "rt_traversal");
b.shader->info.internal = false;
@ -1322,10 +1361,19 @@ build_traversal_shader(struct radv_device *device,
b.shader->info.workgroup_size[1] = device->physical_device->rt_wave_size == 64 ? 8 : 4;
b.shader->info.shared_size =
device->physical_device->rt_wave_size * MAX_STACK_ENTRY_COUNT * sizeof(uint32_t);
struct rt_variables vars = create_rt_variables(b.shader, pCreateInfo, dst_vars->stack_sizes);
map_rt_variables(var_remap, &vars, dst_vars);
struct rt_variables vars = create_rt_variables(b.shader, pCreateInfo, stack_sizes);
nir_ssa_def *accel_struct = nir_load_var(&b, vars.accel_struct);
/* initialize trace_ray arguments */
nir_ssa_def *accel_struct = nir_load_accel_struct_amd(&b);
nir_store_var(&b, vars.flags, nir_load_ray_flags(&b), 0x1);
nir_store_var(&b, vars.cull_mask, nir_load_cull_mask(&b), 0x1);
nir_store_var(&b, vars.sbt_offset, nir_load_sbt_offset_amd(&b), 0x1);
nir_store_var(&b, vars.sbt_stride, nir_load_sbt_stride_amd(&b), 0x1);
nir_store_var(&b, vars.origin, nir_load_ray_world_origin(&b), 0x7);
nir_store_var(&b, vars.tmin, nir_load_ray_t_min(&b), 0x1);
nir_store_var(&b, vars.direction, nir_load_ray_world_direction(&b), 0x7);
nir_store_var(&b, vars.tmax, nir_load_ray_t_max(&b), 0x1);
nir_store_var(&b, vars.stack_ptr, nir_imm_int(&b, 0), 0x1);
struct rt_traversal_vars trav_vars = init_traversal_vars(&b);
@ -1412,54 +1460,30 @@ build_traversal_shader(struct radv_device *device,
/* Initialize follow-up shader. */
nir_push_if(&b, nir_load_var(&b, trav_vars.hit));
{
/* vars.idx contains the SBT index at this point. */
load_sbt_entry(&b, &vars, nir_load_var(&b, vars.idx), SBT_HIT, 0);
nir_ssa_def *should_return = nir_ior(
&b,
nir_test_mask(&b, nir_load_var(&b, vars.flags), SpvRayFlagsSkipClosestHitShaderKHRMask),
nir_ieq_imm(&b, nir_load_var(&b, vars.idx), 0));
/* should_return is set if we had a hit but we won't be calling the closest hit shader and
* hence need to return immediately to the calling shader. */
nir_push_if(&b, should_return);
{
insert_rt_return(&b, &vars);
}
nir_pop_if(&b, NULL);
nir_execute_closest_hit_amd(
&b, nir_load_var(&b, vars.idx), nir_load_var(&b, vars.tmax),
nir_load_var(&b, vars.primitive_id), nir_load_var(&b, vars.instance_addr),
nir_load_var(&b, vars.geometry_id_and_flags), nir_load_var(&b, vars.hit_kind));
}
nir_push_else(&b, NULL);
{
/* Only load the miss shader if we actually miss. It is valid to not specify an SBT pointer
* for miss shaders if none of the rays miss. */
load_sbt_entry(&b, &vars, nir_load_var(&b, vars.miss_index), SBT_MISS, 0);
nir_execute_miss_amd(&b, nir_load_var(&b, vars.tmax));
}
nir_pop_if(&b, NULL);
/* Deal with all the inline functions. */
nir_index_ssa_defs(nir_shader_get_entrypoint(b.shader));
nir_metadata_preserve(nir_shader_get_entrypoint(b.shader), nir_metadata_none);
/* Lower and cleanup variables */
NIR_PASS_V(b.shader, nir_lower_global_vars_to_local);
NIR_PASS_V(b.shader, nir_lower_vars_to_ssa);
return b.shader;
}
static void
insert_traversal(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
nir_builder *b, const struct rt_variables *vars)
{
struct hash_table *var_remap = _mesa_pointer_hash_table_create(NULL);
nir_shader *shader = build_traversal_shader(device, pCreateInfo, vars, var_remap);
assert(b->shader->info.shared_size == 0);
b->shader->info.shared_size = shader->info.shared_size;
assert(b->shader->info.shared_size <= 32768);
/* For now, just inline the traversal shader */
nir_push_if(b, nir_ieq_imm(b, nir_load_var(b, vars->idx), 1));
nir_inline_function_impl(b, nir_shader_get_entrypoint(shader), NULL, var_remap);
nir_pop_if(b, NULL);
/* Adopt the instructions from the source shader, since they are merely moved, not cloned. */
ralloc_adopt(ralloc_context(b->shader), ralloc_context(shader));
ralloc_free(var_remap);
}
static unsigned
compute_rt_stack_size(const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
const struct radv_pipeline_shader_stack_size *stack_sizes)
@ -1598,10 +1622,15 @@ create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInf
nir_jump(&b, nir_jump_break);
nir_pop_if(&b, NULL);
insert_traversal(device, pCreateInfo, &b, &vars);
nir_ssa_def *idx = nir_load_var(&b, vars.idx);
/* Insert traversal shader */
nir_shader *traversal = build_traversal_shader(device, pCreateInfo, stack_sizes);
assert(b.shader->info.shared_size == 0);
b.shader->info.shared_size = traversal->info.shared_size;
assert(b.shader->info.shared_size <= 32768);
insert_rt_case(&b, traversal, &vars, idx, 0, 1);
/* We do a trick with the indexing of the resume shaders so that the first
* shader of stage x always gets id x and the resume shader ids then come after
* stageCount. This makes the shadergroup handles independent of compilation. */