vulkan: Remove bvh_state

This will allow passing the whole array to the driver.

Reviewed-by: Natalie Vock <natalie.vock@gmx.de>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/39752>
This commit is contained in:
Konstantin Seurer 2026-02-07 12:15:47 +01:00 committed by Marge Bot
parent 50b1becdde
commit 411e23d389
2 changed files with 168 additions and 171 deletions

View file

@ -36,9 +36,6 @@
#include "bvh/vk_build_interface.h"
#include "bvh/vk_bvh.h"
#include "radix_sort/common/vk/barrier.h"
#include "radix_sort/shaders/push.h"
#include "util/u_string.h"
#include "util/bitset.h"
@ -253,23 +250,6 @@ vk_acceleration_structure_build_state_init(struct vk_acceleration_structure_buil
}
}
struct bvh_state {
struct vk_acceleration_structure_build_state vk;
uint32_t build_flags;
uint32_t scratch_offset;
/* Radix sort state */
uint32_t scatter_blocks;
uint32_t count_ru_scatter;
uint32_t histo_blocks;
uint32_t count_ru_histo;
struct rs_push_scatter push_scatter;
uint32_t last_encode_pass;
};
struct bvh_batch_state {
bool any_updateable;
bool any_non_updateable;
@ -543,8 +523,8 @@ static VkResult
build_leaves(VkCommandBuffer commandBuffer, struct vk_device *device,
struct vk_meta_device *meta,
const struct vk_acceleration_structure_build_args *args,
struct bvh_state *bvh_states, uint32_t build_count,
uint32_t build_flags)
struct vk_acceleration_structure_build_state *states,
uint32_t build_count, uint32_t build_flags)
{
VkPipeline pipeline;
VkPipelineLayout layout;
@ -588,36 +568,36 @@ build_leaves(VkCommandBuffer commandBuffer, struct vk_device *device,
commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline);
for (uint32_t i = 0; i < build_count; ++i) {
if (bvh_states[i].vk.config.internal_type == VK_INTERNAL_BUILD_TYPE_UPDATE)
if (states[i].config.internal_type == VK_INTERNAL_BUILD_TYPE_UPDATE)
continue;
if ((bvh_states[i].build_flags & VK_BUILD_LEAVES_FLAGS) != build_flags)
if ((states[i].build_flags & VK_BUILD_LEAVES_FLAGS) != build_flags)
continue;
const VkAccelerationStructureBuildGeometryInfoKHR *build_info = bvh_states[i].vk.build_info;
const VkAccelerationStructureBuildGeometryInfoKHR *build_info = states[i].build_info;
uint64_t scratch_addr = build_info->scratchData.deviceAddress;
struct leaf_args leaf_consts = {
.bvh = scratch_addr + bvh_states[i].vk.scratch.ir_offset,
.header = scratch_addr + bvh_states[i].vk.scratch.header_offset,
.ids = scratch_addr + bvh_states[i].vk.scratch.sort_buffer_offset[0],
.bvh = scratch_addr + states[i].scratch.ir_offset,
.header = scratch_addr + states[i].scratch.header_offset,
.ids = scratch_addr + states[i].scratch.sort_buffer_offset[0],
};
for (unsigned j = 0; j < build_info->geometryCount; ++j) {
const VkAccelerationStructureGeometryKHR *geom =
build_info->pGeometries ? &build_info->pGeometries[j] : build_info->ppGeometries[j];
const VkAccelerationStructureBuildRangeInfoKHR *build_range_info = &bvh_states[i].vk.build_range_infos[j];
const VkAccelerationStructureBuildRangeInfoKHR *build_range_info = &states[i].build_range_infos[j];
if (build_range_info->primitiveCount == 0)
continue;
leaf_consts.geom_data = vk_fill_geometry_data(build_info->type, bvh_states[i].vk.leaf_node_count, j, geom, build_range_info);
leaf_consts.geom_data = vk_fill_geometry_data(build_info->type, states[i].leaf_node_count, j, geom, build_range_info);
disp->CmdPushConstants(commandBuffer, layout,
VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(leaf_consts), &leaf_consts);
device->cmd_dispatch_unaligned(commandBuffer, build_range_info->primitiveCount, 1, 1);
bvh_states[i].vk.leaf_node_count += build_range_info->primitiveCount;
states[i].leaf_node_count += build_range_info->primitiveCount;
}
}
@ -635,8 +615,8 @@ static VkResult
morton_generate(VkCommandBuffer commandBuffer, struct vk_device *device,
struct vk_meta_device *meta,
const struct vk_acceleration_structure_build_args *args,
struct bvh_state *bvh_states, uint32_t build_count,
uint32_t build_flags)
struct vk_acceleration_structure_build_state *states,
uint32_t build_count, uint32_t build_flags)
{
VkPipeline pipeline;
VkPipelineLayout layout;
@ -665,19 +645,19 @@ morton_generate(VkCommandBuffer commandBuffer, struct vk_device *device,
commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline);
for (uint32_t i = 0; i < build_count; ++i) {
if (bvh_states[i].vk.config.internal_type == VK_INTERNAL_BUILD_TYPE_UPDATE)
if (states[i].config.internal_type == VK_INTERNAL_BUILD_TYPE_UPDATE)
continue;
uint64_t scratch_addr = bvh_states[i].vk.build_info->scratchData.deviceAddress;
uint64_t scratch_addr = states[i].build_info->scratchData.deviceAddress;
const struct morton_args consts = {
.bvh = scratch_addr + bvh_states[i].vk.scratch.ir_offset,
.header = scratch_addr + bvh_states[i].vk.scratch.header_offset,
.ids = scratch_addr + bvh_states[i].vk.scratch.sort_buffer_offset[0],
.bvh = scratch_addr + states[i].scratch.ir_offset,
.header = scratch_addr + states[i].scratch.header_offset,
.ids = scratch_addr + states[i].scratch.sort_buffer_offset[0],
};
disp->CmdPushConstants(commandBuffer, layout,
VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(consts), &consts);
device->cmd_dispatch_unaligned(commandBuffer, bvh_states[i].vk.leaf_node_count, 1, 1);
device->cmd_dispatch_unaligned(commandBuffer, states[i].leaf_node_count, 1, 1);
}
if (args->emit_markers) {
@ -694,8 +674,8 @@ static VkResult
morton_sort(VkCommandBuffer commandBuffer, struct vk_device *device,
struct vk_meta_device *meta,
const struct vk_acceleration_structure_build_args *args,
struct bvh_state *bvh_states, uint32_t build_count,
uint32_t build_flags)
struct vk_acceleration_structure_build_state *states,
uint32_t build_count, uint32_t build_flags)
{
const struct vk_device_dispatch_table *disp = &device->dispatch_table;
@ -732,10 +712,10 @@ morton_sort(VkCommandBuffer commandBuffer, struct vk_device *device,
uint32_t passes = (key_bits + RS_RADIX_LOG2 - 1) / RS_RADIX_LOG2;
for (uint32_t i = 0; i < build_count; ++i) {
if (bvh_states[i].vk.leaf_node_count)
bvh_states[i].scratch_offset = bvh_states[i].vk.scratch.sort_buffer_offset[passes & 1];
if (states[i].leaf_node_count)
states[i].scratch_offset = states[i].scratch.sort_buffer_offset[passes & 1];
else
bvh_states[i].scratch_offset = bvh_states[i].vk.scratch.sort_buffer_offset[0];
states[i].scratch_offset = states[i].scratch.sort_buffer_offset[0];
}
/*
@ -766,26 +746,26 @@ morton_sort(VkCommandBuffer commandBuffer, struct vk_device *device,
uint32_t pass_idx = (keyval_bytes - passes);
for (uint32_t i = 0; i < build_count; ++i) {
if (!bvh_states[i].vk.leaf_node_count)
if (!states[i].leaf_node_count)
continue;
if (bvh_states[i].vk.config.internal_type == VK_INTERNAL_BUILD_TYPE_UPDATE)
if (states[i].config.internal_type == VK_INTERNAL_BUILD_TYPE_UPDATE)
continue;
uint64_t scratch_addr = bvh_states[i].vk.build_info->scratchData.deviceAddress;
uint64_t keyvals_even_addr = scratch_addr + bvh_states[i].vk.scratch.sort_buffer_offset[0];
uint64_t internal_addr = scratch_addr + bvh_states[i].vk.scratch.sort_internal_offset;
uint64_t scratch_addr = states[i].build_info->scratchData.deviceAddress;
uint64_t keyvals_even_addr = scratch_addr + states[i].scratch.sort_buffer_offset[0];
uint64_t internal_addr = scratch_addr + states[i].scratch.sort_internal_offset;
bvh_states[i].scatter_blocks = (bvh_states[i].vk.leaf_node_count + scatter_block_kvs - 1) / scatter_block_kvs;
bvh_states[i].count_ru_scatter = bvh_states[i].scatter_blocks * scatter_block_kvs;
states[i].scatter_blocks = (states[i].leaf_node_count + scatter_block_kvs - 1) / scatter_block_kvs;
states[i].count_ru_scatter = states[i].scatter_blocks * scatter_block_kvs;
bvh_states[i].histo_blocks = (bvh_states[i].count_ru_scatter + histo_block_kvs - 1) / histo_block_kvs;
bvh_states[i].count_ru_histo = bvh_states[i].histo_blocks * histo_block_kvs;
states[i].histo_blocks = (states[i].count_ru_scatter + histo_block_kvs - 1) / histo_block_kvs;
states[i].count_ru_histo = states[i].histo_blocks * histo_block_kvs;
/* Fill with max values */
if (bvh_states[i].count_ru_histo > bvh_states[i].vk.leaf_node_count) {
if (states[i].count_ru_histo > states[i].leaf_node_count) {
device->cmd_fill_buffer_addr(commandBuffer, keyvals_even_addr +
bvh_states[i].vk.leaf_node_count * keyval_bytes,
(bvh_states[i].count_ru_histo - bvh_states[i].vk.leaf_node_count) * keyval_bytes,
states[i].leaf_node_count * keyval_bytes,
(states[i].count_ru_histo - states[i].leaf_node_count) * keyval_bytes,
0xFFFFFFFF);
}
@ -799,7 +779,7 @@ morton_sort(VkCommandBuffer commandBuffer, struct vk_device *device,
* Note that the last workgroup doesn't read/write a partition so it doesn't
* need to be initialized.
*/
uint32_t histo_partition_count = passes + bvh_states[i].scatter_blocks - 1;
uint32_t histo_partition_count = passes + states[i].scatter_blocks - 1;
uint32_t fill_base = pass_idx * (RS_RADIX_SIZE * sizeof(uint32_t));
@ -821,14 +801,14 @@ morton_sort(VkCommandBuffer commandBuffer, struct vk_device *device,
rs->pipelines.named.histogram);
for (uint32_t i = 0; i < build_count; ++i) {
if (!bvh_states[i].vk.leaf_node_count)
if (!states[i].leaf_node_count)
continue;
if (bvh_states[i].vk.config.internal_type == VK_INTERNAL_BUILD_TYPE_UPDATE)
if (states[i].config.internal_type == VK_INTERNAL_BUILD_TYPE_UPDATE)
continue;
uint64_t scratch_addr = bvh_states[i].vk.build_info->scratchData.deviceAddress;
uint64_t keyvals_even_addr = scratch_addr + bvh_states[i].vk.scratch.sort_buffer_offset[0];
uint64_t internal_addr = scratch_addr + bvh_states[i].vk.scratch.sort_internal_offset;
uint64_t scratch_addr = states[i].build_info->scratchData.deviceAddress;
uint64_t keyvals_even_addr = scratch_addr + states[i].scratch.sort_buffer_offset[0];
uint64_t internal_addr = scratch_addr + states[i].scratch.sort_internal_offset;
/* Dispatch histogram */
struct rs_push_histogram push_histogram = {
@ -840,7 +820,7 @@ morton_sort(VkCommandBuffer commandBuffer, struct vk_device *device,
disp->CmdPushConstants(commandBuffer, rs->pipeline_layouts.named.histogram, VK_SHADER_STAGE_COMPUTE_BIT, 0,
sizeof(push_histogram), &push_histogram);
disp->CmdDispatch(commandBuffer, bvh_states[i].histo_blocks, 1, 1);
disp->CmdDispatch(commandBuffer, states[i].histo_blocks, 1, 1);
}
/*
@ -854,13 +834,13 @@ morton_sort(VkCommandBuffer commandBuffer, struct vk_device *device,
rs->pipelines.named.prefix);
for (uint32_t i = 0; i < build_count; ++i) {
if (!bvh_states[i].vk.leaf_node_count)
if (!states[i].leaf_node_count)
continue;
if (bvh_states[i].vk.config.internal_type == VK_INTERNAL_BUILD_TYPE_UPDATE)
if (states[i].config.internal_type == VK_INTERNAL_BUILD_TYPE_UPDATE)
continue;
uint64_t internal_addr = bvh_states[i].vk.build_info->scratchData.deviceAddress +
bvh_states[i].vk.scratch.sort_internal_offset;
uint64_t internal_addr = states[i].build_info->scratchData.deviceAddress +
states[i].scratch.sort_internal_offset;
struct rs_push_prefix push_prefix = {
.devaddr_histograms = internal_addr + rs->internal.histograms.offset,
@ -878,12 +858,12 @@ morton_sort(VkCommandBuffer commandBuffer, struct vk_device *device,
uint32_t histogram_offset = pass_idx * (RS_RADIX_SIZE * sizeof(uint32_t));
for (uint32_t i = 0; i < build_count; i++) {
uint64_t scratch_addr = bvh_states[i].vk.build_info->scratchData.deviceAddress;
uint64_t keyvals_even_addr = scratch_addr + bvh_states[i].vk.scratch.sort_buffer_offset[0];
uint64_t keyvals_odd_addr = scratch_addr + bvh_states[i].vk.scratch.sort_buffer_offset[1];
uint64_t internal_addr = scratch_addr + bvh_states[i].vk.scratch.sort_internal_offset;
uint64_t scratch_addr = states[i].build_info->scratchData.deviceAddress;
uint64_t keyvals_even_addr = scratch_addr + states[i].scratch.sort_buffer_offset[0];
uint64_t keyvals_odd_addr = scratch_addr + states[i].scratch.sort_buffer_offset[1];
uint64_t internal_addr = scratch_addr + states[i].scratch.sort_internal_offset;
bvh_states[i].push_scatter = (struct rs_push_scatter){
states[i].push_scatter = (struct rs_push_scatter){
.devaddr_keyvals_even = keyvals_even_addr,
.devaddr_keyvals_odd = keyvals_odd_addr,
.devaddr_partitions = internal_addr + rs->internal.partitions.offset,
@ -906,19 +886,19 @@ morton_sort(VkCommandBuffer commandBuffer, struct vk_device *device,
: rs->pipeline_layouts.named.scatter[pass_dword].odd;
for (uint32_t i = 0; i < build_count; i++) {
if (!bvh_states[i].vk.leaf_node_count)
if (!states[i].leaf_node_count)
continue;
if (bvh_states[i].vk.config.internal_type == VK_INTERNAL_BUILD_TYPE_UPDATE)
if (states[i].config.internal_type == VK_INTERNAL_BUILD_TYPE_UPDATE)
continue;
bvh_states[i].push_scatter.pass_offset = (pass_idx & 3) * RS_RADIX_LOG2;
states[i].push_scatter.pass_offset = (pass_idx & 3) * RS_RADIX_LOG2;
disp->CmdPushConstants(commandBuffer, pl, VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(struct rs_push_scatter),
&bvh_states[i].push_scatter);
&states[i].push_scatter);
disp->CmdDispatch(commandBuffer, bvh_states[i].scatter_blocks, 1, 1);
disp->CmdDispatch(commandBuffer, states[i].scatter_blocks, 1, 1);
bvh_states[i].push_scatter.devaddr_histograms += (RS_RADIX_SIZE * sizeof(uint32_t));
states[i].push_scatter.devaddr_histograms += (RS_RADIX_SIZE * sizeof(uint32_t));
}
/* Continue? */
@ -946,8 +926,8 @@ static VkResult
lbvh_build_internal(VkCommandBuffer commandBuffer, struct vk_device *device,
struct vk_meta_device *meta,
const struct vk_acceleration_structure_build_args *args,
struct bvh_state *bvh_states, uint32_t build_count,
uint32_t build_flags)
struct vk_acceleration_structure_build_state *states,
uint32_t build_count, uint32_t build_flags)
{
VkPipeline pipeline;
VkPipelineLayout layout;
@ -976,23 +956,23 @@ lbvh_build_internal(VkCommandBuffer commandBuffer, struct vk_device *device,
commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline);
for (uint32_t i = 0; i < build_count; ++i) {
if (bvh_states[i].vk.config.internal_type != VK_INTERNAL_BUILD_TYPE_LBVH)
if (states[i].config.internal_type != VK_INTERNAL_BUILD_TYPE_LBVH)
continue;
uint32_t src_scratch_offset = bvh_states[i].scratch_offset;
uint32_t src_scratch_offset = states[i].scratch_offset;
uint64_t scratch_addr = bvh_states[i].vk.build_info->scratchData.deviceAddress;
uint64_t scratch_addr = states[i].build_info->scratchData.deviceAddress;
const struct lbvh_main_args consts = {
.bvh = scratch_addr + bvh_states[i].vk.scratch.ir_offset,
.bvh = scratch_addr + states[i].scratch.ir_offset,
.src_ids = scratch_addr + src_scratch_offset,
.node_info = scratch_addr + bvh_states[i].vk.scratch.lbvh_node_offset,
.header = scratch_addr + bvh_states[i].vk.scratch.header_offset,
.internal_node_base = bvh_states[i].vk.scratch.internal_node_offset - bvh_states[i].vk.scratch.ir_offset,
.node_info = scratch_addr + states[i].scratch.lbvh_node_offset,
.header = scratch_addr + states[i].scratch.header_offset,
.internal_node_base = states[i].scratch.internal_node_offset - states[i].scratch.ir_offset,
};
disp->CmdPushConstants(commandBuffer, layout,
VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(consts), &consts);
device->cmd_dispatch_unaligned(commandBuffer, bvh_states[i].vk.internal_node_count, 1, 1);
device->cmd_dispatch_unaligned(commandBuffer, states[i].internal_node_count, 1, 1);
}
vk_barrier_compute_w_to_compute_r(commandBuffer);
@ -1012,20 +992,20 @@ lbvh_build_internal(VkCommandBuffer commandBuffer, struct vk_device *device,
commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline);
for (uint32_t i = 0; i < build_count; ++i) {
if (bvh_states[i].vk.config.internal_type != VK_INTERNAL_BUILD_TYPE_LBVH)
if (states[i].config.internal_type != VK_INTERNAL_BUILD_TYPE_LBVH)
continue;
uint64_t scratch_addr = bvh_states[i].vk.build_info->scratchData.deviceAddress;
uint64_t scratch_addr = states[i].build_info->scratchData.deviceAddress;
const struct lbvh_generate_ir_args consts = {
.bvh = scratch_addr + bvh_states[i].vk.scratch.ir_offset,
.node_info = scratch_addr + bvh_states[i].vk.scratch.lbvh_node_offset,
.header = scratch_addr + bvh_states[i].vk.scratch.header_offset,
.internal_node_base = bvh_states[i].vk.scratch.internal_node_offset - bvh_states[i].vk.scratch.ir_offset,
.bvh = scratch_addr + states[i].scratch.ir_offset,
.node_info = scratch_addr + states[i].scratch.lbvh_node_offset,
.header = scratch_addr + states[i].scratch.header_offset,
.internal_node_base = states[i].scratch.internal_node_offset - states[i].scratch.ir_offset,
};
disp->CmdPushConstants(commandBuffer, layout,
VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(consts), &consts);
device->cmd_dispatch_unaligned(commandBuffer, bvh_states[i].vk.internal_node_count, 1, 1);
device->cmd_dispatch_unaligned(commandBuffer, states[i].internal_node_count, 1, 1);
}
if (args->emit_markers) {
@ -1044,8 +1024,8 @@ static VkResult
ploc_build_internal(VkCommandBuffer commandBuffer, struct vk_device *device,
struct vk_meta_device *meta,
const struct vk_acceleration_structure_build_args *args,
struct bvh_state *bvh_states, uint32_t build_count,
uint32_t build_flags)
struct vk_acceleration_structure_build_state *states,
uint32_t build_count, uint32_t build_flags)
{
VkPipeline pipeline;
VkPipelineLayout layout;
@ -1073,27 +1053,27 @@ ploc_build_internal(VkCommandBuffer commandBuffer, struct vk_device *device,
commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline);
for (uint32_t i = 0; i < build_count; ++i) {
if (bvh_states[i].vk.config.internal_type != VK_INTERNAL_BUILD_TYPE_PLOC)
if (states[i].config.internal_type != VK_INTERNAL_BUILD_TYPE_PLOC)
continue;
uint32_t src_scratch_offset = bvh_states[i].scratch_offset;
uint32_t dst_scratch_offset = (src_scratch_offset == bvh_states[i].vk.scratch.sort_buffer_offset[0])
? bvh_states[i].vk.scratch.sort_buffer_offset[1]
: bvh_states[i].vk.scratch.sort_buffer_offset[0];
uint32_t src_scratch_offset = states[i].scratch_offset;
uint32_t dst_scratch_offset = (src_scratch_offset == states[i].scratch.sort_buffer_offset[0])
? states[i].scratch.sort_buffer_offset[1]
: states[i].scratch.sort_buffer_offset[0];
uint64_t scratch_addr = bvh_states[i].vk.build_info->scratchData.deviceAddress;
uint64_t scratch_addr = states[i].build_info->scratchData.deviceAddress;
const struct ploc_args consts = {
.bvh = scratch_addr + bvh_states[i].vk.scratch.ir_offset,
.header = scratch_addr + bvh_states[i].vk.scratch.header_offset,
.bvh = scratch_addr + states[i].scratch.ir_offset,
.header = scratch_addr + states[i].scratch.header_offset,
.ids_0 = scratch_addr + src_scratch_offset,
.ids_1 = scratch_addr + dst_scratch_offset,
.prefix_scan_partitions = scratch_addr + bvh_states[i].vk.scratch.ploc_prefix_sum_partition_offset,
.internal_node_offset = bvh_states[i].vk.scratch.internal_node_offset - bvh_states[i].vk.scratch.ir_offset,
.prefix_scan_partitions = scratch_addr + states[i].scratch.ploc_prefix_sum_partition_offset,
.internal_node_offset = states[i].scratch.internal_node_offset - states[i].scratch.ir_offset,
};
disp->CmdPushConstants(commandBuffer, layout,
VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(consts), &consts);
disp->CmdDispatch(commandBuffer, MAX2(DIV_ROUND_UP(bvh_states[i].vk.leaf_node_count, PLOC_WORKGROUP_SIZE), 1), 1, 1);
disp->CmdDispatch(commandBuffer, MAX2(DIV_ROUND_UP(states[i].leaf_node_count, PLOC_WORKGROUP_SIZE), 1), 1, 1);
}
if (args->emit_markers) {
@ -1112,8 +1092,8 @@ static VkResult
hploc_build_internal(VkCommandBuffer commandBuffer, struct vk_device *device,
struct vk_meta_device *meta,
const struct vk_acceleration_structure_build_args *args,
struct bvh_state *bvh_states, uint32_t build_count,
uint32_t build_flags)
struct vk_acceleration_structure_build_state *states,
uint32_t build_count, uint32_t build_flags)
{
VkPipeline pipeline;
VkPipelineLayout layout;
@ -1145,23 +1125,23 @@ hploc_build_internal(VkCommandBuffer commandBuffer, struct vk_device *device,
commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline);
for (uint32_t i = 0; i < build_count; ++i) {
if (bvh_states[i].vk.config.internal_type != VK_INTERNAL_BUILD_TYPE_HPLOC)
if (states[i].config.internal_type != VK_INTERNAL_BUILD_TYPE_HPLOC)
continue;
assert(args->subgroup_size <= 64);
uint64_t scratch_addr = bvh_states[i].vk.build_info->scratchData.deviceAddress;
uint64_t scratch_addr = states[i].build_info->scratchData.deviceAddress;
const struct hploc_args consts = {
.header = scratch_addr + bvh_states[i].vk.scratch.header_offset,
.bvh = scratch_addr + bvh_states[i].vk.scratch.ir_offset,
.ranges = scratch_addr + bvh_states[i].vk.scratch.hploc_ranges_offset,
.ids = scratch_addr + bvh_states[i].scratch_offset,
.internal_node_base = bvh_states[i].vk.scratch.internal_node_offset - bvh_states[i].vk.scratch.ir_offset,
.header = scratch_addr + states[i].scratch.header_offset,
.bvh = scratch_addr + states[i].scratch.ir_offset,
.ranges = scratch_addr + states[i].scratch.hploc_ranges_offset,
.ids = scratch_addr + states[i].scratch_offset,
.internal_node_base = states[i].scratch.internal_node_offset - states[i].scratch.ir_offset,
};
disp->CmdPushConstants(commandBuffer, layout,
VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(consts), &consts);
disp->CmdDispatch(commandBuffer, MAX2(DIV_ROUND_UP(bvh_states[i].vk.leaf_node_count, args->subgroup_size), 1), 1, 1);
disp->CmdDispatch(commandBuffer, MAX2(DIV_ROUND_UP(states[i].leaf_node_count, args->subgroup_size), 1), 1, 1);
}
if (args->emit_markers) {
@ -1177,22 +1157,23 @@ hploc_build_internal(VkCommandBuffer commandBuffer, struct vk_device *device,
typedef VkResult (*vk_build_stage_cb)(VkCommandBuffer commandBuffer, struct vk_device *device,
struct vk_meta_device *meta,
const struct vk_acceleration_structure_build_args *args,
struct bvh_state *bvh_states, uint32_t build_count,
uint32_t build_flags);
struct vk_acceleration_structure_build_state *states,
uint32_t build_count, uint32_t build_flags);
static VkResult
vk_build_stage(vk_build_stage_cb cb, VkCommandBuffer commandBuffer, struct vk_device *device,
struct vk_meta_device *meta, const struct vk_acceleration_structure_build_args *args,
struct bvh_state *bvh_states, uint32_t build_count, uint32_t build_flags_mask)
struct vk_acceleration_structure_build_state *states, uint32_t build_count,
uint32_t build_flags_mask)
{
BITSET_DECLARE(flag_combinations, 1u << VK_BUILD_FLAG_COUNT);
BITSET_ZERO(flag_combinations);
for (uint32_t i = 0; i < build_count; i++)
BITSET_SET(flag_combinations, bvh_states[i].build_flags & build_flags_mask);
BITSET_SET(flag_combinations, states[i].build_flags & build_flags_mask);
uint32_t build_flags;
BITSET_FOREACH_SET(build_flags, flag_combinations, 1u << VK_BUILD_FLAG_COUNT) {
VkResult result = cb(commandBuffer, device, meta, args, bvh_states, build_count, build_flags);
VkResult result = cb(commandBuffer, device, meta, args, states, build_count, build_flags);
if (result != VK_SUCCESS)
return result;
}
@ -1214,7 +1195,7 @@ vk_cmd_build_acceleration_structures(VkCommandBuffer commandBuffer,
struct bvh_batch_state batch_state = {0};
struct bvh_state *bvh_states = calloc(infoCount, sizeof(struct bvh_state));
struct vk_acceleration_structure_build_state *states = calloc(infoCount, sizeof(struct vk_acceleration_structure_build_state));
struct vk_acceleration_structure_build_marker top_marker = {
.step = VK_ACCELERATION_STRUCTURE_BUILD_STEP_TOP,
@ -1241,38 +1222,38 @@ vk_cmd_build_acceleration_structures(VkCommandBuffer commandBuffer,
leaf_node_count += ppBuildRangeInfos[i][j].primitiveCount;
}
vk_acceleration_structure_build_state_init(&bvh_states[i].vk, cmd_buffer->base.device, leaf_node_count,
vk_acceleration_structure_build_state_init(&states[i], cmd_buffer->base.device, leaf_node_count,
pInfos + i, args);
bvh_states[i].vk.build_range_infos = ppBuildRangeInfos[i];
states[i].build_range_infos = ppBuildRangeInfos[i];
/* The leaf node dispatch code uses leaf_node_count as a base index. */
bvh_states[i].vk.leaf_node_count = 0;
states[i].leaf_node_count = 0;
if (bvh_states[i].vk.config.updateable)
if (states[i].config.updateable)
batch_state.any_updateable = true;
else
batch_state.any_non_updateable = true;
if (bvh_states[i].vk.config.internal_type == VK_INTERNAL_BUILD_TYPE_PLOC) {
if (states[i].config.internal_type == VK_INTERNAL_BUILD_TYPE_PLOC) {
batch_state.any_ploc = true;
} else if (bvh_states[i].vk.config.internal_type == VK_INTERNAL_BUILD_TYPE_HPLOC) {
} else if (states[i].config.internal_type == VK_INTERNAL_BUILD_TYPE_HPLOC) {
batch_state.any_hploc = true;
} else if (bvh_states[i].vk.config.internal_type == VK_INTERNAL_BUILD_TYPE_LBVH) {
} else if (states[i].config.internal_type == VK_INTERNAL_BUILD_TYPE_LBVH) {
batch_state.any_lbvh = true;
} else if (bvh_states[i].vk.config.internal_type == VK_INTERNAL_BUILD_TYPE_UPDATE) {
} else if (states[i].config.internal_type == VK_INTERNAL_BUILD_TYPE_UPDATE) {
batch_state.any_update = true;
/* For updates, the leaf node pass never runs, so set leaf_node_count here. */
bvh_states[i].vk.leaf_node_count = leaf_node_count;
states[i].leaf_node_count = leaf_node_count;
} else {
UNREACHABLE("Unknown internal_build_type");
}
if (bvh_states[i].vk.config.updateable)
bvh_states[i].build_flags |= VK_BUILD_FLAG_ALWAYS_ACTIVE;
if (states[i].config.updateable)
states[i].build_flags |= VK_BUILD_FLAG_ALWAYS_ACTIVE;
if (args->propagate_cull_flags)
bvh_states[i].build_flags |= VK_BUILD_FLAG_PROPAGATE_CULL_FLAGS;
states[i].build_flags |= VK_BUILD_FLAG_PROPAGATE_CULL_FLAGS;
if (bvh_states[i].vk.config.internal_type != VK_INTERNAL_BUILD_TYPE_UPDATE) {
if (states[i].config.internal_type != VK_INTERNAL_BUILD_TYPE_UPDATE) {
/* The internal node count is updated in lbvh_build_internal for LBVH
* and from the PLOC shader for PLOC. */
struct vk_ir_header header = {
@ -1288,10 +1269,10 @@ vk_cmd_build_acceleration_structures(VkCommandBuffer commandBuffer,
},
};
device->write_buffer_cp(commandBuffer, pInfos[i].scratchData.deviceAddress + bvh_states[i].vk.scratch.header_offset,
device->write_buffer_cp(commandBuffer, pInfos[i].scratchData.deviceAddress + states[i].scratch.header_offset,
&header, sizeof(header));
} else {
ops->init_update_scratch(commandBuffer, &bvh_states[i].vk);
ops->init_update_scratch(commandBuffer, &states[i]);
}
}
@ -1311,19 +1292,19 @@ vk_cmd_build_acceleration_structures(VkCommandBuffer commandBuffer,
}, 0, NULL, 0, NULL);
if (batch_state.any_lbvh || batch_state.any_ploc || batch_state.any_hploc) {
VkResult result = vk_build_stage(build_leaves, commandBuffer, device, meta, args, bvh_states, infoCount,
VkResult result = vk_build_stage(build_leaves, commandBuffer, device, meta, args, states, infoCount,
VK_BUILD_LEAVES_FLAGS);
if (result != VK_SUCCESS) {
free(bvh_states);
free(states);
vk_command_buffer_set_error(cmd_buffer, result);
return;
}
if (batch_state.any_hploc) {
for (uint32_t i = 0; i < infoCount; ++i) {
if (bvh_states[i].vk.config.internal_type == VK_INTERNAL_BUILD_TYPE_HPLOC) {
device->cmd_fill_buffer_addr(commandBuffer, pInfos[i].scratchData.deviceAddress + bvh_states[i].vk.scratch.hploc_ranges_offset,
sizeof(uint32_t) * bvh_states[i].vk.internal_node_count, 0xffffffff);
if (states[i].config.internal_type == VK_INTERNAL_BUILD_TYPE_HPLOC) {
device->cmd_fill_buffer_addr(commandBuffer, pInfos[i].scratchData.deviceAddress + states[i].scratch.hploc_ranges_offset,
sizeof(uint32_t) * states[i].internal_node_count, 0xffffffff);
}
}
vk_barrier_transfer_w_to_compute_r(commandBuffer);
@ -1331,31 +1312,31 @@ vk_cmd_build_acceleration_structures(VkCommandBuffer commandBuffer,
vk_barrier_compute_w_to_compute_r(commandBuffer);
result = vk_build_stage(morton_generate, commandBuffer, device, meta, args, bvh_states, infoCount, 0);
result = vk_build_stage(morton_generate, commandBuffer, device, meta, args, states, infoCount, 0);
if (result != VK_SUCCESS) {
free(bvh_states);
free(states);
vk_command_buffer_set_error(cmd_buffer, result);
return;
}
vk_barrier_compute_w_to_compute_r(commandBuffer);
vk_build_stage(morton_sort, commandBuffer, device, meta, args, bvh_states, infoCount, 0);
vk_build_stage(morton_sort, commandBuffer, device, meta, args, states, infoCount, 0);
vk_barrier_compute_w_to_compute_r(commandBuffer);
if (batch_state.any_lbvh) {
result = vk_build_stage(lbvh_build_internal, commandBuffer, device, meta, args, bvh_states, infoCount,
result = vk_build_stage(lbvh_build_internal, commandBuffer, device, meta, args, states, infoCount,
VK_LBVH_BUILD_INTERNAL_FLAGS);
if (result != VK_SUCCESS) {
free(bvh_states);
free(states);
vk_command_buffer_set_error(cmd_buffer, result);
return;
}
}
if (batch_state.any_ploc) {
result = vk_build_stage(ploc_build_internal, commandBuffer, device, meta, args, bvh_states, infoCount,
result = vk_build_stage(ploc_build_internal, commandBuffer, device, meta, args, states, infoCount,
VK_PLOC_BUILD_INTERNAL_FLAGS);
if (result != VK_SUCCESS) {
vk_command_buffer_set_error(cmd_buffer, result);
@ -1364,7 +1345,7 @@ vk_cmd_build_acceleration_structures(VkCommandBuffer commandBuffer,
}
if (batch_state.any_hploc) {
result = vk_build_stage(hploc_build_internal, commandBuffer, device, meta, args, bvh_states, infoCount,
result = vk_build_stage(hploc_build_internal, commandBuffer, device, meta, args, states, infoCount,
VK_HPLOC_BUILD_INTERNAL_FLAGS);
if (result != VK_SUCCESS) {
vk_command_buffer_set_error(cmd_buffer, result);
@ -1392,18 +1373,18 @@ vk_cmd_build_acceleration_structures(VkCommandBuffer commandBuffer,
uint32_t encode_key = 0;
uint32_t update_key = 0;
for (uint32_t i = 0; i < infoCount; ++i) {
if (bvh_states[i].last_encode_pass == pass + 1)
if (states[i].last_encode_pass == pass + 1)
continue;
if (!progress) {
update = (bvh_states[i].vk.config.internal_type ==
update = (states[i].config.internal_type ==
VK_INTERNAL_BUILD_TYPE_UPDATE);
if (update && !ops->update_as[pass])
continue;
if (!update && !ops->encode_as[pass])
continue;
encode_key = bvh_states[i].vk.config.encode_key[pass];
update_key = bvh_states[i].vk.config.update_key[pass];
encode_key = states[i].config.encode_key[pass];
update_key = states[i].config.update_key[pass];
progress = true;
if (args->emit_markers) {
@ -1417,14 +1398,14 @@ vk_cmd_build_acceleration_structures(VkCommandBuffer commandBuffer,
encode_marker.encode.key = update ? update_key : encode_key;
for (uint32_t j = 0; j < infoCount; j++) {
if (update != (bvh_states[j].vk.config.internal_type ==
if (update != (states[j].config.internal_type ==
VK_INTERNAL_BUILD_TYPE_UPDATE) ||
encode_key != bvh_states[j].vk.config.encode_key[pass] ||
update_key != bvh_states[j].vk.config.update_key[pass])
encode_key != states[j].config.encode_key[pass] ||
update_key != states[j].config.update_key[pass])
continue;
encode_marker.encode.leaf_node_count += bvh_states[j].vk.leaf_node_count;
encode_marker.encode.internal_node_count += bvh_states[i].vk.internal_node_count;
encode_marker.encode.leaf_node_count += states[j].leaf_node_count;
encode_marker.encode.internal_node_count += states[i].internal_node_count;
}
device->as_build_ops->begin_debug_marker(commandBuffer, &encode_marker);
@ -1433,26 +1414,26 @@ vk_cmd_build_acceleration_structures(VkCommandBuffer commandBuffer,
}
if (update) {
ops->update_prepare[pass](commandBuffer, &bvh_states[i].vk,
flushed_cp_after_init_update_scratch,
flushed_compute_after_init_update_scratch);
ops->update_prepare[pass](commandBuffer, &states[i],
flushed_cp_after_init_update_scratch,
flushed_compute_after_init_update_scratch);
} else {
ops->encode_prepare[pass](commandBuffer, &bvh_states[i].vk);
ops->encode_prepare[pass](commandBuffer, &states[i]);
}
} else {
if (update != (bvh_states[i].vk.config.internal_type ==
if (update != (states[i].config.internal_type ==
VK_INTERNAL_BUILD_TYPE_UPDATE) ||
encode_key != bvh_states[i].vk.config.encode_key[pass] ||
update_key != bvh_states[i].vk.config.update_key[pass])
encode_key != states[i].config.encode_key[pass] ||
update_key != states[i].config.update_key[pass])
continue;
}
if (update)
ops->update_as[pass](commandBuffer, &bvh_states[i].vk);
ops->update_as[pass](commandBuffer, &states[i]);
else
ops->encode_as[pass](commandBuffer, &bvh_states[i].vk);
ops->encode_as[pass](commandBuffer, &states[i]);
bvh_states[i].last_encode_pass = pass + 1;
states[i].last_encode_pass = pass + 1;
}
} while (progress);
}
@ -1463,7 +1444,7 @@ vk_cmd_build_acceleration_structures(VkCommandBuffer commandBuffer,
if (args->emit_markers)
device->as_build_ops->end_debug_marker(commandBuffer, &top_marker);
free(bvh_states);
free(states);
}
void

View file

@ -29,6 +29,8 @@
#include "vk_meta.h"
#include "vk_object.h"
#include "radix_sort/radix_sort_vk.h"
#include "radix_sort/common/vk/barrier.h"
#include "radix_sort/shaders/push.h"
#include "bvh/vk_bvh.h"
@ -131,6 +133,20 @@ struct vk_acceleration_structure_build_state {
uint32_t internal_node_count;
struct vk_scratch_layout scratch;
struct vk_build_config config;
/* Internal state of vk_acceleration_structure.c */
uint32_t build_flags;
uint32_t scratch_offset;
/* Radix sort state */
uint32_t scatter_blocks;
uint32_t count_ru_scatter;
uint32_t histo_blocks;
uint32_t count_ru_histo;
struct rs_push_scatter push_scatter;
uint32_t last_encode_pass;
};
struct vk_acceleration_structure_build_ops {