mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2026-05-06 07:18:17 +02:00
radv: Perform multiple sorts in parallel
This was the last part that didn't scale with multiple infos. Reducing the amount of barriers in this case improves DOOM Eternal performance by 50%. (Running with low resolution) Reviewed-by: Friedrich Vock <friedrich.vock@gmx.de> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/24720>
This commit is contained in:
parent
44c47054bc
commit
97b1caf9f6
1 changed files with 150 additions and 128 deletions
|
|
@ -600,6 +600,13 @@ struct bvh_state {
|
|||
struct acceleration_structure_layout accel_struct;
|
||||
struct scratch_layout scratch;
|
||||
struct build_config config;
|
||||
|
||||
/* Radix sort state */
|
||||
uint32_t scatter_blocks;
|
||||
uint32_t count_ru_scatter;
|
||||
uint32_t histo_blocks;
|
||||
uint32_t count_ru_histo;
|
||||
struct rs_push_scatter push_scatter;
|
||||
};
|
||||
|
||||
static uint32_t
|
||||
|
|
@ -726,75 +733,79 @@ morton_sort(VkCommandBuffer commandBuffer, uint32_t infoCount,
|
|||
|
||||
radix_sort_vk_t *rs = cmd_buffer->device->meta_state.accel_struct_build.radix_sort;
|
||||
|
||||
/*
|
||||
* OVERVIEW
|
||||
*
|
||||
* 1. Pad the keyvals in `scatter_even`.
|
||||
* 2. Zero the `histograms` and `partitions`.
|
||||
* --- BARRIER ---
|
||||
* 3. HISTOGRAM is dispatched before PREFIX.
|
||||
* --- BARRIER ---
|
||||
* 4. PREFIX is dispatched before the first SCATTER.
|
||||
* --- BARRIER ---
|
||||
* 5. One or more SCATTER dispatches.
|
||||
*
|
||||
* Note that the `partitions` buffer can be zeroed anytime before the first
|
||||
* scatter.
|
||||
*/
|
||||
|
||||
/* How many passes? */
|
||||
uint32_t keyval_bytes = rs->config.keyval_dwords * (uint32_t)sizeof(uint32_t);
|
||||
uint32_t keyval_bits = keyval_bytes * 8;
|
||||
uint32_t key_bits = MIN2(MORTON_BIT_SIZE, keyval_bits);
|
||||
uint32_t passes = (key_bits + RS_RADIX_LOG2 - 1) / RS_RADIX_LOG2;
|
||||
|
||||
for (uint32_t i = 0; i < infoCount; ++i) {
|
||||
uint32_t count = bvh_states[i].node_count;
|
||||
if (bvh_states[i].node_count)
|
||||
bvh_states[i].scratch_offset = bvh_states[i].scratch.sort_buffer_offset[passes & 1];
|
||||
else
|
||||
bvh_states[i].scratch_offset = bvh_states[i].scratch.sort_buffer_offset[0];
|
||||
}
|
||||
|
||||
/*
|
||||
* PAD KEYVALS AND ZERO HISTOGRAM/PARTITIONS
|
||||
*
|
||||
* Pad fractional blocks with max-valued keyvals.
|
||||
*
|
||||
* Zero the histograms and partitions buffer.
|
||||
*
|
||||
* This assumes the partitions follow the histograms.
|
||||
*/
|
||||
|
||||
/* FIXME(allanmac): Consider precomputing some of these values and hang them off `rs`. */
|
||||
|
||||
/* How many scatter blocks? */
|
||||
uint32_t scatter_wg_size = 1 << rs->config.scatter.workgroup_size_log2;
|
||||
uint32_t scatter_block_kvs = scatter_wg_size * rs->config.scatter.block_rows;
|
||||
|
||||
/*
|
||||
* How many histogram blocks?
|
||||
*
|
||||
* Note that it's OK to have more max-valued digits counted by the histogram
|
||||
* than sorted by the scatters because the sort is stable.
|
||||
*/
|
||||
uint32_t histo_wg_size = 1 << rs->config.histogram.workgroup_size_log2;
|
||||
uint32_t histo_block_kvs = histo_wg_size * rs->config.histogram.block_rows;
|
||||
|
||||
uint32_t pass_idx = (keyval_bytes - passes);
|
||||
|
||||
for (uint32_t i = 0; i < infoCount; ++i) {
|
||||
if (!bvh_states[i].node_count)
|
||||
continue;
|
||||
|
||||
uint64_t keyvals_even_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_buffer_offset[0];
|
||||
uint64_t keyvals_odd_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_buffer_offset[1];
|
||||
uint64_t internal_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_internal_offset;
|
||||
|
||||
/* Anything to do? */
|
||||
if (!count) {
|
||||
bvh_states[i].scratch_offset = bvh_states[i].scratch.sort_buffer_offset[0];
|
||||
continue;
|
||||
}
|
||||
bvh_states[i].scatter_blocks = (bvh_states[i].node_count + scatter_block_kvs - 1) / scatter_block_kvs;
|
||||
bvh_states[i].count_ru_scatter = bvh_states[i].scatter_blocks * scatter_block_kvs;
|
||||
|
||||
/*
|
||||
* OVERVIEW
|
||||
*
|
||||
* 1. Pad the keyvals in `scatter_even`.
|
||||
* 2. Zero the `histograms` and `partitions`.
|
||||
* --- BARRIER ---
|
||||
* 3. HISTOGRAM is dispatched before PREFIX.
|
||||
* --- BARRIER ---
|
||||
* 4. PREFIX is dispatched before the first SCATTER.
|
||||
* --- BARRIER ---
|
||||
* 5. One or more SCATTER dispatches.
|
||||
*
|
||||
* Note that the `partitions` buffer can be zeroed anytime before the first
|
||||
* scatter.
|
||||
*/
|
||||
|
||||
/* How many passes? */
|
||||
uint32_t keyval_bytes = rs->config.keyval_dwords * (uint32_t)sizeof(uint32_t);
|
||||
uint32_t keyval_bits = keyval_bytes * 8;
|
||||
uint32_t key_bits = MIN2(MORTON_BIT_SIZE, keyval_bits);
|
||||
uint32_t passes = (key_bits + RS_RADIX_LOG2 - 1) / RS_RADIX_LOG2;
|
||||
|
||||
bvh_states[i].scratch_offset = bvh_states[i].scratch.sort_buffer_offset[passes & 1];
|
||||
|
||||
/*
|
||||
* PAD KEYVALS AND ZERO HISTOGRAM/PARTITIONS
|
||||
*
|
||||
* Pad fractional blocks with max-valued keyvals.
|
||||
*
|
||||
* Zero the histograms and partitions buffer.
|
||||
*
|
||||
* This assumes the partitions follow the histograms.
|
||||
*/
|
||||
|
||||
/* FIXME(allanmac): Consider precomputing some of these values and hang them off `rs`. */
|
||||
|
||||
/* How many scatter blocks? */
|
||||
uint32_t scatter_wg_size = 1 << rs->config.scatter.workgroup_size_log2;
|
||||
uint32_t scatter_block_kvs = scatter_wg_size * rs->config.scatter.block_rows;
|
||||
uint32_t scatter_blocks = (count + scatter_block_kvs - 1) / scatter_block_kvs;
|
||||
uint32_t count_ru_scatter = scatter_blocks * scatter_block_kvs;
|
||||
|
||||
/*
|
||||
* How many histogram blocks?
|
||||
*
|
||||
* Note that it's OK to have more max-valued digits counted by the histogram
|
||||
* than sorted by the scatters because the sort is stable.
|
||||
*/
|
||||
uint32_t histo_wg_size = 1 << rs->config.histogram.workgroup_size_log2;
|
||||
uint32_t histo_block_kvs = histo_wg_size * rs->config.histogram.block_rows;
|
||||
uint32_t histo_blocks = (count_ru_scatter + histo_block_kvs - 1) / histo_block_kvs;
|
||||
uint32_t count_ru_histo = histo_blocks * histo_block_kvs;
|
||||
bvh_states[i].histo_blocks = (bvh_states[i].count_ru_scatter + histo_block_kvs - 1) / histo_block_kvs;
|
||||
bvh_states[i].count_ru_histo = bvh_states[i].histo_blocks * histo_block_kvs;
|
||||
|
||||
/* Fill with max values */
|
||||
if (count_ru_histo > count) {
|
||||
radv_fill_buffer(cmd_buffer, NULL, NULL, keyvals_even_addr + count * keyval_bytes,
|
||||
(count_ru_histo - count) * keyval_bytes, 0xFFFFFFFF);
|
||||
if (bvh_states[i].count_ru_histo > bvh_states[i].node_count) {
|
||||
radv_fill_buffer(cmd_buffer, NULL, NULL, keyvals_even_addr + bvh_states[i].node_count * keyval_bytes,
|
||||
(bvh_states[i].count_ru_histo - bvh_states[i].node_count) * keyval_bytes, 0xFFFFFFFF);
|
||||
}
|
||||
|
||||
/*
|
||||
|
|
@ -807,28 +818,35 @@ morton_sort(VkCommandBuffer commandBuffer, uint32_t infoCount,
|
|||
* Note that the last workgroup doesn't read/write a partition so it doesn't
|
||||
* need to be initialized.
|
||||
*/
|
||||
uint32_t histo_partition_count = passes + scatter_blocks - 1;
|
||||
uint32_t pass_idx = (keyval_bytes - passes);
|
||||
uint32_t histo_partition_count = passes + bvh_states[i].scatter_blocks - 1;
|
||||
|
||||
uint32_t fill_base = pass_idx * (RS_RADIX_SIZE * sizeof(uint32_t));
|
||||
|
||||
radv_fill_buffer(cmd_buffer, NULL, NULL, internal_addr + rs->internal.histograms.offset + fill_base,
|
||||
histo_partition_count * (RS_RADIX_SIZE * sizeof(uint32_t)), 0);
|
||||
}
|
||||
|
||||
/*
|
||||
* Pipeline: HISTOGRAM
|
||||
*
|
||||
* TODO(allanmac): All subgroups should try to process approximately the same
|
||||
* number of blocks in order to minimize tail effects. This was implemented
|
||||
* and reverted but should be reimplemented and benchmarked later.
|
||||
*/
|
||||
vk_barrier_transfer_w_to_compute_r(commandBuffer);
|
||||
/*
|
||||
* Pipeline: HISTOGRAM
|
||||
*
|
||||
* TODO(allanmac): All subgroups should try to process approximately the same
|
||||
* number of blocks in order to minimize tail effects. This was implemented
|
||||
* and reverted but should be reimplemented and benchmarked later.
|
||||
*/
|
||||
vk_barrier_transfer_w_to_compute_r(commandBuffer);
|
||||
|
||||
uint64_t devaddr_histograms = internal_addr + rs->internal.histograms.offset;
|
||||
radv_CmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.histogram);
|
||||
|
||||
for (uint32_t i = 0; i < infoCount; ++i) {
|
||||
if (!bvh_states[i].node_count)
|
||||
continue;
|
||||
|
||||
uint64_t keyvals_even_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_buffer_offset[0];
|
||||
uint64_t internal_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_internal_offset;
|
||||
|
||||
/* Dispatch histogram */
|
||||
struct rs_push_histogram push_histogram = {
|
||||
.devaddr_histograms = devaddr_histograms,
|
||||
.devaddr_histograms = internal_addr + rs->internal.histograms.offset,
|
||||
.devaddr_keyvals = keyvals_even_addr,
|
||||
.passes = passes,
|
||||
};
|
||||
|
|
@ -836,83 +854,87 @@ morton_sort(VkCommandBuffer commandBuffer, uint32_t infoCount,
|
|||
radv_CmdPushConstants(commandBuffer, rs->pipeline_layouts.named.histogram, VK_SHADER_STAGE_COMPUTE_BIT, 0,
|
||||
sizeof(push_histogram), &push_histogram);
|
||||
|
||||
radv_CmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.histogram);
|
||||
vk_common_CmdDispatch(commandBuffer, bvh_states[i].histo_blocks, 1, 1);
|
||||
}
|
||||
|
||||
vk_common_CmdDispatch(commandBuffer, histo_blocks, 1, 1);
|
||||
/*
|
||||
* Pipeline: PREFIX
|
||||
*
|
||||
* Launch one workgroup per pass.
|
||||
*/
|
||||
vk_barrier_compute_w_to_compute_r(commandBuffer);
|
||||
|
||||
/*
|
||||
* Pipeline: PREFIX
|
||||
*
|
||||
* Launch one workgroup per pass.
|
||||
*/
|
||||
vk_barrier_compute_w_to_compute_r(commandBuffer);
|
||||
radv_CmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.prefix);
|
||||
|
||||
for (uint32_t i = 0; i < infoCount; ++i) {
|
||||
if (!bvh_states[i].node_count)
|
||||
continue;
|
||||
|
||||
uint64_t internal_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_internal_offset;
|
||||
|
||||
struct rs_push_prefix push_prefix = {
|
||||
.devaddr_histograms = devaddr_histograms,
|
||||
.devaddr_histograms = internal_addr + rs->internal.histograms.offset,
|
||||
};
|
||||
|
||||
radv_CmdPushConstants(commandBuffer, rs->pipeline_layouts.named.prefix, VK_SHADER_STAGE_COMPUTE_BIT, 0,
|
||||
sizeof(push_prefix), &push_prefix);
|
||||
|
||||
radv_CmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, rs->pipelines.named.prefix);
|
||||
|
||||
vk_common_CmdDispatch(commandBuffer, passes, 1, 1);
|
||||
}
|
||||
|
||||
/* Pipeline: SCATTER */
|
||||
vk_barrier_compute_w_to_compute_r(commandBuffer);
|
||||
/* Pipeline: SCATTER */
|
||||
vk_barrier_compute_w_to_compute_r(commandBuffer);
|
||||
|
||||
uint32_t histogram_offset = pass_idx * (RS_RADIX_SIZE * sizeof(uint32_t));
|
||||
uint64_t devaddr_partitions = internal_addr + rs->internal.partitions.offset;
|
||||
uint32_t histogram_offset = pass_idx * (RS_RADIX_SIZE * sizeof(uint32_t));
|
||||
|
||||
struct rs_push_scatter push_scatter = {
|
||||
for (uint32_t i = 0; i < infoCount; i++) {
|
||||
uint64_t keyvals_even_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_buffer_offset[0];
|
||||
uint64_t keyvals_odd_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_buffer_offset[1];
|
||||
uint64_t internal_addr = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.sort_internal_offset;
|
||||
|
||||
bvh_states[i].push_scatter = (struct rs_push_scatter){
|
||||
.devaddr_keyvals_even = keyvals_even_addr,
|
||||
.devaddr_keyvals_odd = keyvals_odd_addr,
|
||||
.devaddr_partitions = devaddr_partitions,
|
||||
.devaddr_histograms = devaddr_histograms + histogram_offset,
|
||||
.pass_offset = (pass_idx & 3) * RS_RADIX_LOG2,
|
||||
.devaddr_partitions = internal_addr + rs->internal.partitions.offset,
|
||||
.devaddr_histograms = internal_addr + rs->internal.histograms.offset + histogram_offset,
|
||||
};
|
||||
}
|
||||
|
||||
{
|
||||
uint32_t pass_dword = pass_idx / 4;
|
||||
bool is_even = true;
|
||||
|
||||
radv_CmdPushConstants(commandBuffer, rs->pipeline_layouts.named.scatter[pass_dword].even,
|
||||
VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(push_scatter), &push_scatter);
|
||||
while (true) {
|
||||
uint32_t pass_dword = pass_idx / 4;
|
||||
|
||||
radv_CmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE,
|
||||
rs->pipelines.named.scatter[pass_dword].even);
|
||||
/* Bind new pipeline */
|
||||
VkPipeline p =
|
||||
is_even ? rs->pipelines.named.scatter[pass_dword].even : rs->pipelines.named.scatter[pass_dword].odd;
|
||||
radv_CmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, p);
|
||||
|
||||
/* Update push constants that changed */
|
||||
VkPipelineLayout pl = is_even ? rs->pipeline_layouts.named.scatter[pass_dword].even
|
||||
: rs->pipeline_layouts.named.scatter[pass_dword].odd;
|
||||
|
||||
for (uint32_t i = 0; i < infoCount; i++) {
|
||||
if (!bvh_states[i].node_count)
|
||||
continue;
|
||||
|
||||
bvh_states[i].push_scatter.pass_offset = (pass_idx & 3) * RS_RADIX_LOG2;
|
||||
|
||||
radv_CmdPushConstants(commandBuffer, pl, VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(struct rs_push_scatter),
|
||||
&bvh_states[i].push_scatter);
|
||||
|
||||
vk_common_CmdDispatch(commandBuffer, bvh_states[i].scatter_blocks, 1, 1);
|
||||
|
||||
bvh_states[i].push_scatter.devaddr_histograms += (RS_RADIX_SIZE * sizeof(uint32_t));
|
||||
}
|
||||
|
||||
bool is_even = true;
|
||||
/* Continue? */
|
||||
if (++pass_idx >= keyval_bytes)
|
||||
break;
|
||||
|
||||
while (true) {
|
||||
vk_common_CmdDispatch(commandBuffer, scatter_blocks, 1, 1);
|
||||
vk_barrier_compute_w_to_compute_r(commandBuffer);
|
||||
|
||||
/* Continue? */
|
||||
if (++pass_idx >= keyval_bytes)
|
||||
break;
|
||||
|
||||
vk_barrier_compute_w_to_compute_r(commandBuffer);
|
||||
|
||||
is_even ^= true;
|
||||
push_scatter.devaddr_histograms += (RS_RADIX_SIZE * sizeof(uint32_t));
|
||||
push_scatter.pass_offset = (pass_idx & 3) * RS_RADIX_LOG2;
|
||||
|
||||
uint32_t pass_dword = pass_idx / 4;
|
||||
|
||||
/* Update push constants that changed */
|
||||
VkPipelineLayout pl = is_even ? rs->pipeline_layouts.named.scatter[pass_dword].even
|
||||
: rs->pipeline_layouts.named.scatter[pass_dword].odd;
|
||||
radv_CmdPushConstants(commandBuffer, pl, VK_SHADER_STAGE_COMPUTE_BIT,
|
||||
offsetof(struct rs_push_scatter, devaddr_histograms),
|
||||
sizeof(push_scatter.devaddr_histograms) + sizeof(push_scatter.pass_offset),
|
||||
&push_scatter.devaddr_histograms);
|
||||
|
||||
/* Bind new pipeline */
|
||||
VkPipeline p =
|
||||
is_even ? rs->pipelines.named.scatter[pass_dword].even : rs->pipelines.named.scatter[pass_dword].odd;
|
||||
|
||||
radv_CmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, p);
|
||||
}
|
||||
is_even ^= true;
|
||||
}
|
||||
|
||||
cmd_buffer->state.flush_bits |= flush_bits;
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue