diff --git a/src/amd/vulkan/radv_acceleration_structure.c b/src/amd/vulkan/radv_acceleration_structure.c index 73c36808e3b..b6449a19bdd 100644 --- a/src/amd/vulkan/radv_acceleration_structure.c +++ b/src/amd/vulkan/radv_acceleration_structure.c @@ -61,6 +61,11 @@ static const uint32_t convert_internal_spv[] = { #define KEY_ID_PAIR_SIZE 8 +enum radv_accel_struct_build_type { + RADV_ACCEL_STRUCT_BUILD_TYPE_LBVH, + RADV_ACCEL_STRUCT_BUILD_TYPE_PLOC, +}; + struct acceleration_structure_layout { uint32_t bvh_offset; uint32_t size; @@ -74,9 +79,27 @@ struct scratch_layout { uint32_t sort_buffer_offset[2]; uint32_t sort_internal_offset; + uint32_t ploc_prefix_sum_partition_offset; + uint32_t ir_offset; }; +static enum radv_accel_struct_build_type +build_type(uint32_t leaf_count, const VkAccelerationStructureBuildGeometryInfoKHR *build_info) +{ + if (leaf_count <= 4) + return RADV_ACCEL_STRUCT_BUILD_TYPE_LBVH; + + if (build_info->type == VK_ACCELERATION_STRUCTURE_TYPE_TOP_LEVEL_KHR) + return RADV_ACCEL_STRUCT_BUILD_TYPE_LBVH; + + if (!(build_info->flags & VK_BUILD_ACCELERATION_STRUCTURE_PREFER_FAST_BUILD_BIT_KHR) && + !(build_info->flags & VK_BUILD_ACCELERATION_STRUCTURE_ALLOW_UPDATE_BIT_KHR)) + return RADV_ACCEL_STRUCT_BUILD_TYPE_PLOC; + else + return RADV_ACCEL_STRUCT_BUILD_TYPE_LBVH; +} + static void get_build_layout(struct radv_device *device, uint32_t leaf_count, const VkAccelerationStructureBuildGeometryInfoKHR *build_info, @@ -145,6 +168,12 @@ get_build_layout(struct radv_device *device, uint32_t leaf_count, uint32_t offset = 0; + uint32_t ploc_scratch_space = 0; + + if (build_type(leaf_count, build_info)) + ploc_scratch_space = DIV_ROUND_UP(leaf_count, PLOC_WORKGROUP_SIZE) * + sizeof(struct ploc_prefix_scan_partition); + scratch->header_offset = offset; offset += sizeof(struct radv_ir_header); @@ -155,7 +184,10 @@ get_build_layout(struct radv_device *device, uint32_t leaf_count, offset += requirements.keyvals_size; scratch->sort_internal_offset = offset; - offset += requirements.internal_size; + /* Internal sorting data is not needed when PLOC is invoked, + * save space by aliasing them */ + scratch->ploc_prefix_sum_partition_offset = offset; + offset += MAX2(requirements.internal_size, ploc_scratch_space); scratch->ir_offset = offset; offset += ir_leaf_size * leaf_count; @@ -455,6 +487,7 @@ struct bvh_state { struct acceleration_structure_layout accel_struct; struct scratch_layout scratch; + enum radv_accel_struct_build_type type; }; static void @@ -629,6 +662,9 @@ lbvh_build_internal(VkCommandBuffer commandBuffer, uint32_t infoCount, for (unsigned iter = 0; progress; ++iter) { progress = false; for (uint32_t i = 0; i < infoCount; ++i) { + if (bvh_states[i].type != RADV_ACCEL_STRUCT_BUILD_TYPE_LBVH) + continue; + if (iter && bvh_states[i].node_count == 1) continue; @@ -667,14 +703,62 @@ lbvh_build_internal(VkCommandBuffer commandBuffer, uint32_t infoCount, } for (uint32_t i = 0; i < infoCount; ++i) { + if (bvh_states[i].type != RADV_ACCEL_STRUCT_BUILD_TYPE_LBVH) + continue; + radv_update_buffer_cp(cmd_buffer, pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.header_offset + offsetof(struct radv_ir_header, ir_internal_node_count), &bvh_states[i].internal_node_count, 4); } +} - cmd_buffer->state.flush_bits |= flush_bits; +static void +ploc_build_internal(VkCommandBuffer commandBuffer, uint32_t infoCount, + const VkAccelerationStructureBuildGeometryInfoKHR *pInfos, + struct bvh_state *bvh_states) +{ + RADV_FROM_HANDLE(radv_cmd_buffer, cmd_buffer, commandBuffer); + radv_CmdBindPipeline(commandBuffer, VK_PIPELINE_BIND_POINT_COMPUTE, + cmd_buffer->device->meta_state.accel_struct_build.ploc_pipeline); + + for (uint32_t i = 0; i < infoCount; ++i) { + if (bvh_states[i].type != RADV_ACCEL_STRUCT_BUILD_TYPE_PLOC) + continue; + + struct radv_global_sync_data initial_sync_data = { + .current_phase_end_counter = DIV_ROUND_UP(bvh_states[i].node_count, PLOC_WORKGROUP_SIZE), + /* Will be updated by the first PLOC shader invocation */ + .task_counts = {TASK_INDEX_INVALID, TASK_INDEX_INVALID}, + }; + radv_update_buffer_cp(cmd_buffer, + pInfos[i].scratchData.deviceAddress + + bvh_states[i].scratch.header_offset + + offsetof(struct radv_ir_header, sync_data), + &initial_sync_data, sizeof(struct radv_global_sync_data)); + + uint32_t src_scratch_offset = bvh_states[i].scratch_offset; + uint32_t dst_scratch_offset = + (src_scratch_offset == bvh_states[i].scratch.sort_buffer_offset[0]) + ? bvh_states[i].scratch.sort_buffer_offset[1] + : bvh_states[i].scratch.sort_buffer_offset[0]; + + const struct ploc_args consts = { + .bvh = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.ir_offset, + .header = pInfos[i].scratchData.deviceAddress + bvh_states[i].scratch.header_offset, + .ids_0 = pInfos[i].scratchData.deviceAddress + src_scratch_offset, + .ids_1 = pInfos[i].scratchData.deviceAddress + dst_scratch_offset, + .prefix_scan_partitions = pInfos[i].scratchData.deviceAddress + + bvh_states[i].scratch.ploc_prefix_sum_partition_offset, + .internal_node_offset = bvh_states[i].node_offset, + }; + + radv_CmdPushConstants(commandBuffer, + cmd_buffer->device->meta_state.accel_struct_build.ploc_p_layout, + VK_SHADER_STAGE_COMPUTE_BIT, 0, sizeof(consts), &consts); + radv_CmdDispatch(commandBuffer, MAX2(DIV_ROUND_UP(bvh_states[i].node_count, 64), 1), 1, 1); + } } static void @@ -776,6 +860,7 @@ radv_CmdBuildAccelerationStructuresKHR( get_build_layout(cmd_buffer->device, leaf_node_count, pInfos + i, &bvh_states[i].accel_struct, &bvh_states[i].scratch); + bvh_states[i].type = build_type(leaf_node_count, pInfos + i); /* The internal node count is updated in lbvh_build_internal for LBVH * and from the PLOC shader for PLOC. */ @@ -799,7 +884,12 @@ radv_CmdBuildAccelerationStructuresKHR( morton_sort(commandBuffer, infoCount, pInfos, bvh_states, flush_bits); + cmd_buffer->state.flush_bits |= flush_bits; + lbvh_build_internal(commandBuffer, infoCount, pInfos, bvh_states, flush_bits); + ploc_build_internal(commandBuffer, infoCount, pInfos, bvh_states); + + cmd_buffer->state.flush_bits |= flush_bits; convert_leaf_nodes(commandBuffer, infoCount, pInfos, bvh_states);