radv/rt: implement radv_nir_lower_rt_abi to lower RT shaders for separate compilation

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/22096>
This commit is contained in:
Daniel Schürmann 2023-03-06 20:03:49 +01:00 committed by Marge Bot
parent d4409769c7
commit 99466ca185
2 changed files with 119 additions and 0 deletions

View file

@ -26,6 +26,7 @@
#include "bvh/bvh.h"
#include "meta/radv_meta.h"
#include "ac_nir.h"
#include "radv_private.h"
#include "radv_rt_common.h"
#include "radv_shader.h"
@ -88,6 +89,8 @@ struct rt_variables {
* the correct resume index upon returning.
*/
nir_variable *idx;
nir_variable *shader_va;
nir_variable *traversal_addr;
/* scratch offset of the argument area relative to stack_ptr */
nir_variable *arg;
@ -129,6 +132,10 @@ create_rt_variables(nir_shader *shader, const VkRayTracingPipelineCreateInfoKHR
.create_info = create_info,
};
vars.idx = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "idx");
vars.shader_va =
nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "shader_va");
vars.traversal_addr =
nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "traversal_addr");
vars.arg = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "arg");
vars.stack_ptr = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "stack_ptr");
vars.shader_record_ptr =
@ -177,6 +184,8 @@ map_rt_variables(struct hash_table *var_remap, struct rt_variables *src,
src->create_info = dst->create_info;
_mesa_hash_table_insert(var_remap, src->idx, dst->idx);
_mesa_hash_table_insert(var_remap, src->shader_va, dst->shader_va);
_mesa_hash_table_insert(var_remap, src->traversal_addr, dst->traversal_addr);
_mesa_hash_table_insert(var_remap, src->arg, dst->arg);
_mesa_hash_table_insert(var_remap, src->stack_ptr, dst->stack_ptr);
_mesa_hash_table_insert(var_remap, src->shader_record_ptr, dst->shader_record_ptr);
@ -1702,3 +1711,109 @@ create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInf
return b.shader;
}
void
radv_nir_lower_rt_abi(nir_shader *shader, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
const struct radv_shader_args *args, const struct radv_pipeline_key *key,
uint32_t *stack_size)
{
nir_builder b;
nir_function_impl *impl = nir_shader_get_entrypoint(shader);
nir_builder_init(&b, impl);
struct rt_variables vars = create_rt_variables(shader, pCreateInfo);
lower_rt_instructions(shader, &vars, 0);
if (stack_size) {
vars.stack_size = MAX2(vars.stack_size, shader->scratch_size);
*stack_size = MAX2(*stack_size, vars.stack_size);
}
shader->scratch_size = 0;
NIR_PASS(_, shader, nir_lower_returns);
nir_cf_list list;
nir_cf_extract(&list, nir_before_cf_list(&impl->body), nir_after_cf_list(&impl->body));
/* initialize variables */
b.cursor = nir_before_cf_list(&impl->body);
nir_ssa_def *traversal_addr = ac_nir_load_arg(&b, &args->ac, args->ac.rt.traversal_shader);
nir_store_var(&b, vars.traversal_addr, nir_pack_64_2x32(&b, traversal_addr), 1);
nir_ssa_def *shader_va = ac_nir_load_arg(&b, &args->ac, args->ac.rt.next_shader);
shader_va = nir_pack_64_2x32(&b, shader_va);
nir_store_var(&b, vars.shader_va, shader_va, 1);
nir_store_var(&b, vars.stack_ptr,
ac_nir_load_arg(&b, &args->ac, args->ac.rt.dynamic_callable_stack_base), 1);
nir_ssa_def *record_ptr = ac_nir_load_arg(&b, &args->ac, args->ac.rt.shader_record);
nir_store_var(&b, vars.shader_record_ptr, nir_pack_64_2x32(&b, record_ptr), 1);
nir_store_var(&b, vars.arg, ac_nir_load_arg(&b, &args->ac, args->ac.rt.payload_offset), 1);
nir_ssa_def *accel_struct = ac_nir_load_arg(&b, &args->ac, args->ac.rt.accel_struct);
nir_store_var(&b, vars.accel_struct, nir_pack_64_2x32(&b, accel_struct), 1);
nir_store_var(&b, vars.cull_mask_and_flags,
ac_nir_load_arg(&b, &args->ac, args->ac.rt.cull_mask_and_flags), 1);
nir_store_var(&b, vars.sbt_offset, ac_nir_load_arg(&b, &args->ac, args->ac.rt.sbt_offset), 1);
nir_store_var(&b, vars.sbt_stride, ac_nir_load_arg(&b, &args->ac, args->ac.rt.sbt_stride), 1);
nir_store_var(&b, vars.miss_index, ac_nir_load_arg(&b, &args->ac, args->ac.rt.miss_index), 1);
nir_store_var(&b, vars.origin, ac_nir_load_arg(&b, &args->ac, args->ac.rt.ray_origin), 0x7);
nir_store_var(&b, vars.tmin, ac_nir_load_arg(&b, &args->ac, args->ac.rt.ray_tmin), 1);
nir_store_var(&b, vars.direction, ac_nir_load_arg(&b, &args->ac, args->ac.rt.ray_direction),
0x7);
nir_store_var(&b, vars.tmax, ac_nir_load_arg(&b, &args->ac, args->ac.rt.ray_tmax), 1);
nir_store_var(&b, vars.primitive_id, ac_nir_load_arg(&b, &args->ac, args->ac.rt.primitive_id),
1);
nir_ssa_def *instance_addr = ac_nir_load_arg(&b, &args->ac, args->ac.rt.instance_addr);
nir_store_var(&b, vars.instance_addr, nir_pack_64_2x32(&b, instance_addr), 1);
nir_store_var(&b, vars.geometry_id_and_flags,
ac_nir_load_arg(&b, &args->ac, args->ac.rt.geometry_id_and_flags), 1);
nir_store_var(&b, vars.hit_kind, ac_nir_load_arg(&b, &args->ac, args->ac.rt.hit_kind), 1);
/* guard the shader, so that only the correct invocations execute it */
nir_ssa_def *shader_pc = ac_nir_load_arg(&b, &args->ac, args->ac.rt.shader_pc);
shader_pc = nir_pack_64_2x32(&b, shader_pc);
nir_ssa_def *cond = nir_ieq(&b, shader_pc, shader_va);
nir_if *shader_guard = nir_push_if(&b, cond);
shader_guard->control = nir_selection_control_divergent_always_taken;
nir_cf_reinsert(&list, b.cursor);
nir_pop_if(&b, shader_guard);
/* select next shader */
// TODO: use a priority-based selection
b.cursor = nir_after_cf_list(&impl->body);
shader_va = nir_load_var(&b, vars.shader_va);
nir_ssa_def *next = nir_read_first_invocation(&b, shader_va);
ac_nir_store_arg(&b, &args->ac, args->ac.rt.shader_pc, next);
/* store back all variables to registers */
ac_nir_store_arg(&b, &args->ac, args->ac.rt.dynamic_callable_stack_base,
nir_load_var(&b, vars.stack_ptr));
ac_nir_store_arg(&b, &args->ac, args->ac.rt.next_shader, nir_load_var(&b, vars.shader_va));
ac_nir_store_arg(&b, &args->ac, args->ac.rt.shader_record,
nir_load_var(&b, vars.shader_record_ptr));
ac_nir_store_arg(&b, &args->ac, args->ac.rt.payload_offset, nir_load_var(&b, vars.arg));
ac_nir_store_arg(&b, &args->ac, args->ac.rt.accel_struct, nir_load_var(&b, vars.accel_struct));
ac_nir_store_arg(&b, &args->ac, args->ac.rt.cull_mask_and_flags,
nir_load_var(&b, vars.cull_mask_and_flags));
ac_nir_store_arg(&b, &args->ac, args->ac.rt.sbt_offset, nir_load_var(&b, vars.sbt_offset));
ac_nir_store_arg(&b, &args->ac, args->ac.rt.sbt_stride, nir_load_var(&b, vars.sbt_stride));
ac_nir_store_arg(&b, &args->ac, args->ac.rt.miss_index, nir_load_var(&b, vars.miss_index));
ac_nir_store_arg(&b, &args->ac, args->ac.rt.ray_origin, nir_load_var(&b, vars.origin));
ac_nir_store_arg(&b, &args->ac, args->ac.rt.ray_tmin, nir_load_var(&b, vars.tmin));
ac_nir_store_arg(&b, &args->ac, args->ac.rt.ray_direction, nir_load_var(&b, vars.direction));
ac_nir_store_arg(&b, &args->ac, args->ac.rt.ray_tmax, nir_load_var(&b, vars.tmax));
ac_nir_store_arg(&b, &args->ac, args->ac.rt.primitive_id, nir_load_var(&b, vars.primitive_id));
ac_nir_store_arg(&b, &args->ac, args->ac.rt.instance_addr, nir_load_var(&b, vars.instance_addr));
ac_nir_store_arg(&b, &args->ac, args->ac.rt.geometry_id_and_flags,
nir_load_var(&b, vars.geometry_id_and_flags));
ac_nir_store_arg(&b, &args->ac, args->ac.rt.hit_kind, nir_load_var(&b, vars.hit_kind));
/* cleanup passes */
NIR_PASS_V(shader, nir_lower_global_vars_to_local);
NIR_PASS_V(shader, nir_lower_vars_to_ssa);
if (shader->info.stage == MESA_SHADER_CLOSEST_HIT ||
shader->info.stage == MESA_SHADER_INTERSECTION)
NIR_PASS_V(shader, lower_hit_attribs, NULL, key->cs.compute_subgroup_size);
}

View file

@ -583,6 +583,10 @@ nir_shader *radv_parse_rt_stage(struct radv_device *device,
const VkPipelineShaderStageCreateInfo *sinfo,
const struct radv_pipeline_key *key);
void radv_nir_lower_rt_abi(nir_shader *shader, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
const struct radv_shader_args *args, const struct radv_pipeline_key *key,
uint32_t *stack_size);
struct radv_pipeline_stage;
nir_shader *radv_shader_spirv_to_nir(struct radv_device *device,