radv/rt: Call ahit/isec shaders

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/39314>
This commit is contained in:
Natalie Vock 2025-11-11 18:11:00 +01:00 committed by Marge Bot
parent a03e9287c3
commit 30f6eacfad
5 changed files with 221 additions and 52 deletions

View file

@ -498,6 +498,22 @@ lower_rt_instruction(nir_builder *b, nir_instr *instr, void *_vars)
ret = nir_ushr_imm(b, nir_load_param(b, vars->cull_mask_and_flags_param), 24);
break;
}
case nir_intrinsic_load_ray_payload_ptr_amd: {
ret = nir_load_param(b, vars->in_payload_base_param + nir_intrinsic_base(intr));
break;
}
case nir_intrinsic_load_rt_descriptors_amd: {
ret = nir_load_param(b, RT_ARG_DESCRIPTORS);
break;
}
case nir_intrinsic_load_rt_dynamic_descriptors_amd: {
ret = nir_load_param(b, RT_ARG_DYNAMIC_DESCRIPTORS);
break;
}
case nir_intrinsic_load_rt_push_constants_amd: {
ret = nir_load_param(b, RT_ARG_PUSH_CONSTANTS);
break;
}
case nir_intrinsic_load_sbt_base_amd: {
ret = nir_load_param(b, RT_ARG_SBT_DESCRIPTORS);
break;

View file

@ -10,9 +10,12 @@
#include "nir/radv_nir_rt_stage_common.h"
#include "nir/radv_nir_rt_stage_cps.h"
#include "nir/radv_nir_rt_traversal_shader.h"
#include "aco_nir_call_attribs.h"
#include "nir_builder.h"
#include "radv_device.h"
#include "radv_meta_nir.h"
#include "radv_nir_rt_stage_functions.h"
#include "radv_physical_device.h"
#include "radv_rra.h"
@ -77,6 +80,7 @@ struct traversal_data {
struct radv_nir_rt_traversal_params *params;
struct traversal_vars trav_vars;
nir_function *ahit_isec_func;
struct radv_ray_tracing_pipeline *pipeline;
};
@ -870,33 +874,99 @@ handle_candidate_triangle(nir_builder *b, struct radv_triangle_intersection *int
nir_push_if(b, nir_inot(b, intersection->base.opaque));
{
struct radv_nir_sbt_data sbt_data =
radv_nir_load_sbt_entry(b, nir_load_sbt_base_amd(b), sbt_idx, SBT_HIT, SBT_ANY_HIT_IDX);
nir_store_var(b, ahit_vars.shader_record_ptr, sbt_data.shader_record_ptr, 0x1);
if (data->params->preprocess_ahit_isec) {
struct radv_nir_sbt_data sbt_data =
radv_nir_load_sbt_entry(b, nir_load_sbt_base_amd(b), sbt_idx, SBT_HIT, SBT_ANY_HIT_IDX);
struct traversal_inlining_params inlining_params = {
.device = data->device,
.trav_vars = &data->trav_vars,
.candidate = &candidate_result,
.anyhit_vars = &ahit_vars,
.preprocess = data->params->preprocess_ahit_isec,
.preprocess_data = data->params->cb_data,
};
nir_store_var(b, ahit_vars.shader_record_ptr, sbt_data.shader_record_ptr, 0x1);
struct radv_rt_case_data case_data = {
.device = data->device,
.pipeline = data->pipeline,
.param_data = &inlining_params,
};
struct traversal_inlining_params inlining_params = {
.device = data->device,
.trav_vars = &data->trav_vars,
.candidate = &candidate_result,
.anyhit_vars = &ahit_vars,
.preprocess = data->params->preprocess_ahit_isec,
.preprocess_data = data->params->cb_data,
};
if (data->trav_vars.ahit_isec_count)
nir_store_var(b, data->trav_vars.ahit_isec_count,
nir_iadd_imm(b, nir_load_var(b, data->trav_vars.ahit_isec_count), 1), 0x1);
struct radv_rt_case_data case_data = {
.device = data->device,
.pipeline = data->pipeline,
.param_data = &inlining_params,
};
radv_visit_inlined_shaders(
b, sbt_data.shader_addr,
!(data->pipeline->base.base.create_flags & VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_ANY_HIT_SHADERS_BIT_KHR),
&case_data, radv_ray_tracing_group_ahit_info, radv_build_ahit_case);
if (data->trav_vars.ahit_isec_count)
nir_store_var(b, data->trav_vars.ahit_isec_count,
nir_iadd_imm(b, nir_load_var(b, data->trav_vars.ahit_isec_count), 1), 0x1);
radv_visit_inlined_shaders(b, sbt_data.shader_addr,
!(data->pipeline->base.base.create_flags &
VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_ANY_HIT_SHADERS_BIT_KHR),
&case_data, radv_ray_tracing_group_ahit_info, radv_build_ahit_case);
} else {
struct radv_nir_sbt_data sbt_data =
radv_nir_load_sbt_entry(b, nir_load_sbt_base_amd(b), sbt_idx, SBT_HIT, SBT_AHIT_ISEC_PTR);
if (!(data->pipeline->base.base.create_flags &
VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_ANY_HIT_SHADERS_BIT_KHR)) {
nir_push_if(b, nir_ine_imm(b, sbt_data.shader_addr, 0));
}
unsigned hit_attrib_param_count = DIV_ROUND_UP(data->params->hit_attrib_size, 4);
unsigned payload_param_count = DIV_ROUND_UP(data->params->payload_size, 4);
nir_variable **hit_attrib_vars = rzalloc_array_size(b->shader, sizeof(nir_variable *), hit_attrib_param_count);
for (unsigned i = 0; i < hit_attrib_param_count; i++) {
hit_attrib_vars[i] =
nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_hit_attrib");
if (i < 2)
nir_store_var(b, hit_attrib_vars[i], nir_channel(b, intersection->barycentrics, i), 0x1);
}
unsigned param_count = AHIT_ISEC_ARG_HIT_ATTRIB_PAYLOAD_BASE + hit_attrib_param_count + payload_param_count;
nir_def **params = rzalloc_array_size(b->shader, sizeof(nir_def *), param_count);
params[RT_ARG_LAUNCH_ID] = nir_load_ray_launch_id(b);
params[RT_ARG_LAUNCH_SIZE] = nir_load_ray_launch_size(b);
params[RT_ARG_DESCRIPTORS] = nir_load_rt_descriptors_amd(b);
params[RT_ARG_DYNAMIC_DESCRIPTORS] = nir_load_rt_dynamic_descriptors_amd(b);
params[RT_ARG_PUSH_CONSTANTS] = nir_load_rt_push_constants_amd(b);
params[RT_ARG_SBT_DESCRIPTORS] = nir_load_sbt_base_amd(b);
params[AHIT_ISEC_ARG_SHADER_RECORD_PTR] = sbt_data.shader_record_ptr;
params[AHIT_ISEC_ARG_CULL_MASK_AND_FLAGS] = data->params->cull_mask_and_flags;
params[AHIT_ISEC_ARG_SBT_INDEX] = sbt_idx;
params[AHIT_ISEC_ARG_RAY_ORIGIN] = data->params->origin;
params[AHIT_ISEC_ARG_RAY_TMIN] = data->params->tmin;
params[AHIT_ISEC_ARG_RAY_DIRECTION] = data->params->direction;
params[AHIT_ISEC_ARG_CANDIDATE_RAY_TMAX] = intersection->t;
params[AHIT_ISEC_ARG_PRIMITIVE_ADDR] = intersection->base.node_addr;
params[AHIT_ISEC_ARG_PRIMITIVE_ID] = intersection->base.primitive_id;
params[AHIT_ISEC_ARG_INSTANCE_ADDR] = nir_load_var(b, data->trav_vars.instance_addr);
params[AHIT_ISEC_ARG_GEOMETRY_ID_AND_FLAGS] = intersection->base.geometry_id_and_flags;
params[AHIT_ISEC_ARG_OPAQUE] = intersection->base.opaque;
params[AHIT_ISEC_ARG_HIT_KIND] = &nir_build_deref_var(b, candidate_result.hit_kind)->def;
params[AHIT_ISEC_ARG_ACCEPT] = &nir_build_deref_var(b, ahit_vars.ahit_accept)->def;
params[AHIT_ISEC_ARG_TERMINATE] = &nir_build_deref_var(b, ahit_vars.ahit_terminate)->def;
params[AHIT_ISEC_ARG_COMMITTED_RAY_TMAX] = &nir_build_deref_var(b, candidate_result.tmax)->def;
for (unsigned i = 0; i < hit_attrib_param_count; ++i) {
params[AHIT_ISEC_ARG_HIT_ATTRIB_PAYLOAD_BASE + i] =
nir_instr_def(&nir_build_deref_var(b, hit_attrib_vars[i])->instr);
}
for (unsigned i = 0; i < payload_param_count; ++i) {
unsigned param_idx = AHIT_ISEC_ARG_HIT_ATTRIB_PAYLOAD_BASE + hit_attrib_param_count + i;
params[param_idx] = nir_instr_def(&nir_build_deref_cast(b, nir_load_ray_payload_ptr_amd(b, 32, .base = i),
nir_var_shader_call_data, glsl_uint_type(), 4)
->instr);
}
nir_build_indirect_call(b, data->ahit_isec_func, sbt_data.shader_addr, param_count, params);
if (!(data->pipeline->base.base.create_flags &
VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_ANY_HIT_SHADERS_BIT_KHR)) {
nir_push_else(b, NULL);
nir_store_var(b, ahit_vars.ahit_accept, nir_imm_true(b), 0x1);
nir_store_var(b, ahit_vars.ahit_terminate,
nir_test_mask(b, data->params->cull_mask_and_flags, SpvRayFlagsTerminateOnFirstHitKHRMask),
0x1);
nir_pop_if(b, NULL);
}
}
}
nir_pop_if(b, NULL);
@ -955,36 +1025,100 @@ handle_candidate_aabb(nir_builder *b, struct radv_leaf_intersection *intersectio
nir_store_var(b, data->trav_vars.ahit_isec_count,
nir_iadd_imm(b, nir_load_var(b, data->trav_vars.ahit_isec_count), 1 << 16), 0x1);
struct radv_nir_sbt_data sbt_data =
radv_nir_load_sbt_entry(b, nir_load_sbt_base_amd(b), sbt_idx, SBT_HIT, SBT_INTERSECTION_IDX);
nir_store_var(b, ahit_vars.shader_record_ptr, sbt_data.shader_record_ptr, 0x1);
if (data->params->preprocess_ahit_isec) {
struct radv_nir_sbt_data sbt_data =
radv_nir_load_sbt_entry(b, nir_load_sbt_base_amd(b), sbt_idx, SBT_HIT, SBT_INTERSECTION_IDX);
nir_store_var(b, ahit_vars.shader_record_ptr, sbt_data.shader_record_ptr, 0x1);
struct traversal_inlining_params inlining_params = {
.device = data->device,
.trav_vars = &data->trav_vars,
.candidate = &candidate_result,
.anyhit_vars = &ahit_vars,
.preprocess = data->params->preprocess_ahit_isec,
.preprocess_data = data->params->cb_data,
};
struct traversal_inlining_params inlining_params = {
.device = data->device,
.trav_vars = &data->trav_vars,
.candidate = &candidate_result,
.anyhit_vars = &ahit_vars,
.preprocess = data->params->preprocess_ahit_isec,
.preprocess_data = data->params->cb_data,
};
struct radv_rt_case_data case_data = {
.device = data->device,
.pipeline = data->pipeline,
.param_data = &inlining_params,
};
struct radv_rt_case_data case_data = {
.device = data->device,
.pipeline = data->pipeline,
.param_data = &inlining_params,
};
radv_visit_inlined_shaders(
b, sbt_data.shader_addr,
!(data->pipeline->base.base.create_flags & VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_INTERSECTION_SHADERS_BIT_KHR),
&case_data, radv_ray_tracing_group_isec_info, radv_build_isec_case);
radv_visit_inlined_shaders(b, sbt_data.shader_addr,
!(data->pipeline->base.base.create_flags &
VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_INTERSECTION_SHADERS_BIT_KHR),
&case_data, radv_ray_tracing_group_isec_info, radv_build_isec_case);
nir_push_if(b, nir_load_var(b, ahit_vars.ahit_accept));
{
copy_traversal_result(b, &data->trav_vars.result, &candidate_result);
nir_break_if(b, nir_load_var(b, ahit_vars.ahit_terminate));
nir_push_if(b, nir_load_var(b, ahit_vars.ahit_accept));
{
copy_traversal_result(b, &data->trav_vars.result, &candidate_result);
nir_break_if(b, nir_load_var(b, ahit_vars.ahit_terminate));
}
nir_pop_if(b, NULL);
} else {
struct radv_nir_sbt_data sbt_data =
radv_nir_load_sbt_entry(b, nir_load_sbt_base_amd(b), sbt_idx, SBT_HIT, SBT_AHIT_ISEC_PTR);
unsigned hit_attrib_param_count = DIV_ROUND_UP(data->params->hit_attrib_size, 4);
unsigned payload_param_count = DIV_ROUND_UP(data->params->payload_size, 4);
nir_variable **hit_attrib_vars = rzalloc_array_size(b->shader, sizeof(nir_variable *), hit_attrib_param_count);
for (unsigned i = 0; i < hit_attrib_param_count; i++) {
hit_attrib_vars[i] = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "inner_hit_attrib");
}
if (!(data->pipeline->base.base.create_flags &
VK_PIPELINE_CREATE_2_RAY_TRACING_NO_NULL_INTERSECTION_SHADERS_BIT_KHR)) {
nir_push_if(b, nir_ine_imm(b, sbt_data.shader_addr, 0));
}
unsigned param_count = AHIT_ISEC_ARG_HIT_ATTRIB_PAYLOAD_BASE + hit_attrib_param_count + payload_param_count;
nir_def **params = rzalloc_array_size(b->shader, sizeof(nir_def *), param_count);
params[RT_ARG_LAUNCH_ID] = nir_load_ray_launch_id(b);
params[RT_ARG_LAUNCH_SIZE] = nir_load_ray_launch_size(b);
params[RT_ARG_DESCRIPTORS] = nir_load_rt_descriptors_amd(b);
params[RT_ARG_DYNAMIC_DESCRIPTORS] = nir_load_rt_dynamic_descriptors_amd(b);
params[RT_ARG_PUSH_CONSTANTS] = nir_load_rt_push_constants_amd(b);
params[RT_ARG_SBT_DESCRIPTORS] = nir_load_sbt_base_amd(b);
params[AHIT_ISEC_ARG_SHADER_RECORD_PTR] = sbt_data.shader_record_ptr;
params[AHIT_ISEC_ARG_CULL_MASK_AND_FLAGS] = data->params->cull_mask_and_flags;
params[AHIT_ISEC_ARG_SBT_INDEX] = sbt_idx;
params[AHIT_ISEC_ARG_RAY_ORIGIN] = data->params->origin;
params[AHIT_ISEC_ARG_RAY_TMIN] = data->params->tmin;
params[AHIT_ISEC_ARG_RAY_DIRECTION] = data->params->direction;
params[AHIT_ISEC_ARG_CANDIDATE_RAY_TMAX] = nir_load_var(b, candidate_result.tmax);
params[AHIT_ISEC_ARG_PRIMITIVE_ADDR] = intersection->node_addr;
params[AHIT_ISEC_ARG_PRIMITIVE_ID] = intersection->primitive_id;
params[AHIT_ISEC_ARG_INSTANCE_ADDR] = nir_load_var(b, data->trav_vars.instance_addr);
params[AHIT_ISEC_ARG_GEOMETRY_ID_AND_FLAGS] = intersection->geometry_id_and_flags;
params[AHIT_ISEC_ARG_OPAQUE] = intersection->opaque;
params[AHIT_ISEC_ARG_HIT_KIND] = &nir_build_deref_var(b, candidate_result.hit_kind)->def;
params[AHIT_ISEC_ARG_ACCEPT] = &nir_build_deref_var(b, ahit_vars.ahit_accept)->def;
params[AHIT_ISEC_ARG_TERMINATE] = &nir_build_deref_var(b, ahit_vars.ahit_terminate)->def;
params[AHIT_ISEC_ARG_COMMITTED_RAY_TMAX] = &nir_build_deref_var(b, candidate_result.tmax)->def;
for (unsigned i = 0; i < hit_attrib_param_count; ++i) {
params[AHIT_ISEC_ARG_HIT_ATTRIB_PAYLOAD_BASE + i] =
nir_instr_def(&nir_build_deref_var(b, hit_attrib_vars[i])->instr);
}
for (unsigned i = 0; i < payload_param_count; ++i) {
unsigned param_idx = AHIT_ISEC_ARG_HIT_ATTRIB_PAYLOAD_BASE + hit_attrib_param_count + i;
params[param_idx] = nir_instr_def(&nir_build_deref_cast(b, nir_load_ray_payload_ptr_amd(b, 32, .base = i),
nir_var_shader_call_data, glsl_uint_type(), 4)
->instr);
}
nir_build_indirect_call(b, data->ahit_isec_func, sbt_data.shader_addr, param_count, params);
nir_pop_if(b, NULL);
nir_push_if(b, nir_load_var(b, ahit_vars.ahit_accept));
{
for (unsigned i = 0; i < hit_attrib_param_count; ++i) {
nir_store_hit_attrib_amd(b, nir_load_var(b, hit_attrib_vars[i]), .base = i);
}
copy_traversal_result(b, &data->trav_vars.result, &candidate_result);
nir_break_if(b, nir_load_var(b, ahit_vars.ahit_terminate));
}
nir_pop_if(b, NULL);
}
nir_pop_if(b, NULL);
}
static void
@ -1016,6 +1150,13 @@ radv_build_traversal(struct radv_device *device, struct radv_ray_tracing_pipelin
init_traversal_vars(b->shader, &data.trav_vars);
data.trav_vars.result.barycentrics = barycentrics;
if (!params->preprocess_ahit_isec) {
nir_function *ahit_isec_func = nir_function_create(b->shader, "ahit_isec_func");
radv_nir_init_rt_function_params(ahit_isec_func, MESA_SHADER_ANY_HIT, params->payload_size,
params->hit_attrib_size);
data.ahit_isec_func = ahit_isec_func;
}
struct radv_ray_traversal_vars trav_vars_args = {
.tmax = nir_build_deref_var(b, data.trav_vars.result.tmax),
.origin = nir_build_deref_var(b, data.trav_vars.origin),
@ -1149,7 +1290,8 @@ preprocess_traversal_shader_ahit_isec(nir_shader *nir, void *cb)
nir_shader *
radv_build_traversal_shader(struct radv_device *device, struct radv_ray_tracing_pipeline *pipeline,
struct radv_ray_tracing_stage_info *info, radv_nir_traversal_preprocess_cb preprocess)
struct radv_ray_tracing_stage_info *info, radv_nir_traversal_preprocess_cb preprocess,
uint32_t payload_size, uint32_t hit_attrib_size)
{
const struct radv_physical_device *pdev = radv_device_physical(device);
@ -1190,8 +1332,10 @@ radv_build_traversal_shader(struct radv_device *device, struct radv_ray_tracing_
params.origin = nir_load_ray_world_origin(&b);
params.direction = nir_load_ray_world_direction(&b);
params.preprocess_ahit_isec = preprocess_traversal_shader_ahit_isec;
params.preprocess_ahit_isec = preprocess ? preprocess_traversal_shader_ahit_isec : NULL;
params.cb_data = preprocess;
params.payload_size = payload_size;
params.hit_attrib_size = hit_attrib_size;
params.ignore_cull_mask = false;
struct radv_nir_rt_traversal_result result = radv_build_traversal(device, pipeline, &b, &params, info);

View file

@ -15,6 +15,7 @@ void radv_nir_lower_intersection_shader(nir_shader *intersection, nir_shader *an
nir_shader *radv_build_traversal_shader(struct radv_device *device, struct radv_ray_tracing_pipeline *pipeline,
struct radv_ray_tracing_stage_info *info,
radv_nir_traversal_preprocess_cb preprocess);
radv_nir_traversal_preprocess_cb preprocess, uint32_t payload_size,
uint32_t hit_attrib_size);
#endif // RADV_NIR_RT_TRAVERSAL_SHADER_H

View file

@ -909,8 +909,12 @@ radv_rt_compile_shaders(struct radv_device *device, struct vk_pipeline_cache *ca
radv_nir_traversal_preprocess_cb preprocess =
recursive_lowering_mode == RADV_RT_LOWERING_MODE_CPS ? radv_nir_lower_rt_io_cps : radv_nir_lower_rt_io_functions;
if (!inline_any_hit_shaders)
preprocess = NULL;
/* create traversal shader */
nir_shader *traversal_nir = radv_build_traversal_shader(device, pipeline, &traversal_info, preprocess);
nir_shader *traversal_nir =
radv_build_traversal_shader(device, pipeline, &traversal_info, preprocess, payload_size, hit_attrib_size);
struct radv_shader_stage traversal_stage = {
.stage = MESA_SHADER_INTERSECTION,
.nir = traversal_nir,

View file

@ -2005,6 +2005,9 @@ system_value("intersection_opaque_amd", 1, bit_sizes=[1])
system_value("resume_shader_address_amd", 1, bit_sizes=[64], indices=[CALL_IDX])
# Ray Tracing Traversal inputs
system_value("rt_descriptors_amd", 1)
system_value("rt_dynamic_descriptors_amd", 1)
system_value("rt_push_constants_amd", 1)
system_value("sbt_offset_amd", 1)
system_value("sbt_stride_amd", 1)
system_value("accel_struct_amd", 1, bit_sizes=[64])
@ -2030,6 +2033,7 @@ intrinsic("load_incoming_ray_payload_amd", dest_comp=1, bit_sizes=[32], indices=
intrinsic("store_incoming_ray_payload_amd", src_comp=[1], indices=[BASE])
intrinsic("load_outgoing_ray_payload_amd", dest_comp=1, bit_sizes=[32], indices=[BASE])
intrinsic("store_outgoing_ray_payload_amd", src_comp=[1], indices=[BASE])
intrinsic("load_ray_payload_ptr_amd", dest_comp=1, indices=[BASE])
# Load forced VRS rates.
intrinsic("load_force_vrs_rates_amd", dest_comp=1, bit_sizes=[32], flags=[CAN_ELIMINATE, CAN_REORDER])