intel/compiler: Update MemRay data structure to 64-bit

Rework: (Kevin)
- Fix miss_shader_index offset
- Handle hit group index

Signed-off-by: Sagar Ghuge <sagar.ghuge@intel.com>
Reviewed-by: Kevin Chuang <kaiwenjon23@gmail.com>
Reviewed-by: Lionel Landwerlin <lionel.g.landwerlin@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/33047>
This commit is contained in:
Sagar Ghuge 2023-10-20 17:15:12 -07:00 committed by Marge Bot
parent 7b526de18f
commit 5cd0f4ba2f
7 changed files with 231 additions and 87 deletions

View file

@ -200,7 +200,8 @@ brw_nir_lower_intersection_shader(nir_shader *intersection,
nir_def *min_t = nir_load_ray_t_min(b);
struct brw_nir_rt_mem_ray_defs ray_def;
brw_nir_rt_load_mem_ray(b, &ray_def, BRW_RT_BVH_LEVEL_WORLD);
brw_nir_rt_load_mem_ray(b, &ray_def, BRW_RT_BVH_LEVEL_WORLD,
devinfo);
struct brw_nir_rt_mem_hit_defs hit_in = {};
brw_nir_rt_load_mem_hit(b, &hit_in, false, devinfo);

View file

@ -271,7 +271,8 @@ lower_ray_query_intrinsic(nir_builder *b,
brw_nir_rt_mem_ray_addr(b, stack_addr, BRW_RT_BVH_LEVEL_WORLD);
brw_nir_rt_query_mark_init(b, stack_addr);
brw_nir_rt_store_mem_ray_query_at_addr(b, ray_addr, &ray_defs);
brw_nir_rt_store_mem_ray_query_at_addr(b, ray_addr, &ray_defs,
state->devinfo);
update_trace_ctrl_level(b, ctrl_level_deref,
NULL, NULL,
@ -363,9 +364,11 @@ lower_ray_query_intrinsic(nir_builder *b,
struct brw_nir_rt_mem_ray_defs object_ray_in = {};
struct brw_nir_rt_mem_hit_defs hit_in = {};
brw_nir_rt_load_mem_ray_from_addr(b, &world_ray_in, stack_addr,
BRW_RT_BVH_LEVEL_WORLD);
BRW_RT_BVH_LEVEL_WORLD,
state->devinfo);
brw_nir_rt_load_mem_ray_from_addr(b, &object_ray_in, stack_addr,
BRW_RT_BVH_LEVEL_OBJECT);
BRW_RT_BVH_LEVEL_OBJECT,
state->devinfo);
brw_nir_rt_load_mem_hit_from_addr(b, &hit_in, stack_addr, committed,
state->devinfo);

View file

@ -89,12 +89,12 @@ lower_rt_intrinsics_impl(nir_function_impl *impl,
brw_nir_rt_load_mem_hit(b, &hit_in,
stage == MESA_SHADER_CLOSEST_HIT, devinfo);
brw_nir_rt_load_mem_ray(b, &object_ray_in,
BRW_RT_BVH_LEVEL_OBJECT);
BRW_RT_BVH_LEVEL_OBJECT, devinfo);
FALLTHROUGH;
case MESA_SHADER_MISS:
brw_nir_rt_load_mem_ray(b, &world_ray_in,
BRW_RT_BVH_LEVEL_WORLD);
BRW_RT_BVH_LEVEL_WORLD, devinfo);
break;
default:

View file

@ -147,7 +147,9 @@ store_resume_addr(nir_builder *b, nir_intrinsic_instr *call)
static bool
lower_shader_trace_ray_instr(struct nir_builder *b, nir_instr *instr, void *data)
{
struct brw_bs_prog_key *key = data;
const struct brw_nir_lower_shader_calls_state *state = data;
const struct intel_device_info *devinfo = state->devinfo;
struct brw_bs_prog_key *key = state->key;
if (instr->type != nir_instr_type_intrinsic)
return false;
@ -224,9 +226,6 @@ lower_shader_trace_ray_instr(struct nir_builder *b, nir_instr *instr, void *data
*/
.ray_flags = nir_ior_imm(b, nir_u2u16(b, ray_flags), key->pipeline_ray_flags),
.ray_mask = cull_mask,
.hit_group_sr_base_ptr = hit_sbt_addr,
.hit_group_sr_stride = nir_u2u16(b, hit_sbt_stride_B),
.miss_sr_ptr = miss_sbt_addr,
.orig = ray_orig,
.t_near = ray_t_min,
.dir = ray_dir,
@ -240,7 +239,17 @@ lower_shader_trace_ray_instr(struct nir_builder *b, nir_instr *instr, void *data
*/
.inst_leaf_ptr = nir_u2u64(b, ray_flags),
};
brw_nir_rt_store_mem_ray(b, &ray_defs, BRW_RT_BVH_LEVEL_WORLD);
if (devinfo->ver >= 30) {
ray_defs.hit_group_index = sbt_offset;
ray_defs.miss_shader_index = nir_u2u16(b, miss_index);
} else {
ray_defs.hit_group_sr_base_ptr = hit_sbt_addr;
ray_defs.hit_group_sr_stride = nir_u2u16(b, hit_sbt_stride_B);
ray_defs.miss_sr_ptr = miss_sbt_addr;
}
brw_nir_rt_store_mem_ray(b, &ray_defs, BRW_RT_BVH_LEVEL_WORLD, devinfo);
nir_trace_ray_intel(b,
nir_load_btd_global_arg_addr_intel(b),
@ -272,12 +281,13 @@ lower_shader_call_instr(struct nir_builder *b, nir_intrinsic_instr *call,
}
bool
brw_nir_lower_shader_calls(nir_shader *shader, struct brw_bs_prog_key *key)
brw_nir_lower_shader_calls(nir_shader *shader,
struct brw_nir_lower_shader_calls_state *state)
{
bool a = nir_shader_instructions_pass(shader,
lower_shader_trace_ray_instr,
nir_metadata_none,
key);
state);
bool b = nir_shader_intrinsics_pass(shader, lower_shader_call_instr,
nir_metadata_control_flow,
NULL);

View file

@ -52,12 +52,18 @@ void brw_nir_lower_combined_intersection_any_hit(nir_shader *intersection,
/* We require the stack to be 8B aligned at the start of a shader */
#define BRW_BTD_STACK_ALIGN 8
struct brw_nir_lower_shader_calls_state {
const struct intel_device_info *devinfo;
struct brw_bs_prog_key *key;
};
bool brw_nir_lower_ray_queries(nir_shader *shader,
const struct intel_device_info *devinfo);
void brw_nir_lower_shader_returns(nir_shader *shader);
bool brw_nir_lower_shader_calls(nir_shader *shader, struct brw_bs_prog_key *key);
bool brw_nir_lower_shader_calls(nir_shader *shader,
struct brw_nir_lower_shader_calls_state *state);
void brw_nir_lower_rt_intrinsics(nir_shader *shader,
const struct brw_base_prog_key *key,

View file

@ -759,12 +759,17 @@ struct brw_nir_rt_mem_ray_defs {
nir_def *shader_index_multiplier;
nir_def *inst_leaf_ptr;
nir_def *ray_mask;
/* Valid on Xe3+ */
nir_def *hit_group_index;
nir_def *miss_shader_index;
};
static inline void
brw_nir_rt_store_mem_ray_query_at_addr(nir_builder *b,
nir_def *ray_addr,
const struct brw_nir_rt_mem_ray_defs *defs)
const struct brw_nir_rt_mem_ray_defs *defs,
const struct intel_device_info *devinfo)
{
assert_def_size(defs->orig, 3, 32);
assert_def_size(defs->dir, 3, 32);
@ -784,15 +789,6 @@ brw_nir_rt_store_mem_ray_query_at_addr(nir_builder *b,
defs->t_far),
~0 /* write mask */);
assert_def_size(defs->root_node_ptr, 1, 64);
assert_def_size(defs->ray_flags, 1, 16);
brw_nir_rt_store(b, nir_iadd_imm(b, ray_addr, 32), 16,
nir_vec2(b, nir_unpack_64_2x32_split_x(b, defs->root_node_ptr),
nir_pack_32_2x16_split(b,
nir_unpack_64_4x16_split_z(b, defs->root_node_ptr),
defs->ray_flags)),
0x3 /* write mask */);
/* leaf_ptr is optional */
nir_def *inst_leaf_ptr;
if (defs->inst_leaf_ptr) {
@ -801,20 +797,47 @@ brw_nir_rt_store_mem_ray_query_at_addr(nir_builder *b,
inst_leaf_ptr = nir_imm_int64(b, 0);
}
assert_def_size(defs->root_node_ptr, 1, 64);
assert_def_size(inst_leaf_ptr, 1, 64);
assert_def_size(defs->ray_mask, 1, 32);
brw_nir_rt_store(b, nir_iadd_imm(b, ray_addr, 56), 8,
nir_vec2(b, nir_unpack_64_2x32_split_x(b, inst_leaf_ptr),
nir_pack_32_2x16_split(b,
nir_unpack_64_4x16_split_z(b, inst_leaf_ptr),
nir_unpack_32_2x16_split_x(b, defs->ray_mask))),
~0 /* write mask */);
assert_def_size(defs->ray_flags, 1, 16);
if (devinfo->ver >= 30) {
brw_nir_rt_store(b, nir_iadd_imm(b, ray_addr, 32), 16,
nir_vec4(b, nir_unpack_64_2x32_split_x(b, defs->root_node_ptr),
nir_unpack_64_2x32_split_y(b, defs->root_node_ptr),
nir_unpack_64_2x32_split_x(b, inst_leaf_ptr),
nir_unpack_64_2x32_split_y(b, inst_leaf_ptr)),
~0 /* write mask */);
assert_def_size(defs->ray_mask, 1, 32);
brw_nir_rt_store(b, nir_iadd_imm(b, ray_addr, 48), 8,
nir_pack_32_2x16_split(b,
defs->ray_flags,
nir_unpack_32_2x16_split_x(b, defs->ray_mask)),
0x1 /* write mask */);
} else {
brw_nir_rt_store(b, nir_iadd_imm(b, ray_addr, 32), 16,
nir_vec2(b, nir_unpack_64_2x32_split_x(b, defs->root_node_ptr),
nir_pack_32_2x16_split(b,
nir_unpack_64_4x16_split_z(b, defs->root_node_ptr),
defs->ray_flags)),
0x3 /* write mask */);
assert_def_size(defs->ray_mask, 1, 32);
brw_nir_rt_store(b, nir_iadd_imm(b, ray_addr, 56), 8,
nir_vec2(b, nir_unpack_64_2x32_split_x(b, inst_leaf_ptr),
nir_pack_32_2x16_split(b,
nir_unpack_64_4x16_split_z(b, inst_leaf_ptr),
nir_unpack_32_2x16_split_x(b, defs->ray_mask))),
~0 /* write mask */);
}
}
static inline void
brw_nir_rt_store_mem_ray(nir_builder *b,
const struct brw_nir_rt_mem_ray_defs *defs,
enum brw_rt_bvh_level bvh_level)
enum brw_rt_bvh_level bvh_level,
const struct intel_device_info *devinfo)
{
nir_def *ray_addr =
brw_nir_rt_mem_ray_addr(b, brw_nir_rt_stack_addr(b), bvh_level);
@ -837,21 +860,6 @@ brw_nir_rt_store_mem_ray(nir_builder *b,
defs->t_far),
~0 /* write mask */);
assert_def_size(defs->root_node_ptr, 1, 64);
assert_def_size(defs->ray_flags, 1, 16);
assert_def_size(defs->hit_group_sr_base_ptr, 1, 64);
assert_def_size(defs->hit_group_sr_stride, 1, 16);
brw_nir_rt_store(b, nir_iadd_imm(b, ray_addr, 32), 16,
nir_vec4(b, nir_unpack_64_2x32_split_x(b, defs->root_node_ptr),
nir_pack_32_2x16_split(b,
nir_unpack_64_4x16_split_z(b, defs->root_node_ptr),
defs->ray_flags),
nir_unpack_64_2x32_split_x(b, defs->hit_group_sr_base_ptr),
nir_pack_32_2x16_split(b,
nir_unpack_64_4x16_split_z(b, defs->hit_group_sr_base_ptr),
defs->hit_group_sr_stride)),
~0 /* write mask */);
/* leaf_ptr is optional */
nir_def *inst_leaf_ptr;
if (defs->inst_leaf_ptr) {
@ -860,33 +868,122 @@ brw_nir_rt_store_mem_ray(nir_builder *b,
inst_leaf_ptr = nir_imm_int64(b, 0);
}
assert_def_size(defs->miss_sr_ptr, 1, 64);
assert_def_size(defs->shader_index_multiplier, 1, 32);
assert_def_size(defs->root_node_ptr, 1, 64);
assert_def_size(inst_leaf_ptr, 1, 64);
assert_def_size(defs->ray_mask, 1, 32);
brw_nir_rt_store(b, nir_iadd_imm(b, ray_addr, 48), 16,
nir_vec4(b, nir_unpack_64_2x32_split_x(b, defs->miss_sr_ptr),
nir_pack_32_2x16_split(b,
nir_unpack_64_4x16_split_z(b, defs->miss_sr_ptr),
nir_unpack_32_2x16_split_x(b,
nir_ishl(b, defs->shader_index_multiplier,
nir_imm_int(b, 8)))),
nir_unpack_64_2x32_split_x(b, inst_leaf_ptr),
nir_pack_32_2x16_split(b,
nir_unpack_64_4x16_split_z(b, inst_leaf_ptr),
nir_unpack_32_2x16_split_x(b, defs->ray_mask))),
~0 /* write mask */);
assert_def_size(defs->ray_flags, 1, 16);
if (devinfo->ver >= 30) {
brw_nir_rt_store(b, nir_iadd_imm(b, ray_addr, 32), 16,
nir_vec4(b, nir_unpack_64_2x32_split_x(b, defs->root_node_ptr),
nir_unpack_64_2x32_split_y(b, defs->root_node_ptr),
nir_unpack_64_2x32_split_x(b, inst_leaf_ptr),
nir_unpack_64_2x32_split_y(b, inst_leaf_ptr)),
~0 /* write mask */);
assert_def_size(defs->ray_mask, 1, 32);
assert_def_size(defs->miss_shader_index, 1, 16);
assert_def_size(defs->shader_index_multiplier, 1, 32);
nir_def *packed0 = nir_pack_32_2x16_split(b,
defs->ray_flags,
nir_unpack_32_2x16_split_x(b, defs->ray_mask));
/* internalRayFlags are not used at the moment */
nir_def *packed1 = nir_pack_32_2x16_split(b,
defs->miss_shader_index,
nir_unpack_32_2x16_split_x(b, defs->shader_index_multiplier));
brw_nir_rt_store(b, nir_iadd_imm(b, ray_addr, 48), 16,
nir_vec3(b, packed0, defs->hit_group_index, packed1),
0x7 /* write mask */);
} else {
assert_def_size(defs->hit_group_sr_base_ptr, 1, 64);
assert_def_size(defs->hit_group_sr_stride, 1, 16);
brw_nir_rt_store(b, nir_iadd_imm(b, ray_addr, 32), 16,
nir_vec4(b, nir_unpack_64_2x32_split_x(b, defs->root_node_ptr),
nir_pack_32_2x16_split(b,
nir_unpack_64_4x16_split_z(b, defs->root_node_ptr),
defs->ray_flags),
nir_unpack_64_2x32_split_x(b, defs->hit_group_sr_base_ptr),
nir_pack_32_2x16_split(b,
nir_unpack_64_4x16_split_z(b, defs->hit_group_sr_base_ptr),
defs->hit_group_sr_stride)),
~0 /* write mask */);
assert_def_size(defs->miss_sr_ptr, 1, 64);
assert_def_size(defs->shader_index_multiplier, 1, 32);
assert_def_size(defs->ray_mask, 1, 32);
brw_nir_rt_store(b, nir_iadd_imm(b, ray_addr, 48), 16,
nir_vec4(b, nir_unpack_64_2x32_split_x(b, defs->miss_sr_ptr),
nir_pack_32_2x16_split(b,
nir_unpack_64_4x16_split_z(b, defs->miss_sr_ptr),
nir_unpack_32_2x16_split_x(b,
nir_ishl(b, defs->shader_index_multiplier,
nir_imm_int(b, 8)))),
nir_unpack_64_2x32_split_x(b, inst_leaf_ptr),
nir_pack_32_2x16_split(b,
nir_unpack_64_4x16_split_z(b, inst_leaf_ptr),
nir_unpack_32_2x16_split_x(b, defs->ray_mask))),
~0 /* write mask */);
}
}
/* On Xe3+, MemRay memory data structure (Bspec 56933):
* 64b version:
*
* org_x 32 the origin of the ray
* org_y 32 the origin of the ray
* org_z 32 the origin of the ray
* dir_x 32 the direction of the ray
* dir_y 32 the direction of the ray
* dir_z 32 the direction of the ray
* tnear 32 the start of the ray
* tfar 32 the end of the ray
* rootNodePtr 64 root node to start traversal at (64-byte
* alignment)
* instLeafPtr 64 the pointer to instance leaf in case we
* traverse an instance (64-bytes alignment)
* rayFlags 16 ray flags (see RayFlag structure)
* rayMask 8 ray mask used for ray masking
* comparisonValue 7 to be compared with Instance.ComparisonMask
* pad 1
* hitGroupIndex 32 hit group shader index
* missShaderIndex 16 index of miss shader to invoke on a miss
* shaderIndexMultiplier 4 shader index multiplier
* pad2 4
* internalRayFlags 8 internal ray flags
*
* On older platforms (< Xe3):
* 48b version:
*
* org_x 32 the origin of the ray
* org_y 32 the origin of the ray
* org_z 32 the origin of the ray
* dir_x 32 the direction of the ray
* dir_y 32 the direction of the ray
* dir_z 32 the direction of the ray
* tnear 32 the start of the ray
* tfar 32 the end of the ray
* rootNodePtr 48 root node to start traversal at
* rayFlags 16 ray flags (see RayFlag structure)
* hitGroupSRBasePtr 48 base of hit group shader record array (8-bytes
* alignment)
* hitGroupSRStride 16 stride of hit group shader record array (8-bytes
* alignment)
* missSRPtr 48 pointer to miss shader record to invoke on a
* miss (8-bytes alignment)
* pad 8
* shaderIndexMultiplier 8 shader index multiplier
* instLeafPtr 48 the pointer to instance leaf in case we traverse an
* instance (64-bytes alignment)
* rayMask 8 ray mask used for ray masking
*/
static inline void
brw_nir_rt_load_mem_ray_from_addr(nir_builder *b,
struct brw_nir_rt_mem_ray_defs *defs,
nir_def *ray_base_addr,
enum brw_rt_bvh_level bvh_level)
enum brw_rt_bvh_level bvh_level,
const struct intel_device_info *devinfo)
{
nir_def *ray_addr = brw_nir_rt_mem_ray_addr(b,
ray_base_addr,
bvh_level);
nir_def *ray_addr = brw_nir_rt_mem_ray_addr(b, ray_base_addr, bvh_level);
nir_def *data[4] = {
brw_nir_rt_load(b, nir_iadd_imm(b, ray_addr, 0), 16, 4, 32),
@ -901,40 +998,62 @@ brw_nir_rt_load_mem_ray_from_addr(nir_builder *b,
nir_channel(b, data[1], 1));
defs->t_near = nir_channel(b, data[1], 2);
defs->t_far = nir_channel(b, data[1], 3);
defs->root_node_ptr =
nir_pack_64_2x32_split(b, nir_channel(b, data[2], 0),
if (devinfo->ver >= 30) {
defs->root_node_ptr =
nir_pack_64_2x32_split(b, nir_channel(b, data[2], 0),
nir_channel(b, data[2], 1));
defs->inst_leaf_ptr =
nir_pack_64_2x32_split(b, nir_channel(b, data[2], 2),
nir_channel(b, data[2], 3));
defs->ray_flags =
nir_unpack_32_2x16_split_x(b, nir_channel(b, data[3], 0));
defs->ray_mask =
nir_iand_imm(b, nir_unpack_32_2x16_split_y(b, nir_channel(b, data[3], 0)),
0xff);
defs->hit_group_index = nir_channel(b, data[3], 1);
defs->miss_shader_index =
nir_unpack_32_2x16_split_x(b, nir_channel(b, data[3], 2));
defs->shader_index_multiplier =
nir_iand_imm(b, nir_unpack_32_2x16_split_y(b, nir_channel(b, data[3], 2)),
0xf);
} else {
defs->root_node_ptr =
nir_pack_64_2x32_split(b, nir_channel(b, data[2], 0),
nir_extract_i16(b, nir_channel(b, data[2], 1),
nir_imm_int(b, 0)));
defs->ray_flags =
nir_unpack_32_2x16_split_y(b, nir_channel(b, data[2], 1));
defs->hit_group_sr_base_ptr =
nir_pack_64_2x32_split(b, nir_channel(b, data[2], 2),
defs->ray_flags =
nir_unpack_32_2x16_split_y(b, nir_channel(b, data[2], 1));
defs->hit_group_sr_base_ptr =
nir_pack_64_2x32_split(b, nir_channel(b, data[2], 2),
nir_extract_i16(b, nir_channel(b, data[2], 3),
nir_imm_int(b, 0)));
defs->hit_group_sr_stride =
nir_unpack_32_2x16_split_y(b, nir_channel(b, data[2], 3));
defs->miss_sr_ptr =
nir_pack_64_2x32_split(b, nir_channel(b, data[3], 0),
defs->hit_group_sr_stride =
nir_unpack_32_2x16_split_y(b, nir_channel(b, data[2], 3));
defs->miss_sr_ptr =
nir_pack_64_2x32_split(b, nir_channel(b, data[3], 0),
nir_extract_i16(b, nir_channel(b, data[3], 1),
nir_imm_int(b, 0)));
defs->shader_index_multiplier =
nir_ushr(b, nir_unpack_32_2x16_split_y(b, nir_channel(b, data[3], 1)),
nir_imm_int(b, 8));
defs->inst_leaf_ptr =
nir_pack_64_2x32_split(b, nir_channel(b, data[3], 2),
defs->shader_index_multiplier =
nir_ushr(b, nir_unpack_32_2x16_split_y(b, nir_channel(b, data[3], 1)),
nir_imm_int(b, 8));
defs->inst_leaf_ptr =
nir_pack_64_2x32_split(b, nir_channel(b, data[3], 2),
nir_extract_i16(b, nir_channel(b, data[3], 3),
nir_imm_int(b, 0)));
defs->ray_mask =
nir_unpack_32_2x16_split_y(b, nir_channel(b, data[3], 3));
defs->ray_mask =
nir_unpack_32_2x16_split_y(b, nir_channel(b, data[3], 3));
}
}
static inline void
brw_nir_rt_load_mem_ray(nir_builder *b,
struct brw_nir_rt_mem_ray_defs *defs,
enum brw_rt_bvh_level bvh_level)
enum brw_rt_bvh_level bvh_level,
const struct intel_device_info *devinfo)
{
brw_nir_rt_load_mem_ray_from_addr(b, defs, brw_nir_rt_stack_addr(b),
bvh_level);
bvh_level, devinfo);
}
struct brw_nir_rt_bvh_instance_leaf_defs {

View file

@ -3374,6 +3374,11 @@ compile_upload_rt_shader(struct anv_ray_tracing_pipeline *pipeline,
pipeline->base.device->physical->compiler;
const struct intel_device_info *devinfo = compiler->devinfo;
struct brw_nir_lower_shader_calls_state lowering_state = {
.devinfo = devinfo,
.key = &stage->key.bs,
};
nir_shader **resume_shaders = NULL;
uint32_t num_resume_shaders = 0;
if (nir->info.stage != MESA_SHADER_COMPUTE) {
@ -3388,12 +3393,12 @@ compile_upload_rt_shader(struct anv_ray_tracing_pipeline *pipeline,
NIR_PASS(_, nir, nir_lower_shader_calls, &opts,
&resume_shaders, &num_resume_shaders, mem_ctx);
NIR_PASS(_, nir, brw_nir_lower_shader_calls, &stage->key.bs);
NIR_PASS(_, nir, brw_nir_lower_shader_calls, &lowering_state);
NIR_PASS_V(nir, brw_nir_lower_rt_intrinsics, &stage->key.base, devinfo);
}
for (unsigned i = 0; i < num_resume_shaders; i++) {
NIR_PASS(_,resume_shaders[i], brw_nir_lower_shader_calls, &stage->key.bs);
NIR_PASS(_,resume_shaders[i], brw_nir_lower_shader_calls, &lowering_state);
NIR_PASS_V(resume_shaders[i], brw_nir_lower_rt_intrinsics, &stage->key.base, devinfo);
}