anv: Update RT dispatch globals to use 64bit data structure

Rework (Kevin)
- Fix Hit/Miss/Resume shader group table value

Signed-off-by: Sagar Ghuge <sagar.ghuge@intel.com>
Reviewed-by: Kevin Chuang <kaiwenjon23@gmail.com>
Reviewed-by: Lionel Landwerlin <lionel.g.landwerlin@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/33047>
This commit is contained in:
Sagar Ghuge 2023-10-17 21:54:39 -07:00 committed by Marge Bot
parent fcd5fe4a75
commit 6deb1950a4
6 changed files with 99 additions and 32 deletions

View file

@ -508,7 +508,8 @@ lower_ray_query_impl(nir_function_impl *impl, struct lowering_state *state)
state->rq_globals = nir_load_ray_query_global_intel(b);
brw_nir_rt_load_globals_addr(b, &state->globals, state->rq_globals);
brw_nir_rt_load_globals_addr(b, &state->globals, state->rq_globals,
state->devinfo);
nir_foreach_block_safe(block, impl) {
nir_foreach_instr_safe(instr, block) {

View file

@ -73,7 +73,7 @@ lower_rt_intrinsics_impl(nir_function_impl *impl,
nir_builder *b = &build;
struct brw_nir_rt_globals_defs globals;
brw_nir_rt_load_globals(b, &globals);
brw_nir_rt_load_globals(b, &globals, devinfo);
nir_def *hotzone_addr = brw_nir_rt_sw_hotzone_addr(b, devinfo);
nir_def *hotzone = nir_load_global(b, hotzone_addr, 16, 4, 32);

View file

@ -305,7 +305,8 @@ struct brw_nir_rt_globals_defs {
static inline void
brw_nir_rt_load_globals_addr(nir_builder *b,
struct brw_nir_rt_globals_defs *defs,
nir_def *addr)
nir_def *addr,
const struct intel_device_info *devinfo)
{
nir_def *data;
data = brw_nir_rt_load_const(b, 16, addr);
@ -316,38 +317,78 @@ brw_nir_rt_load_globals_addr(nir_builder *b,
defs->hw_stack_size = nir_channel(b, data, 4);
defs->num_dss_rt_stacks = nir_iand_imm(b, nir_channel(b, data, 5), 0xffff);
defs->hit_sbt_addr =
nir_pack_64_2x32_split(b, nir_channel(b, data, 8),
nir_extract_i16(b, nir_channel(b, data, 9),
nir_imm_int(b, 0)));
defs->hit_sbt_stride =
nir_unpack_32_2x16_split_y(b, nir_channel(b, data, 9));
defs->miss_sbt_addr =
nir_pack_64_2x32_split(b, nir_channel(b, data, 10),
nir_extract_i16(b, nir_channel(b, data, 11),
nir_imm_int(b, 0)));
defs->miss_sbt_stride =
nir_unpack_32_2x16_split_y(b, nir_channel(b, data, 11));
if (devinfo->ver >= 30) {
/* maxBVHLevels are not used yet. */
defs->hit_sbt_stride =
nir_iand_imm(b, nir_ishr_imm(b, nir_channel(b, data, 6), 0x3), 0x1fff);
defs->miss_sbt_stride =
nir_iand_imm(b, nir_unpack_32_2x16_split_y(b, nir_channel(b, data, 6)),
0x1fff);
/* per context control flags are not used yet. */
/* Bspec 56933 (r58935):
*
* hitGroupBasePtr: [63:4] Canonical address with 58b address-space,16B
* aligned GPUVA : base pointer of hit group shader
* record array (16-bytes alignment)
*/
defs->hit_sbt_addr = nir_pack_64_2x32(b, nir_channels(b, data, 0x3 << 8));
/* Bspec 56933 (r58935):
*
* missShaderBasePtr: [63:3] Canonical address with 58b address-space,8B
* aligned GPUVA: base pointer of miss shader record
* array (8-bytes alignment)
*/
defs->miss_sbt_addr = nir_pack_64_2x32(b, nir_channels(b, data, 0x3 << 10));
} else {
defs->hit_sbt_addr =
nir_pack_64_2x32_split(b, nir_channel(b, data, 8),
nir_extract_i16(b, nir_channel(b, data, 9),
nir_imm_int(b, 0)));
defs->hit_sbt_stride =
nir_unpack_32_2x16_split_y(b, nir_channel(b, data, 9));
defs->miss_sbt_addr =
nir_pack_64_2x32_split(b, nir_channel(b, data, 10),
nir_extract_i16(b, nir_channel(b, data, 11),
nir_imm_int(b, 0)));
defs->miss_sbt_stride =
nir_unpack_32_2x16_split_y(b, nir_channel(b, data, 11));
}
defs->sw_stack_size = nir_channel(b, data, 12);
defs->launch_size = nir_channels(b, data, 0x7u << 13);
data = brw_nir_rt_load_const(b, 8, nir_iadd_imm(b, addr, 64));
defs->call_sbt_addr =
nir_pack_64_2x32_split(b, nir_channel(b, data, 0),
nir_extract_i16(b, nir_channel(b, data, 1),
nir_imm_int(b, 0)));
defs->call_sbt_stride =
nir_unpack_32_2x16_split_y(b, nir_channel(b, data, 1));
defs->resume_sbt_addr =
nir_pack_64_2x32(b, nir_channels(b, data, 0x3 << 2));
if (devinfo->ver >= 30) {
defs->call_sbt_addr = nir_pack_64_2x32_split(b, nir_channel(b, data, 0),
nir_channel(b, data, 1));
defs->call_sbt_stride =
nir_iand_imm(b, nir_unpack_32_2x16_split_x(b, nir_channel(b, data, 2)),
0x1fff);
defs->resume_sbt_addr =
nir_pack_64_2x32(b, nir_channels(b, data, 0x3 << 3));
} else {
defs->call_sbt_addr =
nir_pack_64_2x32_split(b, nir_channel(b, data, 0),
nir_extract_i16(b, nir_channel(b, data, 1),
nir_imm_int(b, 0)));
defs->call_sbt_stride =
nir_unpack_32_2x16_split_y(b, nir_channel(b, data, 1));
defs->resume_sbt_addr =
nir_pack_64_2x32(b, nir_channels(b, data, 0x3 << 2));
}
}
static inline void
brw_nir_rt_load_globals(nir_builder *b,
struct brw_nir_rt_globals_defs *defs)
struct brw_nir_rt_globals_defs *defs,
const struct intel_device_info *devinfo)
{
brw_nir_rt_load_globals_addr(b, defs, nir_load_btd_global_arg_addr_intel(b));
brw_nir_rt_load_globals_addr(b, defs, nir_load_btd_global_arg_addr_intel(b),
devinfo);
}
static inline nir_def *

View file

@ -17,22 +17,25 @@
<field name="Kernel Start Pointer" start="6" end="31" type="offset" />
<field name="Registers Per Thread" start="59" end="62" type="uint" />
</struct>
<struct name="RT_DISPATCH_GLOBALS" length="20">
<struct name="RT_DISPATCH_GLOBALS" length="21">
<field name="Mem Base Address" start="0" end="63" type="address" />
<field name="Call Stack Handler" start="64" end="127" type="CALL_STACK_HANDLER" />
<field name="Async RT Stack Size" start="128" end="159" type="uint" />
<field name="Num DSS RT Stacks" start="160" end="175" type="uint" />
<field name="Max BVH Levels" start="192" end="194" type="uint" />
<field name="Hit Group Stride" start="195" end="207" type="uint" />
<field name="Miss Group Stride" start="208" end="220" type="uint" />
<field name="Flags" start="224" end="224" type="uint">
<value name="RT_DEPTH_TEST_LESS_EQUAL" value="1" />
</field>
<field name="Hit Group Table" start="256" end="319" type="RT_SHADER_TABLE" />
<field name="Miss Group Table" start="320" end="383" type="RT_SHADER_TABLE" />
<field name="Hit Group Table" start="256" end="319" type="address" />
<field name="Miss Group Table" start="320" end="383" type="address" />
<field name="SW Stack Size" start="384" end="415" type="uint" />
<field name="Launch Width" start="416" end="447" type="uint" />
<field name="Launch Height" start="448" end="479" type="uint" />
<field name="Launch Depth" start="480" end="511" type="uint" />
<field name="Callable Group Table" start="512" end="575" type="RT_SHADER_TABLE" />
<field name="Resume Shader Table" start="576" end="639" type="address" />
<field name="Callable Group Table" start="512" end="575" type="address" />
<field name="Callable Group Stride" start="576" end="588" type="uint" />
<field name="Resume Shader Table" start="608" end="671" type="address" />
</struct>
</genxml>

View file

@ -1497,8 +1497,15 @@ get_properties(const struct anv_physical_device *pdevice,
/* TODO */
props->shaderGroupHandleSize = 32;
props->maxRayRecursionDepth = 31;
/* MemRay::hitGroupSRStride is 16 bits */
props->maxShaderGroupStride = UINT16_MAX;
if (pdevice->info.ver >= 30) {
/* RTDispatchGlobals::missShaderStride is 13-bit wide. The maximum
* here is a 13-bit wide max value.
*/
props->maxShaderGroupStride = (1U << 13) - 1;
} else {
/* MemRay::hitGroupSRStride is 16 bits */
props->maxShaderGroupStride = UINT16_MAX;
}
/* MemRay::hitGroupSRBasePtr requires 16B alignment */
props->shaderGroupBaseAlignment = 16;
props->shaderGroupHandleAlignment = 16;

View file

@ -1057,13 +1057,28 @@ cmd_buffer_emit_rt_dispatch_globals(struct anv_cmd_buffer *cmd_buffer,
.NumDSSRTStacks = rt->scratch.layout.stack_ids_per_dss,
.MaxBVHLevels = BRW_RT_MAX_BVH_LEVELS,
.Flags = RT_DEPTH_TEST_LESS_EQUAL,
#if GFX_VER >= 30
.HitGroupStride = params->hit_sbt->stride,
.MissGroupStride = params->miss_sbt->stride,
.HitGroupTable =
anv_address_from_u64(params->hit_sbt->deviceAddress),
.MissGroupTable =
anv_address_from_u64(params->miss_sbt->deviceAddress),
#else
.HitGroupTable = vk_sdar_to_shader_table(params->hit_sbt),
.MissGroupTable = vk_sdar_to_shader_table(params->miss_sbt),
#endif
.SWStackSize = rt->scratch.layout.sw_stack_size / 64,
.LaunchWidth = params->launch_size[0],
.LaunchHeight = params->launch_size[1],
.LaunchDepth = params->launch_size[2],
#if GFX_VER >= 30
.CallableGroupTable =
anv_address_from_u64(params->callable_sbt->deviceAddress),
.CallableGroupStride = params->callable_sbt->stride,
#else
.CallableGroupTable = vk_sdar_to_shader_table(params->callable_sbt),
#endif
};
GENX(RT_DISPATCH_GLOBALS_pack)(NULL, rtdg_state.map, &rtdg);