anv: refactor ray tracing dispatch

Preparing for vkCmdTraceRaysIndirect2KHR

Signed-off-by: Lionel Landwerlin <lionel.g.landwerlin@intel.com>
Reviewed-by: Ivan Briano <ivan.briano@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/20011>
This commit is contained in:
Lionel Landwerlin 2022-11-25 22:01:10 +02:00 committed by Marge Bot
parent df38426072
commit 675c5bd4cc
2 changed files with 116 additions and 70 deletions

View file

@ -73,6 +73,12 @@ genX_bits_included_symbols = [
'CLEAR_COLOR',
'VERTEX_BUFFER_STATE::Buffer Starting Address',
'CPS_STATE',
'RT_DISPATCH_GLOBALS::Hit Group Table',
'RT_DISPATCH_GLOBALS::Miss Group Table',
'RT_DISPATCH_GLOBALS::Callable Group Table',
'RT_DISPATCH_GLOBALS::Launch Width',
'RT_DISPATCH_GLOBALS::Launch Height',
'RT_DISPATCH_GLOBALS::Launch Depth',
]
genX_bits_h = custom_target(

View file

@ -5617,17 +5617,71 @@ vk_sdar_to_shader_table(const VkStridedDeviceAddressRegionKHR *region)
};
}
struct trace_params {
/* If is_sbt_indirect, use indirect_sbts_addr to build RT_DISPATCH_GLOBALS
* with mi_builder.
*/
bool is_sbt_indirect;
const VkStridedDeviceAddressRegionKHR *raygen_sbt;
const VkStridedDeviceAddressRegionKHR *miss_sbt;
const VkStridedDeviceAddressRegionKHR *hit_sbt;
const VkStridedDeviceAddressRegionKHR *callable_sbt;
/* A pointer to a VkTraceRaysIndirectCommand2KHR structure */
uint64_t indirect_sbts_addr;
/* If is_indirect, use launch_size_addr to program the dispatch size. */
bool is_launch_size_indirect;
uint32_t launch_size[3];
/* A pointer a uint32_t[3] */
uint64_t launch_size_addr;
};
static struct anv_state
cmd_buffer_emit_rt_dispatch_globals(struct anv_cmd_buffer *cmd_buffer,
struct trace_params *params)
{
assert(!params->is_sbt_indirect);
assert(params->miss_sbt != NULL);
assert(params->hit_sbt != NULL);
assert(params->callable_sbt != NULL);
struct anv_cmd_ray_tracing_state *rt = &cmd_buffer->state.rt;
struct anv_state rtdg_state =
anv_cmd_buffer_alloc_dynamic_state(cmd_buffer,
BRW_RT_PUSH_CONST_OFFSET +
sizeof(struct anv_push_constants),
64);
struct GENX(RT_DISPATCH_GLOBALS) rtdg = {
.MemBaseAddress = (struct anv_address) {
.bo = rt->scratch.bo,
.offset = rt->scratch.layout.ray_stack_start,
},
.CallStackHandler = anv_shader_bin_get_bsr(
cmd_buffer->device->rt_trivial_return, 0),
.AsyncRTStackSize = rt->scratch.layout.ray_stack_stride / 64,
.NumDSSRTStacks = rt->scratch.layout.stack_ids_per_dss,
.MaxBVHLevels = BRW_RT_MAX_BVH_LEVELS,
.Flags = RT_DEPTH_TEST_LESS_EQUAL,
.HitGroupTable = vk_sdar_to_shader_table(params->hit_sbt),
.MissGroupTable = vk_sdar_to_shader_table(params->miss_sbt),
.SWStackSize = rt->scratch.layout.sw_stack_size / 64,
.LaunchWidth = params->launch_size[0],
.LaunchHeight = params->launch_size[1],
.LaunchDepth = params->launch_size[2],
.CallableGroupTable = vk_sdar_to_shader_table(params->callable_sbt),
};
GENX(RT_DISPATCH_GLOBALS_pack)(NULL, rtdg_state.map, &rtdg);
return rtdg_state;
}
static void
cmd_buffer_trace_rays(struct anv_cmd_buffer *cmd_buffer,
const VkStridedDeviceAddressRegionKHR *raygen_sbt,
const VkStridedDeviceAddressRegionKHR *miss_sbt,
const VkStridedDeviceAddressRegionKHR *hit_sbt,
const VkStridedDeviceAddressRegionKHR *callable_sbt,
bool is_indirect,
uint32_t launch_width,
uint32_t launch_height,
uint32_t launch_depth,
uint64_t launch_size_addr)
struct trace_params *params)
{
struct anv_device *device = cmd_buffer->device;
struct anv_cmd_ray_tracing_state *rt = &cmd_buffer->state.rt;
@ -5637,8 +5691,10 @@ cmd_buffer_trace_rays(struct anv_cmd_buffer *cmd_buffer,
return;
/* If we have a known degenerate launch size, just bail */
if (!is_indirect &&
(launch_width == 0 || launch_height == 0 || launch_depth == 0))
if (!params->is_launch_size_indirect &&
(params->launch_size[0] == 0 ||
params->launch_size[1] == 0 ||
params->launch_size[2] == 0))
return;
genX(cmd_buffer_config_l3)(cmd_buffer, pipeline->base.l3_config);
@ -5662,34 +5718,12 @@ cmd_buffer_trace_rays(struct anv_cmd_buffer *cmd_buffer,
/* Allocate and set up our RT_DISPATCH_GLOBALS */
struct anv_state rtdg_state =
anv_cmd_buffer_alloc_dynamic_state(cmd_buffer,
BRW_RT_PUSH_CONST_OFFSET +
sizeof(struct anv_push_constants),
64);
cmd_buffer_emit_rt_dispatch_globals(cmd_buffer, params);
struct GENX(RT_DISPATCH_GLOBALS) rtdg = {
.MemBaseAddress = (struct anv_address) {
.bo = rt->scratch.bo,
.offset = rt->scratch.layout.ray_stack_start,
},
.CallStackHandler =
anv_shader_bin_get_bsr(cmd_buffer->device->rt_trivial_return, 0),
.AsyncRTStackSize = rt->scratch.layout.ray_stack_stride / 64,
.NumDSSRTStacks = rt->scratch.layout.stack_ids_per_dss,
.MaxBVHLevels = BRW_RT_MAX_BVH_LEVELS,
.Flags = RT_DEPTH_TEST_LESS_EQUAL,
.HitGroupTable = vk_sdar_to_shader_table(hit_sbt),
.MissGroupTable = vk_sdar_to_shader_table(miss_sbt),
.SWStackSize = rt->scratch.layout.sw_stack_size / 64,
.LaunchWidth = launch_width,
.LaunchHeight = launch_height,
.LaunchDepth = launch_depth,
.CallableGroupTable = vk_sdar_to_shader_table(callable_sbt),
};
GENX(RT_DISPATCH_GLOBALS_pack)(NULL, rtdg_state.map, &rtdg);
/* Push constants go after the RT_DISPATCH_GLOBALS */
assert(rtdg_state.alloc_size >= (BRW_RT_PUSH_CONST_OFFSET +
sizeof(struct anv_push_constants)));
assert(GENX(RT_DISPATCH_GLOBALS_length) * 4 <= BRW_RT_PUSH_CONST_OFFSET);
/* Push constants go after the RT_DISPATCH_GLOBALS */
memcpy(rtdg_state.map + BRW_RT_PUSH_CONST_OFFSET,
&cmd_buffer->state.rt.base.push_constants,
sizeof(struct anv_push_constants));
@ -5700,7 +5734,7 @@ cmd_buffer_trace_rays(struct anv_cmd_buffer *cmd_buffer,
uint8_t local_size_log2[3];
uint32_t global_size[3] = {};
if (is_indirect) {
if (params->is_launch_size_indirect) {
/* Pick a local size that's probably ok. We assume most TraceRays calls
* will use a two-dimensional dispatch size. Worst case, our initial
* dispatch will be a little slower than it has to be.
@ -5713,21 +5747,20 @@ cmd_buffer_trace_rays(struct anv_cmd_buffer *cmd_buffer,
mi_builder_init(&b, cmd_buffer->device->info, &cmd_buffer->batch);
struct mi_value launch_size[3] = {
mi_mem32(anv_address_from_u64(launch_size_addr + 0)),
mi_mem32(anv_address_from_u64(launch_size_addr + 4)),
mi_mem32(anv_address_from_u64(launch_size_addr + 8)),
mi_mem32(anv_address_from_u64(params->launch_size_addr + 0)),
mi_mem32(anv_address_from_u64(params->launch_size_addr + 4)),
mi_mem32(anv_address_from_u64(params->launch_size_addr + 8)),
};
/* Store the original launch size into RT_DISPATCH_GLOBALS
*
* TODO: Pull values from genX_bits.h once RT_DISPATCH_GLOBALS gets
* moved into a genX version.
*/
mi_store(&b, mi_mem32(anv_address_add(rtdg_addr, 52)),
/* Store the original launch size into RT_DISPATCH_GLOBALS */
mi_store(&b, mi_mem32(anv_address_add(rtdg_addr,
GENX(RT_DISPATCH_GLOBALS_LaunchWidth_start) / 8)),
mi_value_ref(&b, launch_size[0]));
mi_store(&b, mi_mem32(anv_address_add(rtdg_addr, 56)),
mi_store(&b, mi_mem32(anv_address_add(rtdg_addr,
GENX(RT_DISPATCH_GLOBALS_LaunchHeight_start) / 8)),
mi_value_ref(&b, launch_size[1]));
mi_store(&b, mi_mem32(anv_address_add(rtdg_addr, 60)),
mi_store(&b, mi_mem32(anv_address_add(rtdg_addr,
GENX(RT_DISPATCH_GLOBALS_LaunchDepth_start) / 8)),
mi_value_ref(&b, launch_size[2]));
/* Compute the global dispatch size */
@ -5752,15 +5785,14 @@ cmd_buffer_trace_rays(struct anv_cmd_buffer *cmd_buffer,
mi_store(&b, mi_reg32(GPGPU_DISPATCHDIMY), launch_size[1]);
mi_store(&b, mi_reg32(GPGPU_DISPATCHDIMZ), launch_size[2]);
} else {
uint32_t launch_size[3] = { launch_width, launch_height, launch_depth };
calc_local_trace_size(local_size_log2, launch_size);
calc_local_trace_size(local_size_log2, params->launch_size);
for (unsigned i = 0; i < 3; i++) {
/* We have to be a bit careful here because DIV_ROUND_UP adds to the
* numerator value may overflow. Cast to uint64_t to avoid this.
*/
uint32_t local_size = 1 << local_size_log2[i];
global_size[i] = DIV_ROUND_UP((uint64_t)launch_size[i], local_size);
global_size[i] = DIV_ROUND_UP((uint64_t)params->launch_size[i], local_size);
}
}
@ -5799,7 +5831,7 @@ cmd_buffer_trace_rays(struct anv_cmd_buffer *cmd_buffer,
brw_cs_get_dispatch_info(device->info, cs_prog_data, NULL);
anv_batch_emit(&cmd_buffer->batch, GENX(COMPUTE_WALKER), cw) {
cw.IndirectParameterEnable = is_indirect;
cw.IndirectParameterEnable = params->is_launch_size_indirect;
cw.PredicateEnable = cmd_buffer->state.conditional_render_enabled;
cw.SIMDSize = dispatch.simd_size / 16;
cw.LocalXMaximum = (1 << local_size_log2[0]) - 1;
@ -5828,7 +5860,7 @@ cmd_buffer_trace_rays(struct anv_cmd_buffer *cmd_buffer,
struct brw_rt_raygen_trampoline_params trampoline_params = {
.rt_disp_globals_addr = anv_address_physical(rtdg_addr),
.raygen_bsr_addr = raygen_sbt->deviceAddress,
.raygen_bsr_addr = params->raygen_sbt->deviceAddress,
.is_indirect = false, /* Only for raygen_bsr_addr */
.local_group_size_log2 = {
local_size_log2[0],
@ -5853,15 +5885,21 @@ genX(CmdTraceRaysKHR)(
uint32_t depth)
{
ANV_FROM_HANDLE(anv_cmd_buffer, cmd_buffer, commandBuffer);
struct trace_params params = {
.is_sbt_indirect = false,
.raygen_sbt = pRaygenShaderBindingTable,
.miss_sbt = pMissShaderBindingTable,
.hit_sbt = pHitShaderBindingTable,
.callable_sbt = pCallableShaderBindingTable,
.is_launch_size_indirect = false,
.launch_size = {
width,
height,
depth,
},
};
cmd_buffer_trace_rays(cmd_buffer,
pRaygenShaderBindingTable,
pMissShaderBindingTable,
pHitShaderBindingTable,
pCallableShaderBindingTable,
false /* is_indirect */,
width, height, depth,
0 /* launch_size_addr */);
cmd_buffer_trace_rays(cmd_buffer, &params);
}
void
@ -5874,15 +5912,17 @@ genX(CmdTraceRaysIndirectKHR)(
VkDeviceAddress indirectDeviceAddress)
{
ANV_FROM_HANDLE(anv_cmd_buffer, cmd_buffer, commandBuffer);
struct trace_params params = {
.is_sbt_indirect = false,
.raygen_sbt = pRaygenShaderBindingTable,
.miss_sbt = pMissShaderBindingTable,
.hit_sbt = pHitShaderBindingTable,
.callable_sbt = pCallableShaderBindingTable,
.is_launch_size_indirect = true,
.launch_size_addr = indirectDeviceAddress,
};
cmd_buffer_trace_rays(cmd_buffer,
pRaygenShaderBindingTable,
pMissShaderBindingTable,
pHitShaderBindingTable,
pCallableShaderBindingTable,
true /* is_indirect */,
0, 0, 0, /* width, height, depth, */
indirectDeviceAddress);
cmd_buffer_trace_rays(cmd_buffer, &params);
}
#endif /* GFX_VERx10 >= 125 */