From 411e23d389688455dc61f3b53b560740ceda3c08 Mon Sep 17 00:00:00 2001 From: Konstantin Seurer Date: Sat, 7 Feb 2026 12:15:47 +0100 Subject: [PATCH] vulkan: Remove bvh_state This will allow passing the whole array to the driver. Reviewed-by: Natalie Vock Part-of: --- .../runtime/vk_acceleration_structure.c | 323 +++++++++--------- .../runtime/vk_acceleration_structure.h | 16 + 2 files changed, 168 insertions(+), 171 deletions(-) diff --git a/src/vulkan/runtime/vk_acceleration_structure.c b/src/vulkan/runtime/vk_acceleration_structure.c index 9c4dec01882..c4d6eefc6bd 100644 --- a/src/vulkan/runtime/vk_acceleration_structure.c +++ b/src/vulkan/runtime/vk_acceleration_structure.c @@ -36,9 +36,6 @@ #include "bvh/vk_build_interface.h" #include "bvh/vk_bvh.h" -#include "radix_sort/common/vk/barrier.h" -#include "radix_sort/shaders/push.h" - #include "util/u_string.h" #include "util/bitset.h" @@ -253,23 +250,6 @@ vk_acceleration_structure_build_state_init(struct vk_acceleration_structure_buil } } -struct bvh_state { - struct vk_acceleration_structure_build_state vk; - - uint32_t build_flags; - - uint32_t scratch_offset; - - /* Radix sort state */ - uint32_t scatter_blocks; - uint32_t count_ru_scatter; - uint32_t histo_blocks; - uint32_t count_ru_histo; - struct rs_push_scatter push_scatter; - - uint32_t last_encode_pass; -}; - struct bvh_batch_state { bool any_updateable; bool any_non_updateable; @@ -543,8 +523,8 @@ static VkResult build_leaves(VkCommandBuffer commandBuffer, struct vk_device *device, struct vk_meta_device *meta, const struct vk_acceleration_structure_build_args *args, - struct bvh_state *bvh_states, uint32_t build_count, - uint32_t build_flags) + struct vk_acceleration_structure_build_state *states, + uint32_t build_count, uint32_t build_flags) { VkPipeline pipeline; VkPipelineLayout layout; @@ -588,36 +568,36 @@ build_leaves(VkCommandBuffer commandBuffer, struct vk_device *device, commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline); for (uint32_t i = 0; i < build_count; ++i) { - if (bvh_states[i].vk.config.internal_type == VK_INTERNAL_BUILD_TYPE_UPDATE) + if (states[i].config.internal_type == VK_INTERNAL_BUILD_TYPE_UPDATE) continue; - if ((bvh_states[i].build_flags & VK_BUILD_LEAVES_FLAGS) != build_flags) + if ((states[i].build_flags & VK_BUILD_LEAVES_FLAGS) != build_flags) continue; - const VkAccelerationStructureBuildGeometryInfoKHR *build_info = bvh_states[i].vk.build_info; + const VkAccelerationStructureBuildGeometryInfoKHR *build_info = states[i].build_info; uint64_t scratch_addr = build_info->scratchData.deviceAddress; struct leaf_args leaf_consts = { - .bvh = scratch_addr + bvh_states[i].vk.scratch.ir_offset, - .header = scratch_addr + bvh_states[i].vk.scratch.header_offset, - .ids = scratch_addr + bvh_states[i].vk.scratch.sort_buffer_offset[0], + .bvh = scratch_addr + states[i].scratch.ir_offset, + .header = scratch_addr + states[i].scratch.header_offset, + .ids = scratch_addr + states[i].scratch.sort_buffer_offset[0], }; for (unsigned j = 0; j < build_info->geometryCount; ++j) { const VkAccelerationStructureGeometryKHR *geom = build_info->pGeometries ? &build_info->pGeometries[j] : build_info->ppGeometries[j]; - const VkAccelerationStructureBuildRangeInfoKHR *build_range_info = &bvh_states[i].vk.build_range_infos[j]; + const VkAccelerationStructureBuildRangeInfoKHR *build_range_info = &states[i].build_range_infos[j]; if (build_range_info->primitiveCount == 0) continue; - leaf_consts.geom_data = vk_fill_geometry_data(build_info->type, bvh_states[i].vk.leaf_node_count, j, geom, build_range_info); + leaf_consts.geom_data = vk_fill_geometry_data(build_info->type, states[i].leaf_node_count, j, geom, build_range_info); disp->CmdPushConstants(commandBuffer, layout, VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(leaf_consts), &leaf_consts); device->cmd_dispatch_unaligned(commandBuffer, build_range_info->primitiveCount, 1, 1); - bvh_states[i].vk.leaf_node_count += build_range_info->primitiveCount; + states[i].leaf_node_count += build_range_info->primitiveCount; } } @@ -635,8 +615,8 @@ static VkResult morton_generate(VkCommandBuffer commandBuffer, struct vk_device *device, struct vk_meta_device *meta, const struct vk_acceleration_structure_build_args *args, - struct bvh_state *bvh_states, uint32_t build_count, - uint32_t build_flags) + struct vk_acceleration_structure_build_state *states, + uint32_t build_count, uint32_t build_flags) { VkPipeline pipeline; VkPipelineLayout layout; @@ -665,19 +645,19 @@ morton_generate(VkCommandBuffer commandBuffer, struct vk_device *device, commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline); for (uint32_t i = 0; i < build_count; ++i) { - if (bvh_states[i].vk.config.internal_type == VK_INTERNAL_BUILD_TYPE_UPDATE) + if (states[i].config.internal_type == VK_INTERNAL_BUILD_TYPE_UPDATE) continue; - uint64_t scratch_addr = bvh_states[i].vk.build_info->scratchData.deviceAddress; + uint64_t scratch_addr = states[i].build_info->scratchData.deviceAddress; const struct morton_args consts = { - .bvh = scratch_addr + bvh_states[i].vk.scratch.ir_offset, - .header = scratch_addr + bvh_states[i].vk.scratch.header_offset, - .ids = scratch_addr + bvh_states[i].vk.scratch.sort_buffer_offset[0], + .bvh = scratch_addr + states[i].scratch.ir_offset, + .header = scratch_addr + states[i].scratch.header_offset, + .ids = scratch_addr + states[i].scratch.sort_buffer_offset[0], }; disp->CmdPushConstants(commandBuffer, layout, VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(consts), &consts); - device->cmd_dispatch_unaligned(commandBuffer, bvh_states[i].vk.leaf_node_count, 1, 1); + device->cmd_dispatch_unaligned(commandBuffer, states[i].leaf_node_count, 1, 1); } if (args->emit_markers) { @@ -694,8 +674,8 @@ static VkResult morton_sort(VkCommandBuffer commandBuffer, struct vk_device *device, struct vk_meta_device *meta, const struct vk_acceleration_structure_build_args *args, - struct bvh_state *bvh_states, uint32_t build_count, - uint32_t build_flags) + struct vk_acceleration_structure_build_state *states, + uint32_t build_count, uint32_t build_flags) { const struct vk_device_dispatch_table *disp = &device->dispatch_table; @@ -732,10 +712,10 @@ morton_sort(VkCommandBuffer commandBuffer, struct vk_device *device, uint32_t passes = (key_bits + RS_RADIX_LOG2 - 1) / RS_RADIX_LOG2; for (uint32_t i = 0; i < build_count; ++i) { - if (bvh_states[i].vk.leaf_node_count) - bvh_states[i].scratch_offset = bvh_states[i].vk.scratch.sort_buffer_offset[passes & 1]; + if (states[i].leaf_node_count) + states[i].scratch_offset = states[i].scratch.sort_buffer_offset[passes & 1]; else - bvh_states[i].scratch_offset = bvh_states[i].vk.scratch.sort_buffer_offset[0]; + states[i].scratch_offset = states[i].scratch.sort_buffer_offset[0]; } /* @@ -766,26 +746,26 @@ morton_sort(VkCommandBuffer commandBuffer, struct vk_device *device, uint32_t pass_idx = (keyval_bytes - passes); for (uint32_t i = 0; i < build_count; ++i) { - if (!bvh_states[i].vk.leaf_node_count) + if (!states[i].leaf_node_count) continue; - if (bvh_states[i].vk.config.internal_type == VK_INTERNAL_BUILD_TYPE_UPDATE) + if (states[i].config.internal_type == VK_INTERNAL_BUILD_TYPE_UPDATE) continue; - uint64_t scratch_addr = bvh_states[i].vk.build_info->scratchData.deviceAddress; - uint64_t keyvals_even_addr = scratch_addr + bvh_states[i].vk.scratch.sort_buffer_offset[0]; - uint64_t internal_addr = scratch_addr + bvh_states[i].vk.scratch.sort_internal_offset; + uint64_t scratch_addr = states[i].build_info->scratchData.deviceAddress; + uint64_t keyvals_even_addr = scratch_addr + states[i].scratch.sort_buffer_offset[0]; + uint64_t internal_addr = scratch_addr + states[i].scratch.sort_internal_offset; - bvh_states[i].scatter_blocks = (bvh_states[i].vk.leaf_node_count + scatter_block_kvs - 1) / scatter_block_kvs; - bvh_states[i].count_ru_scatter = bvh_states[i].scatter_blocks * scatter_block_kvs; + states[i].scatter_blocks = (states[i].leaf_node_count + scatter_block_kvs - 1) / scatter_block_kvs; + states[i].count_ru_scatter = states[i].scatter_blocks * scatter_block_kvs; - bvh_states[i].histo_blocks = (bvh_states[i].count_ru_scatter + histo_block_kvs - 1) / histo_block_kvs; - bvh_states[i].count_ru_histo = bvh_states[i].histo_blocks * histo_block_kvs; + states[i].histo_blocks = (states[i].count_ru_scatter + histo_block_kvs - 1) / histo_block_kvs; + states[i].count_ru_histo = states[i].histo_blocks * histo_block_kvs; /* Fill with max values */ - if (bvh_states[i].count_ru_histo > bvh_states[i].vk.leaf_node_count) { + if (states[i].count_ru_histo > states[i].leaf_node_count) { device->cmd_fill_buffer_addr(commandBuffer, keyvals_even_addr + - bvh_states[i].vk.leaf_node_count * keyval_bytes, - (bvh_states[i].count_ru_histo - bvh_states[i].vk.leaf_node_count) * keyval_bytes, + states[i].leaf_node_count * keyval_bytes, + (states[i].count_ru_histo - states[i].leaf_node_count) * keyval_bytes, 0xFFFFFFFF); } @@ -799,7 +779,7 @@ morton_sort(VkCommandBuffer commandBuffer, struct vk_device *device, * Note that the last workgroup doesn't read/write a partition so it doesn't * need to be initialized. */ - uint32_t histo_partition_count = passes + bvh_states[i].scatter_blocks - 1; + uint32_t histo_partition_count = passes + states[i].scatter_blocks - 1; uint32_t fill_base = pass_idx * (RS_RADIX_SIZE * sizeof(uint32_t)); @@ -821,14 +801,14 @@ morton_sort(VkCommandBuffer commandBuffer, struct vk_device *device, rs->pipelines.named.histogram); for (uint32_t i = 0; i < build_count; ++i) { - if (!bvh_states[i].vk.leaf_node_count) + if (!states[i].leaf_node_count) continue; - if (bvh_states[i].vk.config.internal_type == VK_INTERNAL_BUILD_TYPE_UPDATE) + if (states[i].config.internal_type == VK_INTERNAL_BUILD_TYPE_UPDATE) continue; - uint64_t scratch_addr = bvh_states[i].vk.build_info->scratchData.deviceAddress; - uint64_t keyvals_even_addr = scratch_addr + bvh_states[i].vk.scratch.sort_buffer_offset[0]; - uint64_t internal_addr = scratch_addr + bvh_states[i].vk.scratch.sort_internal_offset; + uint64_t scratch_addr = states[i].build_info->scratchData.deviceAddress; + uint64_t keyvals_even_addr = scratch_addr + states[i].scratch.sort_buffer_offset[0]; + uint64_t internal_addr = scratch_addr + states[i].scratch.sort_internal_offset; /* Dispatch histogram */ struct rs_push_histogram push_histogram = { @@ -840,7 +820,7 @@ morton_sort(VkCommandBuffer commandBuffer, struct vk_device *device, disp->CmdPushConstants(commandBuffer, rs->pipeline_layouts.named.histogram, VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(push_histogram), &push_histogram); - disp->CmdDispatch(commandBuffer, bvh_states[i].histo_blocks, 1, 1); + disp->CmdDispatch(commandBuffer, states[i].histo_blocks, 1, 1); } /* @@ -854,13 +834,13 @@ morton_sort(VkCommandBuffer commandBuffer, struct vk_device *device, rs->pipelines.named.prefix); for (uint32_t i = 0; i < build_count; ++i) { - if (!bvh_states[i].vk.leaf_node_count) + if (!states[i].leaf_node_count) continue; - if (bvh_states[i].vk.config.internal_type == VK_INTERNAL_BUILD_TYPE_UPDATE) + if (states[i].config.internal_type == VK_INTERNAL_BUILD_TYPE_UPDATE) continue; - uint64_t internal_addr = bvh_states[i].vk.build_info->scratchData.deviceAddress + - bvh_states[i].vk.scratch.sort_internal_offset; + uint64_t internal_addr = states[i].build_info->scratchData.deviceAddress + + states[i].scratch.sort_internal_offset; struct rs_push_prefix push_prefix = { .devaddr_histograms = internal_addr + rs->internal.histograms.offset, @@ -878,12 +858,12 @@ morton_sort(VkCommandBuffer commandBuffer, struct vk_device *device, uint32_t histogram_offset = pass_idx * (RS_RADIX_SIZE * sizeof(uint32_t)); for (uint32_t i = 0; i < build_count; i++) { - uint64_t scratch_addr = bvh_states[i].vk.build_info->scratchData.deviceAddress; - uint64_t keyvals_even_addr = scratch_addr + bvh_states[i].vk.scratch.sort_buffer_offset[0]; - uint64_t keyvals_odd_addr = scratch_addr + bvh_states[i].vk.scratch.sort_buffer_offset[1]; - uint64_t internal_addr = scratch_addr + bvh_states[i].vk.scratch.sort_internal_offset; + uint64_t scratch_addr = states[i].build_info->scratchData.deviceAddress; + uint64_t keyvals_even_addr = scratch_addr + states[i].scratch.sort_buffer_offset[0]; + uint64_t keyvals_odd_addr = scratch_addr + states[i].scratch.sort_buffer_offset[1]; + uint64_t internal_addr = scratch_addr + states[i].scratch.sort_internal_offset; - bvh_states[i].push_scatter = (struct rs_push_scatter){ + states[i].push_scatter = (struct rs_push_scatter){ .devaddr_keyvals_even = keyvals_even_addr, .devaddr_keyvals_odd = keyvals_odd_addr, .devaddr_partitions = internal_addr + rs->internal.partitions.offset, @@ -906,19 +886,19 @@ morton_sort(VkCommandBuffer commandBuffer, struct vk_device *device, : rs->pipeline_layouts.named.scatter[pass_dword].odd; for (uint32_t i = 0; i < build_count; i++) { - if (!bvh_states[i].vk.leaf_node_count) + if (!states[i].leaf_node_count) continue; - if (bvh_states[i].vk.config.internal_type == VK_INTERNAL_BUILD_TYPE_UPDATE) + if (states[i].config.internal_type == VK_INTERNAL_BUILD_TYPE_UPDATE) continue; - bvh_states[i].push_scatter.pass_offset = (pass_idx & 3) * RS_RADIX_LOG2; + states[i].push_scatter.pass_offset = (pass_idx & 3) * RS_RADIX_LOG2; disp->CmdPushConstants(commandBuffer, pl, VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(struct rs_push_scatter), - &bvh_states[i].push_scatter); + &states[i].push_scatter); - disp->CmdDispatch(commandBuffer, bvh_states[i].scatter_blocks, 1, 1); + disp->CmdDispatch(commandBuffer, states[i].scatter_blocks, 1, 1); - bvh_states[i].push_scatter.devaddr_histograms += (RS_RADIX_SIZE * sizeof(uint32_t)); + states[i].push_scatter.devaddr_histograms += (RS_RADIX_SIZE * sizeof(uint32_t)); } /* Continue? */ @@ -946,8 +926,8 @@ static VkResult lbvh_build_internal(VkCommandBuffer commandBuffer, struct vk_device *device, struct vk_meta_device *meta, const struct vk_acceleration_structure_build_args *args, - struct bvh_state *bvh_states, uint32_t build_count, - uint32_t build_flags) + struct vk_acceleration_structure_build_state *states, + uint32_t build_count, uint32_t build_flags) { VkPipeline pipeline; VkPipelineLayout layout; @@ -976,23 +956,23 @@ lbvh_build_internal(VkCommandBuffer commandBuffer, struct vk_device *device, commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline); for (uint32_t i = 0; i < build_count; ++i) { - if (bvh_states[i].vk.config.internal_type != VK_INTERNAL_BUILD_TYPE_LBVH) + if (states[i].config.internal_type != VK_INTERNAL_BUILD_TYPE_LBVH) continue; - uint32_t src_scratch_offset = bvh_states[i].scratch_offset; + uint32_t src_scratch_offset = states[i].scratch_offset; - uint64_t scratch_addr = bvh_states[i].vk.build_info->scratchData.deviceAddress; + uint64_t scratch_addr = states[i].build_info->scratchData.deviceAddress; const struct lbvh_main_args consts = { - .bvh = scratch_addr + bvh_states[i].vk.scratch.ir_offset, + .bvh = scratch_addr + states[i].scratch.ir_offset, .src_ids = scratch_addr + src_scratch_offset, - .node_info = scratch_addr + bvh_states[i].vk.scratch.lbvh_node_offset, - .header = scratch_addr + bvh_states[i].vk.scratch.header_offset, - .internal_node_base = bvh_states[i].vk.scratch.internal_node_offset - bvh_states[i].vk.scratch.ir_offset, + .node_info = scratch_addr + states[i].scratch.lbvh_node_offset, + .header = scratch_addr + states[i].scratch.header_offset, + .internal_node_base = states[i].scratch.internal_node_offset - states[i].scratch.ir_offset, }; disp->CmdPushConstants(commandBuffer, layout, VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(consts), &consts); - device->cmd_dispatch_unaligned(commandBuffer, bvh_states[i].vk.internal_node_count, 1, 1); + device->cmd_dispatch_unaligned(commandBuffer, states[i].internal_node_count, 1, 1); } vk_barrier_compute_w_to_compute_r(commandBuffer); @@ -1012,20 +992,20 @@ lbvh_build_internal(VkCommandBuffer commandBuffer, struct vk_device *device, commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline); for (uint32_t i = 0; i < build_count; ++i) { - if (bvh_states[i].vk.config.internal_type != VK_INTERNAL_BUILD_TYPE_LBVH) + if (states[i].config.internal_type != VK_INTERNAL_BUILD_TYPE_LBVH) continue; - uint64_t scratch_addr = bvh_states[i].vk.build_info->scratchData.deviceAddress; + uint64_t scratch_addr = states[i].build_info->scratchData.deviceAddress; const struct lbvh_generate_ir_args consts = { - .bvh = scratch_addr + bvh_states[i].vk.scratch.ir_offset, - .node_info = scratch_addr + bvh_states[i].vk.scratch.lbvh_node_offset, - .header = scratch_addr + bvh_states[i].vk.scratch.header_offset, - .internal_node_base = bvh_states[i].vk.scratch.internal_node_offset - bvh_states[i].vk.scratch.ir_offset, + .bvh = scratch_addr + states[i].scratch.ir_offset, + .node_info = scratch_addr + states[i].scratch.lbvh_node_offset, + .header = scratch_addr + states[i].scratch.header_offset, + .internal_node_base = states[i].scratch.internal_node_offset - states[i].scratch.ir_offset, }; disp->CmdPushConstants(commandBuffer, layout, VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(consts), &consts); - device->cmd_dispatch_unaligned(commandBuffer, bvh_states[i].vk.internal_node_count, 1, 1); + device->cmd_dispatch_unaligned(commandBuffer, states[i].internal_node_count, 1, 1); } if (args->emit_markers) { @@ -1044,8 +1024,8 @@ static VkResult ploc_build_internal(VkCommandBuffer commandBuffer, struct vk_device *device, struct vk_meta_device *meta, const struct vk_acceleration_structure_build_args *args, - struct bvh_state *bvh_states, uint32_t build_count, - uint32_t build_flags) + struct vk_acceleration_structure_build_state *states, + uint32_t build_count, uint32_t build_flags) { VkPipeline pipeline; VkPipelineLayout layout; @@ -1073,27 +1053,27 @@ ploc_build_internal(VkCommandBuffer commandBuffer, struct vk_device *device, commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline); for (uint32_t i = 0; i < build_count; ++i) { - if (bvh_states[i].vk.config.internal_type != VK_INTERNAL_BUILD_TYPE_PLOC) + if (states[i].config.internal_type != VK_INTERNAL_BUILD_TYPE_PLOC) continue; - uint32_t src_scratch_offset = bvh_states[i].scratch_offset; - uint32_t dst_scratch_offset = (src_scratch_offset == bvh_states[i].vk.scratch.sort_buffer_offset[0]) - ? bvh_states[i].vk.scratch.sort_buffer_offset[1] - : bvh_states[i].vk.scratch.sort_buffer_offset[0]; + uint32_t src_scratch_offset = states[i].scratch_offset; + uint32_t dst_scratch_offset = (src_scratch_offset == states[i].scratch.sort_buffer_offset[0]) + ? states[i].scratch.sort_buffer_offset[1] + : states[i].scratch.sort_buffer_offset[0]; - uint64_t scratch_addr = bvh_states[i].vk.build_info->scratchData.deviceAddress; + uint64_t scratch_addr = states[i].build_info->scratchData.deviceAddress; const struct ploc_args consts = { - .bvh = scratch_addr + bvh_states[i].vk.scratch.ir_offset, - .header = scratch_addr + bvh_states[i].vk.scratch.header_offset, + .bvh = scratch_addr + states[i].scratch.ir_offset, + .header = scratch_addr + states[i].scratch.header_offset, .ids_0 = scratch_addr + src_scratch_offset, .ids_1 = scratch_addr + dst_scratch_offset, - .prefix_scan_partitions = scratch_addr + bvh_states[i].vk.scratch.ploc_prefix_sum_partition_offset, - .internal_node_offset = bvh_states[i].vk.scratch.internal_node_offset - bvh_states[i].vk.scratch.ir_offset, + .prefix_scan_partitions = scratch_addr + states[i].scratch.ploc_prefix_sum_partition_offset, + .internal_node_offset = states[i].scratch.internal_node_offset - states[i].scratch.ir_offset, }; disp->CmdPushConstants(commandBuffer, layout, VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(consts), &consts); - disp->CmdDispatch(commandBuffer, MAX2(DIV_ROUND_UP(bvh_states[i].vk.leaf_node_count, PLOC_WORKGROUP_SIZE), 1), 1, 1); + disp->CmdDispatch(commandBuffer, MAX2(DIV_ROUND_UP(states[i].leaf_node_count, PLOC_WORKGROUP_SIZE), 1), 1, 1); } if (args->emit_markers) { @@ -1112,8 +1092,8 @@ static VkResult hploc_build_internal(VkCommandBuffer commandBuffer, struct vk_device *device, struct vk_meta_device *meta, const struct vk_acceleration_structure_build_args *args, - struct bvh_state *bvh_states, uint32_t build_count, - uint32_t build_flags) + struct vk_acceleration_structure_build_state *states, + uint32_t build_count, uint32_t build_flags) { VkPipeline pipeline; VkPipelineLayout layout; @@ -1145,23 +1125,23 @@ hploc_build_internal(VkCommandBuffer commandBuffer, struct vk_device *device, commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline); for (uint32_t i = 0; i < build_count; ++i) { - if (bvh_states[i].vk.config.internal_type != VK_INTERNAL_BUILD_TYPE_HPLOC) + if (states[i].config.internal_type != VK_INTERNAL_BUILD_TYPE_HPLOC) continue; assert(args->subgroup_size <= 64); - uint64_t scratch_addr = bvh_states[i].vk.build_info->scratchData.deviceAddress; + uint64_t scratch_addr = states[i].build_info->scratchData.deviceAddress; const struct hploc_args consts = { - .header = scratch_addr + bvh_states[i].vk.scratch.header_offset, - .bvh = scratch_addr + bvh_states[i].vk.scratch.ir_offset, - .ranges = scratch_addr + bvh_states[i].vk.scratch.hploc_ranges_offset, - .ids = scratch_addr + bvh_states[i].scratch_offset, - .internal_node_base = bvh_states[i].vk.scratch.internal_node_offset - bvh_states[i].vk.scratch.ir_offset, + .header = scratch_addr + states[i].scratch.header_offset, + .bvh = scratch_addr + states[i].scratch.ir_offset, + .ranges = scratch_addr + states[i].scratch.hploc_ranges_offset, + .ids = scratch_addr + states[i].scratch_offset, + .internal_node_base = states[i].scratch.internal_node_offset - states[i].scratch.ir_offset, }; disp->CmdPushConstants(commandBuffer, layout, VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(consts), &consts); - disp->CmdDispatch(commandBuffer, MAX2(DIV_ROUND_UP(bvh_states[i].vk.leaf_node_count, args->subgroup_size), 1), 1, 1); + disp->CmdDispatch(commandBuffer, MAX2(DIV_ROUND_UP(states[i].leaf_node_count, args->subgroup_size), 1), 1, 1); } if (args->emit_markers) { @@ -1177,22 +1157,23 @@ hploc_build_internal(VkCommandBuffer commandBuffer, struct vk_device *device, typedef VkResult (*vk_build_stage_cb)(VkCommandBuffer commandBuffer, struct vk_device *device, struct vk_meta_device *meta, const struct vk_acceleration_structure_build_args *args, - struct bvh_state *bvh_states, uint32_t build_count, - uint32_t build_flags); + struct vk_acceleration_structure_build_state *states, + uint32_t build_count, uint32_t build_flags); static VkResult vk_build_stage(vk_build_stage_cb cb, VkCommandBuffer commandBuffer, struct vk_device *device, struct vk_meta_device *meta, const struct vk_acceleration_structure_build_args *args, - struct bvh_state *bvh_states, uint32_t build_count, uint32_t build_flags_mask) + struct vk_acceleration_structure_build_state *states, uint32_t build_count, + uint32_t build_flags_mask) { BITSET_DECLARE(flag_combinations, 1u << VK_BUILD_FLAG_COUNT); BITSET_ZERO(flag_combinations); for (uint32_t i = 0; i < build_count; i++) - BITSET_SET(flag_combinations, bvh_states[i].build_flags & build_flags_mask); + BITSET_SET(flag_combinations, states[i].build_flags & build_flags_mask); uint32_t build_flags; BITSET_FOREACH_SET(build_flags, flag_combinations, 1u << VK_BUILD_FLAG_COUNT) { - VkResult result = cb(commandBuffer, device, meta, args, bvh_states, build_count, build_flags); + VkResult result = cb(commandBuffer, device, meta, args, states, build_count, build_flags); if (result != VK_SUCCESS) return result; } @@ -1214,7 +1195,7 @@ vk_cmd_build_acceleration_structures(VkCommandBuffer commandBuffer, struct bvh_batch_state batch_state = {0}; - struct bvh_state *bvh_states = calloc(infoCount, sizeof(struct bvh_state)); + struct vk_acceleration_structure_build_state *states = calloc(infoCount, sizeof(struct vk_acceleration_structure_build_state)); struct vk_acceleration_structure_build_marker top_marker = { .step = VK_ACCELERATION_STRUCTURE_BUILD_STEP_TOP, @@ -1241,38 +1222,38 @@ vk_cmd_build_acceleration_structures(VkCommandBuffer commandBuffer, leaf_node_count += ppBuildRangeInfos[i][j].primitiveCount; } - vk_acceleration_structure_build_state_init(&bvh_states[i].vk, cmd_buffer->base.device, leaf_node_count, + vk_acceleration_structure_build_state_init(&states[i], cmd_buffer->base.device, leaf_node_count, pInfos + i, args); - bvh_states[i].vk.build_range_infos = ppBuildRangeInfos[i]; + states[i].build_range_infos = ppBuildRangeInfos[i]; /* The leaf node dispatch code uses leaf_node_count as a base index. */ - bvh_states[i].vk.leaf_node_count = 0; + states[i].leaf_node_count = 0; - if (bvh_states[i].vk.config.updateable) + if (states[i].config.updateable) batch_state.any_updateable = true; else batch_state.any_non_updateable = true; - if (bvh_states[i].vk.config.internal_type == VK_INTERNAL_BUILD_TYPE_PLOC) { + if (states[i].config.internal_type == VK_INTERNAL_BUILD_TYPE_PLOC) { batch_state.any_ploc = true; - } else if (bvh_states[i].vk.config.internal_type == VK_INTERNAL_BUILD_TYPE_HPLOC) { + } else if (states[i].config.internal_type == VK_INTERNAL_BUILD_TYPE_HPLOC) { batch_state.any_hploc = true; - } else if (bvh_states[i].vk.config.internal_type == VK_INTERNAL_BUILD_TYPE_LBVH) { + } else if (states[i].config.internal_type == VK_INTERNAL_BUILD_TYPE_LBVH) { batch_state.any_lbvh = true; - } else if (bvh_states[i].vk.config.internal_type == VK_INTERNAL_BUILD_TYPE_UPDATE) { + } else if (states[i].config.internal_type == VK_INTERNAL_BUILD_TYPE_UPDATE) { batch_state.any_update = true; /* For updates, the leaf node pass never runs, so set leaf_node_count here. */ - bvh_states[i].vk.leaf_node_count = leaf_node_count; + states[i].leaf_node_count = leaf_node_count; } else { UNREACHABLE("Unknown internal_build_type"); } - if (bvh_states[i].vk.config.updateable) - bvh_states[i].build_flags |= VK_BUILD_FLAG_ALWAYS_ACTIVE; + if (states[i].config.updateable) + states[i].build_flags |= VK_BUILD_FLAG_ALWAYS_ACTIVE; if (args->propagate_cull_flags) - bvh_states[i].build_flags |= VK_BUILD_FLAG_PROPAGATE_CULL_FLAGS; + states[i].build_flags |= VK_BUILD_FLAG_PROPAGATE_CULL_FLAGS; - if (bvh_states[i].vk.config.internal_type != VK_INTERNAL_BUILD_TYPE_UPDATE) { + if (states[i].config.internal_type != VK_INTERNAL_BUILD_TYPE_UPDATE) { /* The internal node count is updated in lbvh_build_internal for LBVH * and from the PLOC shader for PLOC. */ struct vk_ir_header header = { @@ -1288,10 +1269,10 @@ vk_cmd_build_acceleration_structures(VkCommandBuffer commandBuffer, }, }; - device->write_buffer_cp(commandBuffer, pInfos[i].scratchData.deviceAddress + bvh_states[i].vk.scratch.header_offset, + device->write_buffer_cp(commandBuffer, pInfos[i].scratchData.deviceAddress + states[i].scratch.header_offset, &header, sizeof(header)); } else { - ops->init_update_scratch(commandBuffer, &bvh_states[i].vk); + ops->init_update_scratch(commandBuffer, &states[i]); } } @@ -1311,19 +1292,19 @@ vk_cmd_build_acceleration_structures(VkCommandBuffer commandBuffer, }, 0, NULL, 0, NULL); if (batch_state.any_lbvh || batch_state.any_ploc || batch_state.any_hploc) { - VkResult result = vk_build_stage(build_leaves, commandBuffer, device, meta, args, bvh_states, infoCount, + VkResult result = vk_build_stage(build_leaves, commandBuffer, device, meta, args, states, infoCount, VK_BUILD_LEAVES_FLAGS); if (result != VK_SUCCESS) { - free(bvh_states); + free(states); vk_command_buffer_set_error(cmd_buffer, result); return; } if (batch_state.any_hploc) { for (uint32_t i = 0; i < infoCount; ++i) { - if (bvh_states[i].vk.config.internal_type == VK_INTERNAL_BUILD_TYPE_HPLOC) { - device->cmd_fill_buffer_addr(commandBuffer, pInfos[i].scratchData.deviceAddress + bvh_states[i].vk.scratch.hploc_ranges_offset, - sizeof(uint32_t) * bvh_states[i].vk.internal_node_count, 0xffffffff); + if (states[i].config.internal_type == VK_INTERNAL_BUILD_TYPE_HPLOC) { + device->cmd_fill_buffer_addr(commandBuffer, pInfos[i].scratchData.deviceAddress + states[i].scratch.hploc_ranges_offset, + sizeof(uint32_t) * states[i].internal_node_count, 0xffffffff); } } vk_barrier_transfer_w_to_compute_r(commandBuffer); @@ -1331,31 +1312,31 @@ vk_cmd_build_acceleration_structures(VkCommandBuffer commandBuffer, vk_barrier_compute_w_to_compute_r(commandBuffer); - result = vk_build_stage(morton_generate, commandBuffer, device, meta, args, bvh_states, infoCount, 0); + result = vk_build_stage(morton_generate, commandBuffer, device, meta, args, states, infoCount, 0); if (result != VK_SUCCESS) { - free(bvh_states); + free(states); vk_command_buffer_set_error(cmd_buffer, result); return; } vk_barrier_compute_w_to_compute_r(commandBuffer); - vk_build_stage(morton_sort, commandBuffer, device, meta, args, bvh_states, infoCount, 0); + vk_build_stage(morton_sort, commandBuffer, device, meta, args, states, infoCount, 0); vk_barrier_compute_w_to_compute_r(commandBuffer); if (batch_state.any_lbvh) { - result = vk_build_stage(lbvh_build_internal, commandBuffer, device, meta, args, bvh_states, infoCount, + result = vk_build_stage(lbvh_build_internal, commandBuffer, device, meta, args, states, infoCount, VK_LBVH_BUILD_INTERNAL_FLAGS); if (result != VK_SUCCESS) { - free(bvh_states); + free(states); vk_command_buffer_set_error(cmd_buffer, result); return; } } if (batch_state.any_ploc) { - result = vk_build_stage(ploc_build_internal, commandBuffer, device, meta, args, bvh_states, infoCount, + result = vk_build_stage(ploc_build_internal, commandBuffer, device, meta, args, states, infoCount, VK_PLOC_BUILD_INTERNAL_FLAGS); if (result != VK_SUCCESS) { vk_command_buffer_set_error(cmd_buffer, result); @@ -1364,7 +1345,7 @@ vk_cmd_build_acceleration_structures(VkCommandBuffer commandBuffer, } if (batch_state.any_hploc) { - result = vk_build_stage(hploc_build_internal, commandBuffer, device, meta, args, bvh_states, infoCount, + result = vk_build_stage(hploc_build_internal, commandBuffer, device, meta, args, states, infoCount, VK_HPLOC_BUILD_INTERNAL_FLAGS); if (result != VK_SUCCESS) { vk_command_buffer_set_error(cmd_buffer, result); @@ -1392,18 +1373,18 @@ vk_cmd_build_acceleration_structures(VkCommandBuffer commandBuffer, uint32_t encode_key = 0; uint32_t update_key = 0; for (uint32_t i = 0; i < infoCount; ++i) { - if (bvh_states[i].last_encode_pass == pass + 1) + if (states[i].last_encode_pass == pass + 1) continue; if (!progress) { - update = (bvh_states[i].vk.config.internal_type == + update = (states[i].config.internal_type == VK_INTERNAL_BUILD_TYPE_UPDATE); if (update && !ops->update_as[pass]) continue; if (!update && !ops->encode_as[pass]) continue; - encode_key = bvh_states[i].vk.config.encode_key[pass]; - update_key = bvh_states[i].vk.config.update_key[pass]; + encode_key = states[i].config.encode_key[pass]; + update_key = states[i].config.update_key[pass]; progress = true; if (args->emit_markers) { @@ -1417,14 +1398,14 @@ vk_cmd_build_acceleration_structures(VkCommandBuffer commandBuffer, encode_marker.encode.key = update ? update_key : encode_key; for (uint32_t j = 0; j < infoCount; j++) { - if (update != (bvh_states[j].vk.config.internal_type == + if (update != (states[j].config.internal_type == VK_INTERNAL_BUILD_TYPE_UPDATE) || - encode_key != bvh_states[j].vk.config.encode_key[pass] || - update_key != bvh_states[j].vk.config.update_key[pass]) + encode_key != states[j].config.encode_key[pass] || + update_key != states[j].config.update_key[pass]) continue; - encode_marker.encode.leaf_node_count += bvh_states[j].vk.leaf_node_count; - encode_marker.encode.internal_node_count += bvh_states[i].vk.internal_node_count; + encode_marker.encode.leaf_node_count += states[j].leaf_node_count; + encode_marker.encode.internal_node_count += states[i].internal_node_count; } device->as_build_ops->begin_debug_marker(commandBuffer, &encode_marker); @@ -1433,26 +1414,26 @@ vk_cmd_build_acceleration_structures(VkCommandBuffer commandBuffer, } if (update) { - ops->update_prepare[pass](commandBuffer, &bvh_states[i].vk, - flushed_cp_after_init_update_scratch, - flushed_compute_after_init_update_scratch); + ops->update_prepare[pass](commandBuffer, &states[i], + flushed_cp_after_init_update_scratch, + flushed_compute_after_init_update_scratch); } else { - ops->encode_prepare[pass](commandBuffer, &bvh_states[i].vk); + ops->encode_prepare[pass](commandBuffer, &states[i]); } } else { - if (update != (bvh_states[i].vk.config.internal_type == + if (update != (states[i].config.internal_type == VK_INTERNAL_BUILD_TYPE_UPDATE) || - encode_key != bvh_states[i].vk.config.encode_key[pass] || - update_key != bvh_states[i].vk.config.update_key[pass]) + encode_key != states[i].config.encode_key[pass] || + update_key != states[i].config.update_key[pass]) continue; } if (update) - ops->update_as[pass](commandBuffer, &bvh_states[i].vk); + ops->update_as[pass](commandBuffer, &states[i]); else - ops->encode_as[pass](commandBuffer, &bvh_states[i].vk); + ops->encode_as[pass](commandBuffer, &states[i]); - bvh_states[i].last_encode_pass = pass + 1; + states[i].last_encode_pass = pass + 1; } } while (progress); } @@ -1463,7 +1444,7 @@ vk_cmd_build_acceleration_structures(VkCommandBuffer commandBuffer, if (args->emit_markers) device->as_build_ops->end_debug_marker(commandBuffer, &top_marker); - free(bvh_states); + free(states); } void diff --git a/src/vulkan/runtime/vk_acceleration_structure.h b/src/vulkan/runtime/vk_acceleration_structure.h index ba6f713d687..a10540fc5d6 100644 --- a/src/vulkan/runtime/vk_acceleration_structure.h +++ b/src/vulkan/runtime/vk_acceleration_structure.h @@ -29,6 +29,8 @@ #include "vk_meta.h" #include "vk_object.h" #include "radix_sort/radix_sort_vk.h" +#include "radix_sort/common/vk/barrier.h" +#include "radix_sort/shaders/push.h" #include "bvh/vk_bvh.h" @@ -131,6 +133,20 @@ struct vk_acceleration_structure_build_state { uint32_t internal_node_count; struct vk_scratch_layout scratch; struct vk_build_config config; + + /* Internal state of vk_acceleration_structure.c */ + uint32_t build_flags; + + uint32_t scratch_offset; + + /* Radix sort state */ + uint32_t scatter_blocks; + uint32_t count_ru_scatter; + uint32_t histo_blocks; + uint32_t count_ru_histo; + struct rs_push_scatter push_scatter; + + uint32_t last_encode_pass; }; struct vk_acceleration_structure_build_ops {