mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2026-05-06 05:08:08 +02:00
radv/rt: implement radv_nir_lower_rt_abi to lower RT shaders for separate compilation
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/22096>
This commit is contained in:
parent
d4409769c7
commit
99466ca185
2 changed files with 119 additions and 0 deletions
|
|
@ -26,6 +26,7 @@
|
|||
|
||||
#include "bvh/bvh.h"
|
||||
#include "meta/radv_meta.h"
|
||||
#include "ac_nir.h"
|
||||
#include "radv_private.h"
|
||||
#include "radv_rt_common.h"
|
||||
#include "radv_shader.h"
|
||||
|
|
@ -88,6 +89,8 @@ struct rt_variables {
|
|||
* the correct resume index upon returning.
|
||||
*/
|
||||
nir_variable *idx;
|
||||
nir_variable *shader_va;
|
||||
nir_variable *traversal_addr;
|
||||
|
||||
/* scratch offset of the argument area relative to stack_ptr */
|
||||
nir_variable *arg;
|
||||
|
|
@ -129,6 +132,10 @@ create_rt_variables(nir_shader *shader, const VkRayTracingPipelineCreateInfoKHR
|
|||
.create_info = create_info,
|
||||
};
|
||||
vars.idx = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "idx");
|
||||
vars.shader_va =
|
||||
nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "shader_va");
|
||||
vars.traversal_addr =
|
||||
nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "traversal_addr");
|
||||
vars.arg = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "arg");
|
||||
vars.stack_ptr = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "stack_ptr");
|
||||
vars.shader_record_ptr =
|
||||
|
|
@ -177,6 +184,8 @@ map_rt_variables(struct hash_table *var_remap, struct rt_variables *src,
|
|||
src->create_info = dst->create_info;
|
||||
|
||||
_mesa_hash_table_insert(var_remap, src->idx, dst->idx);
|
||||
_mesa_hash_table_insert(var_remap, src->shader_va, dst->shader_va);
|
||||
_mesa_hash_table_insert(var_remap, src->traversal_addr, dst->traversal_addr);
|
||||
_mesa_hash_table_insert(var_remap, src->arg, dst->arg);
|
||||
_mesa_hash_table_insert(var_remap, src->stack_ptr, dst->stack_ptr);
|
||||
_mesa_hash_table_insert(var_remap, src->shader_record_ptr, dst->shader_record_ptr);
|
||||
|
|
@ -1702,3 +1711,109 @@ create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInf
|
|||
|
||||
return b.shader;
|
||||
}
|
||||
|
||||
void
|
||||
radv_nir_lower_rt_abi(nir_shader *shader, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
|
||||
const struct radv_shader_args *args, const struct radv_pipeline_key *key,
|
||||
uint32_t *stack_size)
|
||||
{
|
||||
nir_builder b;
|
||||
nir_function_impl *impl = nir_shader_get_entrypoint(shader);
|
||||
nir_builder_init(&b, impl);
|
||||
|
||||
struct rt_variables vars = create_rt_variables(shader, pCreateInfo);
|
||||
lower_rt_instructions(shader, &vars, 0);
|
||||
|
||||
if (stack_size) {
|
||||
vars.stack_size = MAX2(vars.stack_size, shader->scratch_size);
|
||||
*stack_size = MAX2(*stack_size, vars.stack_size);
|
||||
}
|
||||
shader->scratch_size = 0;
|
||||
|
||||
NIR_PASS(_, shader, nir_lower_returns);
|
||||
|
||||
nir_cf_list list;
|
||||
nir_cf_extract(&list, nir_before_cf_list(&impl->body), nir_after_cf_list(&impl->body));
|
||||
|
||||
/* initialize variables */
|
||||
b.cursor = nir_before_cf_list(&impl->body);
|
||||
|
||||
nir_ssa_def *traversal_addr = ac_nir_load_arg(&b, &args->ac, args->ac.rt.traversal_shader);
|
||||
nir_store_var(&b, vars.traversal_addr, nir_pack_64_2x32(&b, traversal_addr), 1);
|
||||
nir_ssa_def *shader_va = ac_nir_load_arg(&b, &args->ac, args->ac.rt.next_shader);
|
||||
shader_va = nir_pack_64_2x32(&b, shader_va);
|
||||
nir_store_var(&b, vars.shader_va, shader_va, 1);
|
||||
nir_store_var(&b, vars.stack_ptr,
|
||||
ac_nir_load_arg(&b, &args->ac, args->ac.rt.dynamic_callable_stack_base), 1);
|
||||
nir_ssa_def *record_ptr = ac_nir_load_arg(&b, &args->ac, args->ac.rt.shader_record);
|
||||
nir_store_var(&b, vars.shader_record_ptr, nir_pack_64_2x32(&b, record_ptr), 1);
|
||||
nir_store_var(&b, vars.arg, ac_nir_load_arg(&b, &args->ac, args->ac.rt.payload_offset), 1);
|
||||
|
||||
nir_ssa_def *accel_struct = ac_nir_load_arg(&b, &args->ac, args->ac.rt.accel_struct);
|
||||
nir_store_var(&b, vars.accel_struct, nir_pack_64_2x32(&b, accel_struct), 1);
|
||||
nir_store_var(&b, vars.cull_mask_and_flags,
|
||||
ac_nir_load_arg(&b, &args->ac, args->ac.rt.cull_mask_and_flags), 1);
|
||||
nir_store_var(&b, vars.sbt_offset, ac_nir_load_arg(&b, &args->ac, args->ac.rt.sbt_offset), 1);
|
||||
nir_store_var(&b, vars.sbt_stride, ac_nir_load_arg(&b, &args->ac, args->ac.rt.sbt_stride), 1);
|
||||
nir_store_var(&b, vars.miss_index, ac_nir_load_arg(&b, &args->ac, args->ac.rt.miss_index), 1);
|
||||
nir_store_var(&b, vars.origin, ac_nir_load_arg(&b, &args->ac, args->ac.rt.ray_origin), 0x7);
|
||||
nir_store_var(&b, vars.tmin, ac_nir_load_arg(&b, &args->ac, args->ac.rt.ray_tmin), 1);
|
||||
nir_store_var(&b, vars.direction, ac_nir_load_arg(&b, &args->ac, args->ac.rt.ray_direction),
|
||||
0x7);
|
||||
nir_store_var(&b, vars.tmax, ac_nir_load_arg(&b, &args->ac, args->ac.rt.ray_tmax), 1);
|
||||
|
||||
nir_store_var(&b, vars.primitive_id, ac_nir_load_arg(&b, &args->ac, args->ac.rt.primitive_id),
|
||||
1);
|
||||
nir_ssa_def *instance_addr = ac_nir_load_arg(&b, &args->ac, args->ac.rt.instance_addr);
|
||||
nir_store_var(&b, vars.instance_addr, nir_pack_64_2x32(&b, instance_addr), 1);
|
||||
nir_store_var(&b, vars.geometry_id_and_flags,
|
||||
ac_nir_load_arg(&b, &args->ac, args->ac.rt.geometry_id_and_flags), 1);
|
||||
nir_store_var(&b, vars.hit_kind, ac_nir_load_arg(&b, &args->ac, args->ac.rt.hit_kind), 1);
|
||||
|
||||
/* guard the shader, so that only the correct invocations execute it */
|
||||
nir_ssa_def *shader_pc = ac_nir_load_arg(&b, &args->ac, args->ac.rt.shader_pc);
|
||||
shader_pc = nir_pack_64_2x32(&b, shader_pc);
|
||||
nir_ssa_def *cond = nir_ieq(&b, shader_pc, shader_va);
|
||||
nir_if *shader_guard = nir_push_if(&b, cond);
|
||||
shader_guard->control = nir_selection_control_divergent_always_taken;
|
||||
nir_cf_reinsert(&list, b.cursor);
|
||||
nir_pop_if(&b, shader_guard);
|
||||
|
||||
/* select next shader */
|
||||
// TODO: use a priority-based selection
|
||||
b.cursor = nir_after_cf_list(&impl->body);
|
||||
shader_va = nir_load_var(&b, vars.shader_va);
|
||||
nir_ssa_def *next = nir_read_first_invocation(&b, shader_va);
|
||||
ac_nir_store_arg(&b, &args->ac, args->ac.rt.shader_pc, next);
|
||||
|
||||
/* store back all variables to registers */
|
||||
ac_nir_store_arg(&b, &args->ac, args->ac.rt.dynamic_callable_stack_base,
|
||||
nir_load_var(&b, vars.stack_ptr));
|
||||
ac_nir_store_arg(&b, &args->ac, args->ac.rt.next_shader, nir_load_var(&b, vars.shader_va));
|
||||
ac_nir_store_arg(&b, &args->ac, args->ac.rt.shader_record,
|
||||
nir_load_var(&b, vars.shader_record_ptr));
|
||||
ac_nir_store_arg(&b, &args->ac, args->ac.rt.payload_offset, nir_load_var(&b, vars.arg));
|
||||
ac_nir_store_arg(&b, &args->ac, args->ac.rt.accel_struct, nir_load_var(&b, vars.accel_struct));
|
||||
ac_nir_store_arg(&b, &args->ac, args->ac.rt.cull_mask_and_flags,
|
||||
nir_load_var(&b, vars.cull_mask_and_flags));
|
||||
ac_nir_store_arg(&b, &args->ac, args->ac.rt.sbt_offset, nir_load_var(&b, vars.sbt_offset));
|
||||
ac_nir_store_arg(&b, &args->ac, args->ac.rt.sbt_stride, nir_load_var(&b, vars.sbt_stride));
|
||||
ac_nir_store_arg(&b, &args->ac, args->ac.rt.miss_index, nir_load_var(&b, vars.miss_index));
|
||||
ac_nir_store_arg(&b, &args->ac, args->ac.rt.ray_origin, nir_load_var(&b, vars.origin));
|
||||
ac_nir_store_arg(&b, &args->ac, args->ac.rt.ray_tmin, nir_load_var(&b, vars.tmin));
|
||||
ac_nir_store_arg(&b, &args->ac, args->ac.rt.ray_direction, nir_load_var(&b, vars.direction));
|
||||
ac_nir_store_arg(&b, &args->ac, args->ac.rt.ray_tmax, nir_load_var(&b, vars.tmax));
|
||||
|
||||
ac_nir_store_arg(&b, &args->ac, args->ac.rt.primitive_id, nir_load_var(&b, vars.primitive_id));
|
||||
ac_nir_store_arg(&b, &args->ac, args->ac.rt.instance_addr, nir_load_var(&b, vars.instance_addr));
|
||||
ac_nir_store_arg(&b, &args->ac, args->ac.rt.geometry_id_and_flags,
|
||||
nir_load_var(&b, vars.geometry_id_and_flags));
|
||||
ac_nir_store_arg(&b, &args->ac, args->ac.rt.hit_kind, nir_load_var(&b, vars.hit_kind));
|
||||
|
||||
/* cleanup passes */
|
||||
NIR_PASS_V(shader, nir_lower_global_vars_to_local);
|
||||
NIR_PASS_V(shader, nir_lower_vars_to_ssa);
|
||||
if (shader->info.stage == MESA_SHADER_CLOSEST_HIT ||
|
||||
shader->info.stage == MESA_SHADER_INTERSECTION)
|
||||
NIR_PASS_V(shader, lower_hit_attribs, NULL, key->cs.compute_subgroup_size);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -583,6 +583,10 @@ nir_shader *radv_parse_rt_stage(struct radv_device *device,
|
|||
const VkPipelineShaderStageCreateInfo *sinfo,
|
||||
const struct radv_pipeline_key *key);
|
||||
|
||||
void radv_nir_lower_rt_abi(nir_shader *shader, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
|
||||
const struct radv_shader_args *args, const struct radv_pipeline_key *key,
|
||||
uint32_t *stack_size);
|
||||
|
||||
struct radv_pipeline_stage;
|
||||
|
||||
nir_shader *radv_shader_spirv_to_nir(struct radv_device *device,
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue