radv: Add ray traversal loop.

Reviewed-by: Samuel Pitoiset <samuel.pitoiset@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/12592>
This commit is contained in:
Bas Nieuwenhuizen 2021-08-27 02:17:19 +02:00 committed by Marge Bot
parent c3d82a9622
commit 85580faa4b

View file

@ -1078,6 +1078,654 @@ nir_lower_intersection_shader(nir_shader *intersection, nir_shader *any_hit)
ralloc_free(dead_ctx);
}
/* Variables only used internally to ray traversal. This is data that describes
* the current state of the traversal vs. what we'd give to a shader. e.g. what
* is the instance we're currently visiting vs. what is the instance of the
* closest hit. */
struct rt_traversal_vars {
nir_variable *origin;
nir_variable *dir;
nir_variable *inv_dir;
nir_variable *sbt_offset_and_flags;
nir_variable *instance_id;
nir_variable *custom_instance_and_mask;
nir_variable *instance_addr;
nir_variable *should_return;
nir_variable *bvh_base;
nir_variable *stack;
nir_variable *top_stack;
};
static struct rt_traversal_vars
init_traversal_vars(nir_builder *b)
{
const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3);
struct rt_traversal_vars ret;
ret.origin = nir_variable_create(b->shader, nir_var_shader_temp, vec3_type, "traversal_origin");
ret.dir = nir_variable_create(b->shader, nir_var_shader_temp, vec3_type, "traversal_dir");
ret.inv_dir =
nir_variable_create(b->shader, nir_var_shader_temp, vec3_type, "traversal_inv_dir");
ret.sbt_offset_and_flags = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(),
"traversal_sbt_offset_and_flags");
ret.instance_id = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(),
"traversal_instance_id");
ret.custom_instance_and_mask = nir_variable_create(
b->shader, nir_var_shader_temp, glsl_uint_type(), "traversal_custom_instance_and_mask");
ret.instance_addr =
nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint64_t_type(), "instance_addr");
ret.should_return = nir_variable_create(b->shader, nir_var_shader_temp, glsl_bool_type(),
"traversal_should_return");
ret.bvh_base = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint64_t_type(),
"traversal_bvh_base");
ret.stack =
nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(), "traversal_stack_ptr");
ret.top_stack = nir_variable_create(b->shader, nir_var_shader_temp, glsl_uint_type(),
"traversal_top_stack_ptr");
return ret;
}
static nir_ssa_def *
nir_build_addr_to_node(nir_builder *b, nir_ssa_def *addr)
{
const uint64_t bvh_size = 1ull << 42;
nir_ssa_def *node = nir_ushr(b, addr, nir_imm_int(b, 3));
return nir_iand(b, node, nir_imm_int64(b, (bvh_size - 1) << 3));
}
static nir_ssa_def *
nir_build_node_to_addr(nir_builder *b, nir_ssa_def *node)
{
nir_ssa_def *addr = nir_iand(b, node, nir_imm_int64(b, ~7ull));
addr = nir_ishl(b, addr, nir_imm_int(b, 3));
/* Assumes everything is in the top half of address space, which is true in
* GFX9+ for now. */
return nir_ior(b, addr, nir_imm_int64(b, 0xffffull << 48));
}
/* When a hit is opaque the any_hit shader is skipped for this hit and the hit
* is assumed to be an actual hit. */
static nir_ssa_def *
hit_is_opaque(nir_builder *b, const struct rt_variables *vars,
const struct rt_traversal_vars *trav_vars, nir_ssa_def *geometry_id_and_flags)
{
nir_ssa_def *geom_force_opaque = nir_ine(
b, nir_iand(b, geometry_id_and_flags, nir_imm_int(b, 1u << 28 /* VK_GEOMETRY_OPAQUE_BIT */)),
nir_imm_int(b, 0));
nir_ssa_def *instance_force_opaque =
nir_ine(b,
nir_iand(b, nir_load_var(b, trav_vars->sbt_offset_and_flags),
nir_imm_int(b, 4 << 24 /* VK_GEOMETRY_INSTANCE_FORCE_OPAQUE_BIT */)),
nir_imm_int(b, 0));
nir_ssa_def *instance_force_non_opaque =
nir_ine(b,
nir_iand(b, nir_load_var(b, trav_vars->sbt_offset_and_flags),
nir_imm_int(b, 8 << 24 /* VK_GEOMETRY_INSTANCE_FORCE_NO_OPAQUE_BIT */)),
nir_imm_int(b, 0));
nir_ssa_def *opaque = geom_force_opaque;
opaque = nir_bcsel(b, instance_force_opaque, nir_imm_bool(b, true), opaque);
opaque = nir_bcsel(b, instance_force_non_opaque, nir_imm_bool(b, false), opaque);
nir_ssa_def *ray_force_opaque =
nir_ine(b, nir_iand(b, nir_load_var(b, vars->flags), nir_imm_int(b, 1 /* RayFlagsOpaque */)),
nir_imm_int(b, 0));
nir_ssa_def *ray_force_non_opaque = nir_ine(
b, nir_iand(b, nir_load_var(b, vars->flags), nir_imm_int(b, 2 /* RayFlagsNoOpaque */)),
nir_imm_int(b, 0));
opaque = nir_bcsel(b, ray_force_opaque, nir_imm_bool(b, true), opaque);
opaque = nir_bcsel(b, ray_force_non_opaque, nir_imm_bool(b, false), opaque);
return opaque;
}
static void
visit_any_hit_shaders(struct radv_device *device,
const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, nir_builder *b,
struct rt_variables *vars)
{
RADV_FROM_HANDLE(radv_pipeline_layout, layout, pCreateInfo->layout);
nir_ssa_def *sbt_idx = nir_load_var(b, vars->idx);
nir_push_if(b, nir_ine(b, sbt_idx, nir_imm_int(b, 0)));
for (unsigned i = 0; i < pCreateInfo->groupCount; ++i) {
const VkRayTracingShaderGroupCreateInfoKHR *group_info = &pCreateInfo->pGroups[i];
uint32_t shader_id = VK_SHADER_UNUSED_KHR;
switch (group_info->type) {
case VK_RAY_TRACING_SHADER_GROUP_TYPE_TRIANGLES_HIT_GROUP_KHR:
shader_id = group_info->anyHitShader;
break;
default:
break;
}
if (shader_id == VK_SHADER_UNUSED_KHR)
continue;
const VkPipelineShaderStageCreateInfo *stage = &pCreateInfo->pStages[shader_id];
nir_shader *nir_stage = parse_rt_stage(device, layout, stage);
vars->group_idx = i;
insert_rt_case(b, nir_stage, vars, sbt_idx, 0, i + 2);
}
nir_pop_if(b, NULL);
}
static void
insert_traversal_triangle_case(struct radv_device *device,
const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, nir_builder *b,
nir_ssa_def *result, const struct rt_variables *vars,
const struct rt_traversal_vars *trav_vars, nir_ssa_def *bvh_node)
{
nir_ssa_def *dist = nir_vector_extract(b, result, nir_imm_int(b, 0));
nir_ssa_def *div = nir_vector_extract(b, result, nir_imm_int(b, 1));
dist = nir_fdiv(b, dist, div);
nir_ssa_def *frontface = nir_flt(b, nir_imm_float(b, 0), div);
nir_ssa_def *switch_ccw = nir_ine(
b,
nir_iand(
b, nir_load_var(b, trav_vars->sbt_offset_and_flags),
nir_imm_int(b, 2 << 24 /* VK_GEOMETRY_INSTANCE_TRIANGLE_FRONT_COUNTERCLOCKWISE_BIT */)),
nir_imm_int(b, 0));
frontface = nir_ixor(b, frontface, switch_ccw);
nir_ssa_def *not_cull = nir_ieq(
b, nir_iand(b, nir_load_var(b, vars->flags), nir_imm_int(b, 256 /* RayFlagsSkipTriangles */)),
nir_imm_int(b, 0));
nir_ssa_def *not_facing_cull = nir_ieq(
b,
nir_iand(b, nir_load_var(b, vars->flags),
nir_bcsel(b, frontface, nir_imm_int(b, 32 /* RayFlagsCullFrontFacingTriangles */),
nir_imm_int(b, 16 /* RayFlagsCullBackFacingTriangles */))),
nir_imm_int(b, 0));
not_cull = nir_iand(
b, not_cull,
nir_ior(
b, not_facing_cull,
nir_ine(
b,
nir_iand(
b, nir_load_var(b, trav_vars->sbt_offset_and_flags),
nir_imm_int(b, 1 << 24 /* VK_GEOMETRY_INSTANCE_TRIANGLE_FACING_CULL_DISABLE_BIT */)),
nir_imm_int(b, 0))));
nir_push_if(b, nir_iand(b,
nir_iand(b, nir_flt(b, dist, nir_load_var(b, vars->tmax)),
nir_fge(b, dist, nir_load_var(b, vars->tmin))),
not_cull));
{
nir_ssa_def *triangle_info = nir_build_load_global(
b, 2, 32,
nir_iadd(b, nir_build_node_to_addr(b, bvh_node),
nir_imm_int64(b, offsetof(struct radv_bvh_triangle_node, triangle_id))),
.align_mul = 4, .align_offset = 0);
nir_ssa_def *primitive_id = nir_channel(b, triangle_info, 0);
nir_ssa_def *geometry_id_and_flags = nir_channel(b, triangle_info, 1);
nir_ssa_def *geometry_id = nir_iand(b, geometry_id_and_flags, nir_imm_int(b, 0xfffffff));
nir_ssa_def *is_opaque = hit_is_opaque(b, vars, trav_vars, geometry_id_and_flags);
not_cull =
nir_ieq(b,
nir_iand(b, nir_load_var(b, vars->flags),
nir_bcsel(b, is_opaque, nir_imm_int(b, 0x40), nir_imm_int(b, 0x80))),
nir_imm_int(b, 0));
nir_push_if(b, not_cull);
{
nir_ssa_def *sbt_idx =
nir_iadd(b,
nir_iadd(b, nir_load_var(b, vars->sbt_offset),
nir_iand(b, nir_load_var(b, trav_vars->sbt_offset_and_flags),
nir_imm_int(b, 0xffffff))),
nir_imul(b, nir_load_var(b, vars->sbt_stride), geometry_id));
nir_ssa_def *divs[2] = {div, div};
nir_ssa_def *ij = nir_fdiv(b, nir_channels(b, result, 0xc), nir_vec(b, divs, 2));
nir_ssa_def *hit_kind =
nir_bcsel(b, frontface, nir_imm_int(b, 0xFE), nir_imm_int(b, 0xFF));
nir_store_scratch(
b, ij,
nir_iadd(b, nir_load_var(b, vars->stack_ptr), nir_imm_int(b, RADV_HIT_ATTRIB_OFFSET)),
.align_mul = 16, .write_mask = 3);
nir_store_var(b, vars->ahit_status, nir_imm_int(b, 0), 1);
nir_push_if(b, nir_ine(b, is_opaque, nir_imm_bool(b, true)));
{
struct rt_variables inner_vars = create_inner_vars(b, vars);
nir_store_var(b, inner_vars.primitive_id, primitive_id, 1);
nir_store_var(b, inner_vars.geometry_id_and_flags, geometry_id_and_flags, 1);
nir_store_var(b, inner_vars.tmax, dist, 0x1);
nir_store_var(b, inner_vars.instance_id, nir_load_var(b, trav_vars->instance_id), 0x1);
nir_store_var(b, inner_vars.instance_addr, nir_load_var(b, trav_vars->instance_addr),
0x1);
nir_store_var(b, inner_vars.hit_kind, hit_kind, 0x1);
nir_store_var(b, inner_vars.custom_instance_and_mask,
nir_load_var(b, trav_vars->custom_instance_and_mask), 0x1);
load_sbt_entry(b, &inner_vars, sbt_idx, SBT_HIT, 4);
visit_any_hit_shaders(device, pCreateInfo, b, &inner_vars);
nir_push_if(b, nir_ieq(b, nir_load_var(b, vars->ahit_status), nir_imm_int(b, 1)));
{
nir_jump(b, nir_jump_continue);
}
nir_pop_if(b, NULL);
}
nir_pop_if(b, NULL);
nir_store_var(b, vars->primitive_id, primitive_id, 1);
nir_store_var(b, vars->geometry_id_and_flags, geometry_id_and_flags, 1);
nir_store_var(b, vars->tmax, dist, 0x1);
nir_store_var(b, vars->instance_id, nir_load_var(b, trav_vars->instance_id), 0x1);
nir_store_var(b, vars->instance_addr, nir_load_var(b, trav_vars->instance_addr), 0x1);
nir_store_var(b, vars->hit_kind, hit_kind, 0x1);
nir_store_var(b, vars->custom_instance_and_mask,
nir_load_var(b, trav_vars->custom_instance_and_mask), 0x1);
load_sbt_entry(b, vars, sbt_idx, SBT_HIT, 0);
nir_store_var(b, trav_vars->should_return,
nir_ior(b,
nir_ine(b,
nir_iand(b, nir_load_var(b, vars->flags),
nir_imm_int(b, 8 /* SkipClosestHitShader */)),
nir_imm_int(b, 0)),
nir_ieq(b, nir_load_var(b, vars->idx), nir_imm_int(b, 0))),
1);
nir_ssa_def *terminate_on_first_hit =
nir_ine(b,
nir_iand(b, nir_load_var(b, vars->flags),
nir_imm_int(b, 4 /* TerminateOnFirstHitKHR */)),
nir_imm_int(b, 0));
nir_ssa_def *ray_terminated =
nir_ieq(b, nir_load_var(b, vars->ahit_status), nir_imm_int(b, 2));
nir_push_if(b, nir_ior(b, terminate_on_first_hit, ray_terminated));
{
nir_jump(b, nir_jump_break);
}
nir_pop_if(b, NULL);
}
nir_pop_if(b, NULL);
}
nir_pop_if(b, NULL);
}
static void
insert_traversal_aabb_case(struct radv_device *device,
const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, nir_builder *b,
nir_ssa_def *result, const struct rt_variables *vars,
const struct rt_traversal_vars *trav_vars, nir_ssa_def *bvh_node)
{
RADV_FROM_HANDLE(radv_pipeline_layout, layout, pCreateInfo->layout);
nir_ssa_def *node_addr = nir_build_node_to_addr(b, bvh_node);
nir_ssa_def *triangle_info = nir_build_load_global(
b, 2, 32, nir_iadd(b, node_addr, nir_imm_int64(b, 24)), .align_mul = 4, .align_offset = 0);
nir_ssa_def *primitive_id = nir_channel(b, triangle_info, 0);
nir_ssa_def *geometry_id_and_flags = nir_channel(b, triangle_info, 1);
nir_ssa_def *geometry_id = nir_iand(b, geometry_id_and_flags, nir_imm_int(b, 0xfffffff));
nir_ssa_def *is_opaque = hit_is_opaque(b, vars, trav_vars, geometry_id_and_flags);
nir_ssa_def *not_cull =
nir_ieq(b,
nir_iand(b, nir_load_var(b, vars->flags),
nir_bcsel(b, is_opaque, nir_imm_int(b, 0x40), nir_imm_int(b, 0x80))),
nir_imm_int(b, 0));
nir_push_if(b, not_cull);
{
nir_ssa_def *sbt_idx =
nir_iadd(b,
nir_iadd(b, nir_load_var(b, vars->sbt_offset),
nir_iand(b, nir_load_var(b, trav_vars->sbt_offset_and_flags),
nir_imm_int(b, 0xffffff))),
nir_imul(b, nir_load_var(b, vars->sbt_stride), geometry_id));
struct rt_variables inner_vars = create_inner_vars(b, vars);
/* For AABBs the intersection shader writes the hit kind, and only does it if it is the
* next closest hit candidate. */
inner_vars.hit_kind = vars->hit_kind;
nir_store_var(b, inner_vars.primitive_id, primitive_id, 1);
nir_store_var(b, inner_vars.geometry_id_and_flags, geometry_id_and_flags, 1);
nir_store_var(b, inner_vars.tmax, nir_load_var(b, vars->tmax), 0x1);
nir_store_var(b, inner_vars.instance_id, nir_load_var(b, trav_vars->instance_id), 0x1);
nir_store_var(b, inner_vars.instance_addr, nir_load_var(b, trav_vars->instance_addr), 0x1);
nir_store_var(b, inner_vars.custom_instance_and_mask,
nir_load_var(b, trav_vars->custom_instance_and_mask), 0x1);
nir_store_var(b, inner_vars.opaque, is_opaque, 1);
load_sbt_entry(b, &inner_vars, sbt_idx, SBT_HIT, 4);
nir_store_var(b, vars->ahit_status, nir_imm_int(b, 1), 1);
nir_push_if(b, nir_ine(b, nir_load_var(b, inner_vars.idx), nir_imm_int(b, 0)));
for (unsigned i = 0; i < pCreateInfo->groupCount; ++i) {
const VkRayTracingShaderGroupCreateInfoKHR *group_info = &pCreateInfo->pGroups[i];
uint32_t shader_id = VK_SHADER_UNUSED_KHR;
uint32_t any_hit_shader_id = VK_SHADER_UNUSED_KHR;
switch (group_info->type) {
case VK_RAY_TRACING_SHADER_GROUP_TYPE_PROCEDURAL_HIT_GROUP_KHR:
shader_id = group_info->intersectionShader;
any_hit_shader_id = group_info->anyHitShader;
break;
default:
break;
}
if (shader_id == VK_SHADER_UNUSED_KHR)
continue;
const VkPipelineShaderStageCreateInfo *stage = &pCreateInfo->pStages[shader_id];
nir_shader *nir_stage = parse_rt_stage(device, layout, stage);
nir_shader *any_hit_stage = NULL;
if (any_hit_shader_id != VK_SHADER_UNUSED_KHR) {
stage = &pCreateInfo->pStages[any_hit_shader_id];
any_hit_stage = parse_rt_stage(device, layout, stage);
nir_lower_intersection_shader(nir_stage, any_hit_stage);
ralloc_free(any_hit_stage);
}
inner_vars.group_idx = i;
insert_rt_case(b, nir_stage, &inner_vars, nir_load_var(b, inner_vars.idx), 0, i + 2);
}
nir_push_else(b, NULL);
{
nir_ssa_def *vec3_zero = nir_channels(b, nir_imm_vec4(b, 0, 0, 0, 0), 0x7);
nir_ssa_def *vec3_inf =
nir_channels(b, nir_imm_vec4(b, INFINITY, INFINITY, INFINITY, 0), 0x7);
nir_ssa_def *bvh_lo =
nir_build_load_global(b, 3, 32, nir_iadd(b, node_addr, nir_imm_int64(b, 0)),
.align_mul = 4, .align_offset = 0);
nir_ssa_def *bvh_hi =
nir_build_load_global(b, 3, 32, nir_iadd(b, node_addr, nir_imm_int64(b, 12)),
.align_mul = 4, .align_offset = 0);
bvh_lo = nir_fsub(b, bvh_lo, nir_load_var(b, trav_vars->origin));
bvh_hi = nir_fsub(b, bvh_hi, nir_load_var(b, trav_vars->origin));
nir_ssa_def *t_vec = nir_fmin(b, nir_fmul(b, bvh_lo, nir_load_var(b, trav_vars->inv_dir)),
nir_fmul(b, bvh_hi, nir_load_var(b, trav_vars->inv_dir)));
nir_ssa_def *t2_vec = nir_fmax(b, nir_fmul(b, bvh_lo, nir_load_var(b, trav_vars->inv_dir)),
nir_fmul(b, bvh_hi, nir_load_var(b, trav_vars->inv_dir)));
/* If we run parallel to one of the edges the range should be [0, inf) not [0,0] */
t2_vec =
nir_bcsel(b, nir_feq(b, nir_load_var(b, trav_vars->dir), vec3_zero), vec3_inf, t2_vec);
nir_ssa_def *t_min = nir_fmax(b, nir_channel(b, t_vec, 0), nir_channel(b, t_vec, 1));
t_min = nir_fmax(b, t_min, nir_channel(b, t_vec, 2));
nir_ssa_def *t_max = nir_fmin(b, nir_channel(b, t2_vec, 0), nir_channel(b, t2_vec, 1));
t_max = nir_fmin(b, t_max, nir_channel(b, t2_vec, 2));
nir_push_if(b, nir_iand(b, nir_flt(b, t_min, nir_load_var(b, vars->tmax)),
nir_fge(b, t_max, nir_load_var(b, vars->tmin))));
{
nir_store_var(b, vars->ahit_status, nir_imm_int(b, 0), 1);
nir_store_var(b, vars->tmax, nir_fmax(b, t_min, nir_load_var(b, vars->tmin)), 1);
}
nir_pop_if(b, NULL);
}
nir_pop_if(b, NULL);
nir_push_if(b, nir_ine(b, nir_load_var(b, vars->ahit_status), nir_imm_int(b, 1)));
{
nir_store_var(b, vars->primitive_id, primitive_id, 1);
nir_store_var(b, vars->geometry_id_and_flags, geometry_id_and_flags, 1);
nir_store_var(b, vars->tmax, nir_load_var(b, inner_vars.tmax), 0x1);
nir_store_var(b, vars->instance_id, nir_load_var(b, trav_vars->instance_id), 0x1);
nir_store_var(b, vars->instance_addr, nir_load_var(b, trav_vars->instance_addr), 0x1);
nir_store_var(b, vars->custom_instance_and_mask,
nir_load_var(b, trav_vars->custom_instance_and_mask), 0x1);
load_sbt_entry(b, vars, sbt_idx, SBT_HIT, 0);
nir_store_var(b, trav_vars->should_return,
nir_ior(b,
nir_ine(b,
nir_iand(b, nir_load_var(b, vars->flags),
nir_imm_int(b, 8 /* SkipClosestHitShader */)),
nir_imm_int(b, 0)),
nir_ieq(b, nir_load_var(b, vars->idx), nir_imm_int(b, 0))),
1);
nir_ssa_def *terminate_on_first_hit =
nir_ine(b,
nir_iand(b, nir_load_var(b, vars->flags),
nir_imm_int(b, 4 /* TerminateOnFirstHitKHR */)),
nir_imm_int(b, 0));
nir_ssa_def *ray_terminated =
nir_ieq(b, nir_load_var(b, vars->ahit_status), nir_imm_int(b, 2));
nir_push_if(b, nir_ior(b, terminate_on_first_hit, ray_terminated));
{
nir_jump(b, nir_jump_break);
}
nir_pop_if(b, NULL);
}
nir_pop_if(b, NULL);
}
nir_pop_if(b, NULL);
}
static void
insert_traversal(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
nir_builder *b, const struct rt_variables *vars)
{
unsigned stack_entry_size = 4;
unsigned lanes = b->shader->info.workgroup_size[0] * b->shader->info.workgroup_size[1] *
b->shader->info.workgroup_size[2];
unsigned stack_entry_stride = stack_entry_size * lanes;
nir_ssa_def *stack_entry_stride_def = nir_imm_int(b, stack_entry_stride);
nir_ssa_def *stack_base =
nir_iadd(b, nir_imm_int(b, b->shader->info.shared_size),
nir_imul(b, nir_load_subgroup_invocation(b), nir_imm_int(b, stack_entry_size)));
/*
* A top-level AS can contain 2^24 children and a bottom-level AS can contain 2^24 triangles. At
* a branching factor of 4, that means we may need up to 24 levels of box nodes + 1 triangle node
* + 1 instance node. Furthermore, when processing a box node, worst case we actually push all 4
* children and remove one, so the DFS stack depth is box nodes * 3 + 2.
*/
b->shader->info.shared_size += stack_entry_stride * 76;
assert(b->shader->info.shared_size <= 32768);
nir_ssa_def *accel_struct = nir_load_var(b, vars->accel_struct);
struct rt_traversal_vars trav_vars = init_traversal_vars(b);
/* Initialize the follow-up shader idx to 0, to be replaced by the miss shader
* if we actually miss. */
nir_store_var(b, vars->idx, nir_imm_int(b, 0), 1);
nir_store_var(b, trav_vars.should_return, nir_imm_bool(b, false), 1);
nir_push_if(b, nir_ine(b, accel_struct, nir_imm_int64(b, 0)));
{
nir_store_var(b, trav_vars.bvh_base, nir_build_addr_to_node(b, accel_struct), 1);
nir_ssa_def *bvh_root =
nir_build_load_global(b, 1, 32, accel_struct, .access = ACCESS_NON_WRITEABLE,
.align_mul = 64, .align_offset = 0);
/* We create a BVH descriptor that covers the entire memory range. That way we can always
* use the same descriptor, which avoids divergence when different rays hit different
* instances at the cost of having to use 64-bit node ids. */
const uint64_t bvh_size = 1ull << 42;
nir_ssa_def *desc = nir_imm_ivec4(
b, 0, 1u << 31 /* Enable box sorting */, (bvh_size - 1) & 0xFFFFFFFFu,
((bvh_size - 1) >> 32) | (1u << 24 /* Return IJ for triangles */) | (1u << 31));
nir_ssa_def *vec3ones = nir_channels(b, nir_imm_vec4(b, 1.0, 1.0, 1.0, 1.0), 0x7);
nir_store_var(b, trav_vars.origin, nir_load_var(b, vars->origin), 7);
nir_store_var(b, trav_vars.dir, nir_load_var(b, vars->direction), 7);
nir_store_var(b, trav_vars.inv_dir, nir_fdiv(b, vec3ones, nir_load_var(b, trav_vars.dir)), 7);
nir_store_var(b, trav_vars.sbt_offset_and_flags, nir_imm_int(b, 0), 1);
nir_store_var(b, trav_vars.instance_addr, nir_imm_int64(b, 0), 1);
nir_store_var(b, trav_vars.stack, nir_iadd(b, stack_base, stack_entry_stride_def), 1);
nir_store_shared(b, bvh_root, stack_base, .base = 0, .write_mask = 0x1,
.align_mul = stack_entry_size, .align_offset = 0);
nir_store_var(b, trav_vars.top_stack, nir_imm_int(b, 0), 1);
nir_push_loop(b);
nir_push_if(b, nir_ieq(b, nir_load_var(b, trav_vars.stack), stack_base));
nir_jump(b, nir_jump_break);
nir_pop_if(b, NULL);
nir_push_if(
b, nir_uge(b, nir_load_var(b, trav_vars.top_stack), nir_load_var(b, trav_vars.stack)));
nir_store_var(b, trav_vars.top_stack, nir_imm_int(b, 0), 1);
nir_store_var(b, trav_vars.bvh_base,
nir_build_addr_to_node(b, nir_load_var(b, vars->accel_struct)), 1);
nir_store_var(b, trav_vars.origin, nir_load_var(b, vars->origin), 7);
nir_store_var(b, trav_vars.dir, nir_load_var(b, vars->direction), 7);
nir_store_var(b, trav_vars.inv_dir, nir_fdiv(b, vec3ones, nir_load_var(b, trav_vars.dir)), 7);
nir_store_var(b, trav_vars.instance_addr, nir_imm_int64(b, 0), 1);
nir_pop_if(b, NULL);
nir_store_var(b, trav_vars.stack,
nir_isub(b, nir_load_var(b, trav_vars.stack), stack_entry_stride_def), 1);
nir_ssa_def *bvh_node = nir_load_shared(b, 1, 32, nir_load_var(b, trav_vars.stack), .base = 0,
.align_mul = stack_entry_size, .align_offset = 0);
nir_ssa_def *bvh_node_type = nir_iand(b, bvh_node, nir_imm_int(b, 7));
bvh_node = nir_iadd(b, nir_load_var(b, trav_vars.bvh_base), nir_u2u(b, bvh_node, 64));
nir_ssa_def *result = nir_bvh64_intersect_ray_amd(
b, 32, desc, nir_unpack_64_2x32(b, bvh_node), nir_load_var(b, vars->tmax),
nir_load_var(b, trav_vars.origin), nir_load_var(b, trav_vars.dir),
nir_load_var(b, trav_vars.inv_dir));
nir_push_if(b, nir_ine(b, nir_iand(b, bvh_node_type, nir_imm_int(b, 4)), nir_imm_int(b, 0)));
{
nir_push_if(b,
nir_ine(b, nir_iand(b, bvh_node_type, nir_imm_int(b, 2)), nir_imm_int(b, 0)));
{
/* custom */
nir_push_if(
b, nir_ine(b, nir_iand(b, bvh_node_type, nir_imm_int(b, 1)), nir_imm_int(b, 0)));
{
insert_traversal_aabb_case(device, pCreateInfo, b, result, vars, &trav_vars,
bvh_node);
}
nir_push_else(b, NULL);
{
/* instance */
nir_ssa_def *instance_node_addr = nir_build_node_to_addr(b, bvh_node);
nir_ssa_def *instance_data = nir_build_load_global(
b, 4, 32, instance_node_addr, .align_mul = 64, .align_offset = 0);
nir_ssa_def *wto_matrix[] = {
nir_build_load_global(b, 4, 32,
nir_iadd(b, instance_node_addr, nir_imm_int64(b, 16)),
.align_mul = 64, .align_offset = 16),
nir_build_load_global(b, 4, 32,
nir_iadd(b, instance_node_addr, nir_imm_int64(b, 32)),
.align_mul = 64, .align_offset = 32),
nir_build_load_global(b, 4, 32,
nir_iadd(b, instance_node_addr, nir_imm_int64(b, 48)),
.align_mul = 64, .align_offset = 48)};
nir_ssa_def *instance_id = nir_build_load_global(
b, 1, 32, nir_iadd(b, instance_node_addr, nir_imm_int64(b, 88)), .align_mul = 4,
.align_offset = 0);
nir_ssa_def *instance_and_mask = nir_channel(b, instance_data, 2);
nir_ssa_def *instance_mask = nir_ushr(b, instance_and_mask, nir_imm_int(b, 24));
nir_push_if(b,
nir_ieq(b, nir_iand(b, instance_mask, nir_load_var(b, vars->cull_mask)),
nir_imm_int(b, 0)));
nir_jump(b, nir_jump_continue);
nir_pop_if(b, NULL);
nir_store_var(b, trav_vars.top_stack, nir_load_var(b, trav_vars.stack), 1);
nir_store_var(b, trav_vars.bvh_base,
nir_build_addr_to_node(
b, nir_pack_64_2x32(b, nir_channels(b, instance_data, 0x3))),
1);
nir_store_shared(b,
nir_iand(b, nir_channel(b, instance_data, 0), nir_imm_int(b, 63)),
nir_load_var(b, trav_vars.stack), .base = 0, .write_mask = 0x1,
.align_mul = stack_entry_size, .align_offset = 0);
nir_store_var(b, trav_vars.stack,
nir_iadd(b, nir_load_var(b, trav_vars.stack), stack_entry_stride_def),
1);
nir_store_var(
b, trav_vars.origin,
nir_build_vec3_mat_mult_pre(b, nir_load_var(b, vars->origin), wto_matrix), 7);
nir_store_var(
b, trav_vars.dir,
nir_build_vec3_mat_mult(b, nir_load_var(b, vars->direction), wto_matrix, false),
7);
nir_store_var(b, trav_vars.inv_dir,
nir_fdiv(b, vec3ones, nir_load_var(b, trav_vars.dir)), 7);
nir_store_var(b, trav_vars.custom_instance_and_mask, instance_and_mask, 1);
nir_store_var(b, trav_vars.sbt_offset_and_flags, nir_channel(b, instance_data, 3),
1);
nir_store_var(b, trav_vars.instance_id, instance_id, 1);
nir_store_var(b, trav_vars.instance_addr, instance_node_addr, 1);
}
nir_pop_if(b, NULL);
}
nir_push_else(b, NULL);
{
/* box */
for (unsigned i = 0; i < 4; ++i) {
nir_ssa_def *new_node = nir_vector_extract(b, result, nir_imm_int(b, i));
nir_push_if(b, nir_ine(b, new_node, nir_imm_int(b, 0xffffffff)));
{
nir_store_shared(b, new_node, nir_load_var(b, trav_vars.stack), .base = 0,
.write_mask = 0x1, .align_mul = stack_entry_size,
.align_offset = 0);
nir_store_var(
b, trav_vars.stack,
nir_iadd(b, nir_load_var(b, trav_vars.stack), stack_entry_stride_def), 1);
}
nir_pop_if(b, NULL);
}
}
nir_pop_if(b, NULL);
}
nir_push_else(b, NULL);
{
insert_traversal_triangle_case(device, pCreateInfo, b, result, vars, &trav_vars, bvh_node);
}
nir_pop_if(b, NULL);
nir_pop_loop(b, NULL);
}
nir_pop_if(b, NULL);
/* should_return is set if we had a hit but we won't be calling the closest hit shader and hence
* need to return immediately to the calling shader. */
nir_push_if(b, nir_load_var(b, trav_vars.should_return));
{
insert_rt_return(b, vars);
}
nir_push_else(b, NULL);
{
/* Only load the miss shader if we actually miss, which we determining by not having set
* a closest hit shader. It is valid to not specify an SBT pointer for miss shaders if none
* of the rays miss. */
nir_push_if(b, nir_ieq(b, nir_load_var(b, vars->idx), nir_imm_int(b, 0)));
{
load_sbt_entry(b, vars, nir_load_var(b, vars->miss_index), SBT_MISS, 0);
}
nir_pop_if(b, NULL);
}
nir_pop_if(b, NULL);
}
static nir_shader *
create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
struct radv_pipeline_shader_stack_size *stack_sizes)