diff --git a/src/compiler/nir/nir_intrinsics.py b/src/compiler/nir/nir_intrinsics.py index 63618d3d3ab..01c37750afb 100644 --- a/src/compiler/nir/nir_intrinsics.py +++ b/src/compiler/nir/nir_intrinsics.py @@ -1113,5 +1113,9 @@ system_value("ray_base_mem_addr_intel", 1, bit_sizes=[64]) system_value("ray_hw_stack_size_intel", 1) system_value("ray_sw_stack_size_intel", 1) system_value("ray_num_dss_rt_stacks_intel", 1) +system_value("ray_hit_sbt_addr_intel", 1, bit_sizes=[64]) +system_value("ray_hit_sbt_stride_intel", 1, bit_sizes=[16]) +system_value("ray_miss_sbt_addr_intel", 1, bit_sizes=[64]) +system_value("ray_miss_sbt_stride_intel", 1, bit_sizes=[16]) system_value("callable_sbt_addr_intel", 1, bit_sizes=[64]) system_value("callable_sbt_stride_intel", 1, bit_sizes=[16]) diff --git a/src/intel/compiler/brw_nir_lower_rt_intrinsics.c b/src/intel/compiler/brw_nir_lower_rt_intrinsics.c index ed51237dd7e..2a59891d842 100644 --- a/src/intel/compiler/brw_nir_lower_rt_intrinsics.c +++ b/src/intel/compiler/brw_nir_lower_rt_intrinsics.c @@ -106,6 +106,22 @@ lower_rt_intrinsics_impl(nir_function_impl *impl, sysval = globals.num_dss_rt_stacks; break; + case nir_intrinsic_load_ray_hit_sbt_addr_intel: + sysval = globals.hit_sbt_addr; + break; + + case nir_intrinsic_load_ray_hit_sbt_stride_intel: + sysval = globals.hit_sbt_stride; + break; + + case nir_intrinsic_load_ray_miss_sbt_addr_intel: + sysval = globals.miss_sbt_addr; + break; + + case nir_intrinsic_load_ray_miss_sbt_stride_intel: + sysval = globals.miss_sbt_stride; + break; + case nir_intrinsic_load_callable_sbt_addr_intel: sysval = globals.call_sbt_addr; break; diff --git a/src/intel/compiler/brw_nir_lower_shader_calls.c b/src/intel/compiler/brw_nir_lower_shader_calls.c index 8c6fd84ebb3..8bf5029c47e 100644 --- a/src/intel/compiler/brw_nir_lower_shader_calls.c +++ b/src/intel/compiler/brw_nir_lower_shader_calls.c @@ -245,6 +245,10 @@ can_remat_instr(nir_instr *instr, struct bitset *remat) case nir_intrinsic_load_ray_hw_stack_size_intel: case nir_intrinsic_load_ray_sw_stack_size_intel: case nir_intrinsic_load_ray_num_dss_rt_stacks_intel: + case nir_intrinsic_load_ray_hit_sbt_addr_intel: + case nir_intrinsic_load_ray_hit_sbt_stride_intel: + case nir_intrinsic_load_ray_miss_sbt_addr_intel: + case nir_intrinsic_load_ray_miss_sbt_stride_intel: case nir_intrinsic_load_callable_sbt_addr_intel: case nir_intrinsic_load_callable_sbt_stride_intel: /* Notably missing from the above list is btd_local_arg_addr_intel. @@ -529,8 +533,87 @@ spill_ssa_defs_and_lower_shader_calls(nir_shader *shader, uint32_t num_calls, /* Lower to the _intel intrinsic */ switch (call->intrinsic) { - case nir_intrinsic_trace_ray: - unreachable("TODO"); + case nir_intrinsic_trace_ray: { + nir_ssa_def *as_addr = call->src[0].ssa; + nir_ssa_def *ray_flags = call->src[1].ssa; + /* From the SPIR-V spec: + * + * "Only the 8 least-significant bits of Cull Mask are used by + * this instruction - other bits are ignored. + * + * Only the 4 least-significant bits of SBT Offset and SBT + * Stride are used by this instruction - other bits are + * ignored. + * + * Only the 16 least-significant bits of Miss Index are used by + * this instruction - other bits are ignored." + */ + nir_ssa_def *cull_mask = nir_iand_imm(b, call->src[2].ssa, 0xff); + nir_ssa_def *sbt_offset = nir_iand_imm(b, call->src[3].ssa, 0xf); + nir_ssa_def *sbt_stride = nir_iand_imm(b, call->src[4].ssa, 0xf); + nir_ssa_def *miss_index = nir_iand_imm(b, call->src[5].ssa, 0xffff); + nir_ssa_def *ray_orig = call->src[6].ssa; + nir_ssa_def *ray_t_min = call->src[7].ssa; + nir_ssa_def *ray_dir = call->src[8].ssa; + nir_ssa_def *ray_t_max = call->src[9].ssa; + + /* The hardware packet takes the address to the root node in the + * acceleration structure, not the acceleration structure itself. + * To find that, we have to read the root node offset from the + * acceleration structure which is the first QWord. + */ + nir_ssa_def *root_node_ptr = + nir_iadd(b, as_addr, nir_load_global(b, as_addr, 256, 1, 64)); + + /* The hardware packet requires an address to the first element of + * the hit SBT. + * + * In order to calculate this, we must multiply the "SBT Offset" + * provided to OpTraceRay by the SBT stride provided for the hit + * SBT in the call to vkCmdTraceRay() and add that to the base + * address of the hit SBT. This stride is not to be confused with + * the "SBT Stride" provided to OpTraceRay which is in units of + * this stride. It's a rather terrible overload of the word + * "stride". The hardware docs calls the SPIR-V stride value the + * "shader index multiplier" which is a much more sane name. + */ + nir_ssa_def *hit_sbt_stride_B = + nir_load_ray_hit_sbt_stride_intel(b); + nir_ssa_def *hit_sbt_offset_B = + nir_umul_32x16(b, sbt_offset, nir_u2u32(b, hit_sbt_stride_B)); + nir_ssa_def *hit_sbt_addr = + nir_iadd(b, nir_load_ray_hit_sbt_addr_intel(b), + nir_u2u64(b, hit_sbt_offset_B)); + + /* The hardware packet takes an address to the miss BSR. */ + nir_ssa_def *miss_sbt_stride_B = + nir_load_ray_miss_sbt_stride_intel(b); + nir_ssa_def *miss_sbt_offset_B = + nir_umul_32x16(b, miss_index, nir_u2u32(b, miss_sbt_stride_B)); + nir_ssa_def *miss_sbt_addr = + nir_iadd(b, nir_load_ray_miss_sbt_addr_intel(b), + nir_u2u64(b, miss_sbt_offset_B)); + + struct brw_nir_rt_mem_ray_defs ray_defs = { + .root_node_ptr = root_node_ptr, + .ray_flags = nir_u2u16(b, ray_flags), + .ray_mask = cull_mask, + .hit_group_sr_base_ptr = hit_sbt_addr, + .hit_group_sr_stride = nir_u2u16(b, hit_sbt_stride_B), + .miss_sr_ptr = miss_sbt_addr, + .orig = ray_orig, + .t_near = ray_t_min, + .dir = ray_dir, + .t_far = ray_t_max, + .shader_index_multiplier = sbt_stride, + }; + brw_nir_rt_store_mem_ray(b, &ray_defs, BRW_RT_BVH_LEVEL_WORLD); + nir_intrinsic_instr *ray_intel = + nir_intrinsic_instr_create(b->shader, + nir_intrinsic_trace_ray_initial_intel); + nir_builder_instr_insert(b, &ray_intel->instr); + break; + } case nir_intrinsic_report_ray_intersection: unreachable("Any-hit shaders must be inlined");