diff --git a/src/vulkan/runtime/bvh/hploc_internal.comp b/src/vulkan/runtime/bvh/hploc_internal.comp index 98a554e9181..abdef8347b7 100644 --- a/src/vulkan/runtime/bvh/hploc_internal.comp +++ b/src/vulkan/runtime/bvh/hploc_internal.comp @@ -39,10 +39,20 @@ delta(uint32_t index) uint32_t left_index = index; uint32_t right_index = index + 1; - uint32_t left_key = DEREF(INDEX(key32_id_pair, args.ids, left_index)).key; - uint32_t right_key = DEREF(INDEX(key32_id_pair, args.ids, right_index)).key; + if (VK_BUILD_FLAG(VK_BUILD_FLAG_64BIT_KEYS)) { + uint64_t left_key = packUint2x32(u32vec2( + DEREF(INDEX(key64_id_pair, args.ids, left_index)).key_lo, + DEREF(INDEX(key64_id_pair, args.ids, left_index)).key_hi)); + uint64_t right_key = packUint2x32(u32vec2( + DEREF(INDEX(key64_id_pair, args.ids, right_index)).key_lo, + DEREF(INDEX(key64_id_pair, args.ids, right_index)).key_hi)); + return left_key != right_key ? (32 + uint32_t(findMSB(left_key ^ right_key))) : uint32_t(findMSB(left_index ^ right_index)); + } else { + uint32_t left_key = DEREF(INDEX(key32_id_pair, args.ids, left_index)).key; + uint32_t right_key = DEREF(INDEX(key32_id_pair, args.ids, right_index)).key; - return left_key != right_key ? (32 + findMSB(left_key ^ right_key)) : findMSB(left_index ^ right_index); + return left_key != right_key ? (32 + findMSB(left_key ^ right_key)) : findMSB(left_index ^ right_index); + } } #define SEARCH_RADIUS 16 @@ -67,8 +77,13 @@ main(void) uint32_t child_id = VK_BVH_INVALID_NODE; vk_ir_node child = vk_ir_node(vk_aabb(vec3(0.0), vec3(0.0))); if (active_leaf_count > 0) { - REF(key32_id_pair) key_id = INDEX(key32_id_pair, args.ids, global_id); - child_id = DEREF(key_id).id; + if (VK_BUILD_FLAG(VK_BUILD_FLAG_64BIT_KEYS)) { + REF(key64_id_pair) key_id = INDEX(key64_id_pair, args.ids, global_id); + child_id = DEREF(key_id).id; + } else { + REF(key32_id_pair) key_id = INDEX(key32_id_pair, args.ids, global_id); + child_id = DEREF(key_id).id; + } child = DEREF(REF(vk_ir_node)(OFFSET(args.bvh, ir_id_to_offset(child_id)))); } @@ -142,8 +157,13 @@ main(void) } uint32_t node_id_index = load_base + load_index; uint32_t node_id = VK_BVH_INVALID_NODE; - if (load_index < index_range) - node_id = DEREF(INDEX(key32_id_pair, args.ids, node_id_index)).id; + if (load_index < index_range) { + if (VK_BUILD_FLAG(VK_BUILD_FLAG_64BIT_KEYS)) { + node_id = DEREF(INDEX(key64_id_pair, args.ids, node_id_index)).id; + } else { + node_id = DEREF(INDEX(key32_id_pair, args.ids, node_id_index)).id; + } + } uvec4 node_valid_mask = subgroupBallot(node_id != VK_BVH_INVALID_NODE); uint32_t node_prefix_sum = subgroupBallotExclusiveBitCount(node_valid_mask); @@ -225,7 +245,11 @@ main(void) if (gl_SubgroupInvocationID < min(end - start + 1, cluster_threshold)) { uint32_t dst_node = gl_SubgroupInvocationID < node_count ? node_ids[gl_SubgroupInvocationID] : VK_BVH_INVALID_NODE; - DEREF(INDEX(key32_id_pair, args.ids, start + gl_SubgroupInvocationID)).id = dst_node; + if (VK_BUILD_FLAG(VK_BUILD_FLAG_64BIT_KEYS)) { + DEREF(INDEX(key64_id_pair, args.ids, start + gl_SubgroupInvocationID)).id = dst_node; + } else { + DEREF(INDEX(key32_id_pair, args.ids, start + gl_SubgroupInvocationID)).id = dst_node; + } } memoryBarrier(gl_ScopeDevice, gl_StorageSemanticsBuffer, diff --git a/src/vulkan/runtime/bvh/lbvh_main.comp b/src/vulkan/runtime/bvh/lbvh_main.comp index fa021c49f47..2c3ad67b0c3 100644 --- a/src/vulkan/runtime/bvh/lbvh_main.comp +++ b/src/vulkan/runtime/bvh/lbvh_main.comp @@ -35,21 +35,37 @@ layout(push_constant) uniform CONSTS }; int32_t -longest_common_prefix(int32_t i, uint32_t key_i, int32_t j, uint32_t active_leaf_count) +longest_common_prefix(int32_t i, uint32_t key_i_lo, uint32_t key_i_hi, int32_t j, uint32_t active_leaf_count) { if (j < 0 || j >= active_leaf_count) return -1; - uint32_t key_j = DEREF(INDEX(key32_id_pair, args.src_ids, j)).key; + if (VK_BUILD_FLAG(VK_BUILD_FLAG_64BIT_KEYS)) { + uint64_t key_i = packUint2x32(u32vec2(key_i_lo, key_i_hi)); + key64_id_pair key_id_j = DEREF(INDEX(key64_id_pair, args.src_ids, j)); + uint64_t key_j = packUint2x32(u32vec2(key_id_j.key_lo, key_id_j.key_hi)); - uint32_t diff = key_i ^ key_j; - int32_t ret = 0; - if (key_i == key_j) { - ret += 32; - diff = i ^ j; + uint64_t diff = key_i ^ key_j; + int32_t ret = 0; + if (key_i == key_j) { + ret += 64; + diff = i ^ j; + } + + return ret + 63 - int32_t(findMSB(diff)); + } else { + uint32_t key_i = key_i_lo; + uint32_t key_j = DEREF(INDEX(key32_id_pair, args.src_ids, j)).key; + + uint32_t diff = key_i ^ key_j; + int32_t ret = 0; + if (key_i == key_j) { + ret += 32; + diff = i ^ j; + } + + return ret + 31 - findMSB(diff); } - - return ret + 31 - findMSB(diff); } /* @@ -84,8 +100,9 @@ main() REF(lbvh_node_info) dst = REF(lbvh_node_info)(args.node_info); DEREF(dst).parent = VK_BVH_INVALID_NODE; DEREF(dst).path_count = 2; - DEREF(dst).children[0] = - active_leaf_count == 1 ? DEREF(INDEX(key32_id_pair, args.src_ids, 0)).id : VK_BVH_INVALID_NODE; + uint32_t id = VK_BUILD_FLAG(VK_BUILD_FLAG_64BIT_KEYS) ? DEREF(INDEX(key64_id_pair, args.src_ids, 0)).id + : DEREF(INDEX(key32_id_pair, args.src_ids, 0)).id; + DEREF(dst).children[0] = active_leaf_count == 1 ? id : VK_BVH_INVALID_NODE; DEREF(dst).children[1] = VK_BVH_INVALID_NODE; return; } @@ -95,10 +112,18 @@ main() if (id >= internal_node_count) return; - uint32_t id_key = DEREF(INDEX(key32_id_pair, args.src_ids, id)).key; + uint32_t key_lo; + uint32_t key_hi; + if (VK_BUILD_FLAG(VK_BUILD_FLAG_64BIT_KEYS)) { + key64_id_pair key_id = DEREF(INDEX(key64_id_pair, args.src_ids, id)); + key_lo = key_id.key_lo; + key_hi = key_id.key_hi; + } else { + key_lo = DEREF(INDEX(key32_id_pair, args.src_ids, id)).key; + } - int32_t left_lcp = longest_common_prefix(id, id_key, id - 1, active_leaf_count); - int32_t right_lcp = longest_common_prefix(id, id_key, id + 1, active_leaf_count); + int32_t left_lcp = longest_common_prefix(id, key_lo, key_hi, id - 1, active_leaf_count); + int32_t right_lcp = longest_common_prefix(id, key_lo, key_hi, id + 1, active_leaf_count); int32_t dir = right_lcp > left_lcp ? 1 : -1; int32_t lcp_min = min(left_lcp, right_lcp); @@ -106,13 +131,13 @@ main() * this subtree is going to own. */ int32_t lmax = 128; - while (longest_common_prefix(id, id_key, id + dir * lmax, active_leaf_count) > lcp_min) { + while (longest_common_prefix(id, key_lo, key_hi, id + dir * lmax, active_leaf_count) > lcp_min) { lmax *= 2; } int32_t length = 0; for (int32_t t = lmax / 2; t >= 1; t /= 2) { - if (longest_common_prefix(id, id_key, id + (length + t) * dir, active_leaf_count) > lcp_min) + if (longest_common_prefix(id, key_lo, key_hi, id + (length + t) * dir, active_leaf_count) > lcp_min) length += t; } int32_t other_end = id + length * dir; @@ -120,11 +145,11 @@ main() /* The number of bits in the prefix that is the same for all elements in the * range. */ - int32_t lcp_node = longest_common_prefix(id, id_key, other_end, active_leaf_count); + int32_t lcp_node = longest_common_prefix(id, key_lo, key_hi, other_end, active_leaf_count); int32_t child_range = 0; for (int32_t diff = 2; diff < 2 * length; diff *= 2) { int32_t t = DIV_ROUND_UP(length, diff); - if (longest_common_prefix(id, id_key, id + (child_range + t) * dir, active_leaf_count) > lcp_node) + if (longest_common_prefix(id, key_lo, key_hi, id + (child_range + t) * dir, active_leaf_count) > lcp_node) child_range += t; } @@ -145,8 +170,13 @@ main() REF(lbvh_node_info) dst = INDEX(lbvh_node_info, args.node_info, id); DEREF(dst).path_count = (left_leaf ? 1 : 0) + (right_leaf ? 1 : 0); - DEREF(dst).children[0] = DEREF(INDEX(key32_id_pair, args.src_ids, left)).id; - DEREF(dst).children[1] = DEREF(INDEX(key32_id_pair, args.src_ids, right)).id; + if (VK_BUILD_FLAG(VK_BUILD_FLAG_64BIT_KEYS)) { + DEREF(dst).children[0] = DEREF(INDEX(key64_id_pair, args.src_ids, left)).id; + DEREF(dst).children[1] = DEREF(INDEX(key64_id_pair, args.src_ids, right)).id; + } else { + DEREF(dst).children[0] = DEREF(INDEX(key32_id_pair, args.src_ids, left)).id; + DEREF(dst).children[1] = DEREF(INDEX(key32_id_pair, args.src_ids, right)).id; + } if (id == 0) DEREF(dst).parent = VK_BVH_INVALID_NODE; } diff --git a/src/vulkan/runtime/bvh/leaf.h b/src/vulkan/runtime/bvh/leaf.h index 5ced463fecd..cfd13c51edd 100644 --- a/src/vulkan/runtime/bvh/leaf.h +++ b/src/vulkan/runtime/bvh/leaf.h @@ -204,7 +204,6 @@ main(void) uint32_t global_id = gl_GlobalInvocationID.x; uint32_t primitive_id = args.geom_data.first_id + global_id; - REF(key32_id_pair) id_ptr = INDEX(key32_id_pair, args.ids, primitive_id); uint32_t src_offset = global_id * args.geom_data.stride; uint32_t dst_stride; @@ -243,7 +242,11 @@ main(void) if (VK_BUILD_FLAG(VK_BUILD_FLAG_ALWAYS_ACTIVE)) is_active = true; - DEREF(id_ptr).id = is_active ? pack_ir_node_id(dst_offset, node_type) : VK_BVH_INVALID_NODE; + uint32_t id = is_active ? pack_ir_node_id(dst_offset, node_type) : VK_BVH_INVALID_NODE; + if (VK_BUILD_FLAG(VK_BUILD_FLAG_64BIT_KEYS)) + DEREF(INDEX(key64_id_pair, args.ids, primitive_id)).id = id; + else + DEREF(INDEX(key32_id_pair, args.ids, primitive_id)).id = id; uvec4 ballot = subgroupBallot(is_active); if (subgroupElect()) diff --git a/src/vulkan/runtime/bvh/morton.comp b/src/vulkan/runtime/bvh/morton.comp index c3b39e6e3fa..5a7076a564c 100644 --- a/src/vulkan/runtime/bvh/morton.comp +++ b/src/vulkan/runtime/bvh/morton.comp @@ -34,34 +34,61 @@ layout(push_constant) uniform CONSTS { uint32_t morton_component(uint32_t x) { + x = (x * 0x00010001u) & 0xFF0000FFu; x = (x * 0x00000101u) & 0x0F00F00Fu; x = (x * 0x00000011u) & 0xC30C30C3u; x = (x * 0x00000005u) & 0x49249249u; return x; } +uint64_t +morton_component64(uint64_t x) +{ + x = (x | x << 32) & 0x1f00000000fffful; + x = (x | x << 16) & 0x1f0000ff0000fful; + x = (x | x << 8) & 0x100f00f00f00f00ful; + x = (x | x << 4) & 0x10c30c30c30c30c3ul; + x = (x | x << 2) & 0x1249249249249249ul; + return x; +} + uint32_t morton_code(uint32_t x, uint32_t y, uint32_t z) { return (morton_component(x) << 2) | (morton_component(y) << 1) | morton_component(z); } +uint64_t +morton_code64(uint64_t x, uint64_t y, uint64_t z) +{ + return (morton_component64(x) << 2) | (morton_component64(y) << 1) | morton_component64(z); +} + uint32_t lbvh_key(float x01, float y01, float z01) { return morton_code(uint32_t(x01 * 255.0), uint32_t(y01 * 255.0), uint32_t(z01 * 255.0)) << 8; } +uint64_t +lbvh_key64(float x01, float y01, float z01) +{ + return morton_code64(uint64_t(x01 * float(0x1fffff)), uint64_t(y01 * float(0x1fffff)), uint64_t(z01 * float(0x1fffff))); +} + void main(void) { uint32_t global_id = gl_GlobalInvocationID.x; - REF(key32_id_pair) key_id = INDEX(key32_id_pair, args.ids, global_id); - - uint32_t id = DEREF(key_id).id; + uint32_t id; + if (VK_BUILD_FLAG(VK_BUILD_FLAG_64BIT_KEYS)) + id = DEREF(INDEX(key64_id_pair, args.ids, global_id)).id; + else + id = DEREF(INDEX(key32_id_pair, args.ids, global_id)).id; uint32_t key; + uint64_t key64; if (id != VK_BVH_INVALID_NODE) { vk_aabb bounds = DEREF(REF(vk_ir_node)OFFSET(args.bvh, ir_id_to_offset(id))).aabb; vec3 center = (bounds.min + bounds.max) * 0.5; @@ -76,12 +103,22 @@ main(void) vec3 normalized_center = (center - bvh_bounds.min) / (bvh_bounds.max - bvh_bounds.min); - key = lbvh_key(normalized_center.x, normalized_center.y, normalized_center.z); + if (VK_BUILD_FLAG(VK_BUILD_FLAG_64BIT_KEYS)) + key64 = lbvh_key64(normalized_center.x, normalized_center.y, normalized_center.z); + else + key = lbvh_key(normalized_center.x, normalized_center.y, normalized_center.z); } else { /* Move null instances to the end to avoid mixing null instances with active instances. This * way, we can skip early during traversal. */ - key = 0xFFFFFFFF; + key = 0xFFFFFFFFu; + key64 = 0xFFFFFFFFFFFFFFFFul; + } + if (VK_BUILD_FLAG(VK_BUILD_FLAG_64BIT_KEYS)) { + REF(key64_id_pair) key_id = INDEX(key64_id_pair, args.ids, global_id); + DEREF(key_id).key_lo = unpackUint2x32(key64).x; + DEREF(key_id).key_hi = unpackUint2x32(key64).y; + } else { + DEREF(INDEX(key32_id_pair, args.ids, global_id)).key = key; } - DEREF(key_id).key = key; } diff --git a/src/vulkan/runtime/bvh/ploc_internal.comp b/src/vulkan/runtime/bvh/ploc_internal.comp index a5dc4c89214..051faf608f0 100644 --- a/src/vulkan/runtime/bvh/ploc_internal.comp +++ b/src/vulkan/runtime/bvh/ploc_internal.comp @@ -177,10 +177,14 @@ shared uint32_t nearest_neighbour_indices[NUM_PLOC_LDS_ITEMS]; uint32_t load_id(VOID_REF ids, uint32_t iter, uint32_t index) { - if (iter == 0) - return DEREF(REF(key32_id_pair)(INDEX(key32_id_pair, ids, index))).id; - else + if (iter == 0) { + if (VK_BUILD_FLAG(VK_BUILD_FLAG_64BIT_KEYS)) + return DEREF(REF(key64_id_pair)(INDEX(key64_id_pair, ids, index))).id; + else + return DEREF(REF(key32_id_pair)(INDEX(key32_id_pair, ids, index))).id; + } else { return DEREF(REF(uint32_t)(INDEX(uint32_t, ids, index))); + } } void @@ -231,7 +235,11 @@ main(void) uint32_t i = 0; for (; i < DEREF(args.header).active_leaf_count; i++) { - uint32_t child_id = DEREF(INDEX(key32_id_pair, src_ids, i)).id; + uint32_t child_id; + if (VK_BUILD_FLAG(VK_BUILD_FLAG_64BIT_KEYS)) + child_id = DEREF(INDEX(key64_id_pair, src_ids, i)).id; + else + child_id = DEREF(INDEX(key32_id_pair, src_ids, i)).id; if (child_id != VK_BVH_INVALID_NODE) { VOID_REF node = OFFSET(args.bvh, ir_id_to_offset(child_id)); diff --git a/src/vulkan/runtime/bvh/vk_build_helpers.h b/src/vulkan/runtime/bvh/vk_build_helpers.h index e67e2194b20..c15d38d7988 100644 --- a/src/vulkan/runtime/bvh/vk_build_helpers.h +++ b/src/vulkan/runtime/bvh/vk_build_helpers.h @@ -235,6 +235,7 @@ from_emulated_float(int32_t bits) TYPE(vk_aabb, 4); TYPE(key32_id_pair, 4); +TYPE(key64_id_pair, 4); TYPE(vk_accel_struct_serialization_header, 8); diff --git a/src/vulkan/runtime/bvh/vk_build_interface.h b/src/vulkan/runtime/bvh/vk_build_interface.h index f9fc0504c45..bf70e5da807 100644 --- a/src/vulkan/runtime/bvh/vk_build_interface.h +++ b/src/vulkan/runtime/bvh/vk_build_interface.h @@ -46,7 +46,8 @@ layout (constant_id = ROOT_FLAGS_OFFSET_ID) const int ROOT_FLAGS_OFFSET = -1; #define VK_BUILD_FLAG_ALWAYS_ACTIVE (1u << 0) #define VK_BUILD_FLAG_PROPAGATE_CULL_FLAGS (1u << 1) -#define VK_BUILD_FLAG_COUNT 2 +#define VK_BUILD_FLAG_64BIT_KEYS (1u << 2) +#define VK_BUILD_FLAG_COUNT 3 #define VK_BUILD_FLAG(flag) ((BUILD_FLAGS & flag) != 0) diff --git a/src/vulkan/runtime/bvh/vk_bvh.h b/src/vulkan/runtime/bvh/vk_bvh.h index 40e69c7c250..93de61cfa97 100644 --- a/src/vulkan/runtime/bvh/vk_bvh.h +++ b/src/vulkan/runtime/bvh/vk_bvh.h @@ -174,4 +174,10 @@ struct key32_id_pair { uint32_t key; }; +struct key64_id_pair { + uint32_t id; + uint32_t key_lo; + uint32_t key_hi; +}; + #endif diff --git a/src/vulkan/runtime/vk_acceleration_structure.c b/src/vulkan/runtime/vk_acceleration_structure.c index a4874a995b3..4c4732a64e0 100644 --- a/src/vulkan/runtime/vk_acceleration_structure.c +++ b/src/vulkan/runtime/vk_acceleration_structure.c @@ -135,8 +135,6 @@ vk_common_GetAccelerationStructureDeviceAddressKHR( return vk_acceleration_structure_get_va(accel_struct); } -#define MORTON_BIT_SIZE 24 - static void vk_acceleration_structure_build_state_init(struct vk_acceleration_structure_build_state *state, struct vk_device *device, uint32_t leaf_count, @@ -173,7 +171,7 @@ vk_acceleration_structure_build_state_init(struct vk_acceleration_structure_buil radix_sort_vk_memory_requirements_t requirements = { 0, }; - radix_sort_vk_get_memory_requirements(args->radix_sort_64, leaf_count, + radix_sort_vk_get_memory_requirements(state->config.u64_keys ? args->radix_sort_96 : args->radix_sort_64, leaf_count, &requirements); uint32_t ir_leaf_size; @@ -516,7 +514,7 @@ vk_accel_struct_cmd_end_debug_marker(VkCommandBuffer commandBuffer, device->dispatch_table.CmdDebugMarkerEndEXT(commandBuffer); } -#define VK_BUILD_LEAVES_FLAGS (VK_BUILD_FLAG_ALWAYS_ACTIVE | VK_BUILD_FLAG_PROPAGATE_CULL_FLAGS) +#define VK_BUILD_LEAVES_FLAGS (VK_BUILD_FLAG_ALWAYS_ACTIVE | VK_BUILD_FLAG_PROPAGATE_CULL_FLAGS | VK_BUILD_FLAG_64BIT_KEYS) static VkResult build_leaves(VkCommandBuffer commandBuffer, struct vk_device *device, @@ -610,6 +608,8 @@ build_leaves(VkCommandBuffer commandBuffer, struct vk_device *device, return VK_SUCCESS; } +#define VK_MORTON_GENERATE_FLAGS (VK_BUILD_FLAG_64BIT_KEYS) + static VkResult morton_generate(VkCommandBuffer commandBuffer, struct vk_device *device, struct vk_meta_device *meta, @@ -622,8 +622,8 @@ morton_generate(VkCommandBuffer commandBuffer, struct vk_device *device, VkResult result = vk_get_bvh_build_pipeline_spv(device, meta, VK_META_OBJECT_KEY_MORTON, morton_spv, sizeof(morton_spv), - sizeof(struct morton_args), args, 0, - &pipeline, + sizeof(struct morton_args), args, + build_flags, &pipeline, true /* unaligned_dispatch */); if (result != VK_SUCCESS) return result; @@ -646,6 +646,8 @@ morton_generate(VkCommandBuffer commandBuffer, struct vk_device *device, for (uint32_t i = 0; i < build_count; ++i) { if (states[i].config.internal_type == VK_INTERNAL_BUILD_TYPE_UPDATE) continue; + if ((states[i].build_flags & VK_MORTON_GENERATE_FLAGS) != build_flags) + continue; uint64_t scratch_addr = states[i].build_info->scratchData.deviceAddress; const struct morton_args consts = { @@ -669,6 +671,8 @@ morton_generate(VkCommandBuffer commandBuffer, struct vk_device *device, return VK_SUCCESS; } +#define VK_MORTON_SORT_FLAGS (VK_BUILD_FLAG_64BIT_KEYS) + static VkResult morton_sort(VkCommandBuffer commandBuffer, struct vk_device *device, struct vk_meta_device *meta, @@ -686,7 +690,12 @@ morton_sort(VkCommandBuffer commandBuffer, struct vk_device *device, } /* Copyright 2019 The Fuchsia Authors. */ + uint32_t key_bits = 24; const radix_sort_vk_t *rs = args->radix_sort_64; + if (build_flags & VK_BUILD_FLAG_64BIT_KEYS) { + key_bits = 64; + rs = args->radix_sort_96; + } /* * OVERVIEW @@ -706,8 +715,6 @@ morton_sort(VkCommandBuffer commandBuffer, struct vk_device *device, /* How many passes? */ uint32_t keyval_bytes = rs->config.keyval_dwords * (uint32_t)sizeof(uint32_t); - uint32_t keyval_bits = keyval_bytes * 8; - uint32_t key_bits = MIN2(MORTON_BIT_SIZE, keyval_bits); uint32_t passes = (key_bits + RS_RADIX_LOG2 - 1) / RS_RADIX_LOG2; for (uint32_t i = 0; i < build_count; ++i) { @@ -749,6 +756,8 @@ morton_sort(VkCommandBuffer commandBuffer, struct vk_device *device, continue; if (states[i].config.internal_type == VK_INTERNAL_BUILD_TYPE_UPDATE) continue; + if ((states[i].build_flags & VK_MORTON_SORT_FLAGS) != build_flags) + continue; uint64_t scratch_addr = states[i].build_info->scratchData.deviceAddress; uint64_t keyvals_even_addr = scratch_addr + states[i].scratch.sort_buffer_offset[0]; @@ -804,6 +813,8 @@ morton_sort(VkCommandBuffer commandBuffer, struct vk_device *device, continue; if (states[i].config.internal_type == VK_INTERNAL_BUILD_TYPE_UPDATE) continue; + if ((states[i].build_flags & VK_MORTON_SORT_FLAGS) != build_flags) + continue; uint64_t scratch_addr = states[i].build_info->scratchData.deviceAddress; uint64_t keyvals_even_addr = scratch_addr + states[i].scratch.sort_buffer_offset[0]; @@ -837,6 +848,8 @@ morton_sort(VkCommandBuffer commandBuffer, struct vk_device *device, continue; if (states[i].config.internal_type == VK_INTERNAL_BUILD_TYPE_UPDATE) continue; + if ((states[i].build_flags & VK_MORTON_SORT_FLAGS) != build_flags) + continue; uint64_t internal_addr = states[i].build_info->scratchData.deviceAddress + states[i].scratch.sort_internal_offset; @@ -889,6 +902,8 @@ morton_sort(VkCommandBuffer commandBuffer, struct vk_device *device, continue; if (states[i].config.internal_type == VK_INTERNAL_BUILD_TYPE_UPDATE) continue; + if ((states[i].build_flags & VK_MORTON_SORT_FLAGS) != build_flags) + continue; states[i].push_scatter.pass_offset = (pass_idx & 3) * RS_RADIX_LOG2; @@ -919,7 +934,7 @@ morton_sort(VkCommandBuffer commandBuffer, struct vk_device *device, return VK_SUCCESS; } -#define VK_LBVH_BUILD_INTERNAL_FLAGS (VK_BUILD_FLAG_PROPAGATE_CULL_FLAGS) +#define VK_LBVH_BUILD_INTERNAL_FLAGS (VK_BUILD_FLAG_PROPAGATE_CULL_FLAGS | VK_BUILD_FLAG_64BIT_KEYS) static VkResult lbvh_build_internal(VkCommandBuffer commandBuffer, struct vk_device *device, @@ -957,6 +972,8 @@ lbvh_build_internal(VkCommandBuffer commandBuffer, struct vk_device *device, for (uint32_t i = 0; i < build_count; ++i) { if (states[i].config.internal_type != VK_INTERNAL_BUILD_TYPE_LBVH) continue; + if ((states[i].build_flags & VK_LBVH_BUILD_INTERNAL_FLAGS) != build_flags) + continue; uint32_t src_scratch_offset = states[i].scratch_offset; @@ -993,6 +1010,8 @@ lbvh_build_internal(VkCommandBuffer commandBuffer, struct vk_device *device, for (uint32_t i = 0; i < build_count; ++i) { if (states[i].config.internal_type != VK_INTERNAL_BUILD_TYPE_LBVH) continue; + if ((states[i].build_flags & VK_LBVH_BUILD_INTERNAL_FLAGS) != build_flags) + continue; uint64_t scratch_addr = states[i].build_info->scratchData.deviceAddress; const struct lbvh_generate_ir_args consts = { @@ -1017,7 +1036,7 @@ lbvh_build_internal(VkCommandBuffer commandBuffer, struct vk_device *device, return VK_SUCCESS; } -#define VK_PLOC_BUILD_INTERNAL_FLAGS (VK_BUILD_FLAG_PROPAGATE_CULL_FLAGS) +#define VK_PLOC_BUILD_INTERNAL_FLAGS (VK_BUILD_FLAG_PROPAGATE_CULL_FLAGS | VK_BUILD_FLAG_64BIT_KEYS) static VkResult ploc_build_internal(VkCommandBuffer commandBuffer, struct vk_device *device, @@ -1054,6 +1073,8 @@ ploc_build_internal(VkCommandBuffer commandBuffer, struct vk_device *device, for (uint32_t i = 0; i < build_count; ++i) { if (states[i].config.internal_type != VK_INTERNAL_BUILD_TYPE_PLOC) continue; + if ((states[i].build_flags & VK_PLOC_BUILD_INTERNAL_FLAGS) != build_flags) + continue; uint32_t src_scratch_offset = states[i].scratch_offset; uint32_t dst_scratch_offset = (src_scratch_offset == states[i].scratch.sort_buffer_offset[0]) @@ -1085,7 +1106,7 @@ ploc_build_internal(VkCommandBuffer commandBuffer, struct vk_device *device, return VK_SUCCESS; } -#define VK_HPLOC_BUILD_INTERNAL_FLAGS (VK_BUILD_FLAG_PROPAGATE_CULL_FLAGS) +#define VK_HPLOC_BUILD_INTERNAL_FLAGS (VK_BUILD_FLAG_PROPAGATE_CULL_FLAGS | VK_BUILD_FLAG_64BIT_KEYS) static VkResult hploc_build_internal(VkCommandBuffer commandBuffer, struct vk_device *device, @@ -1097,13 +1118,9 @@ hploc_build_internal(VkCommandBuffer commandBuffer, struct vk_device *device, VkPipeline pipeline; VkPipelineLayout layout; - uint32_t flags = 0; - if (args->propagate_cull_flags) - flags |= VK_BUILD_FLAG_PROPAGATE_CULL_FLAGS; - VkResult result = vk_get_bvh_build_pipeline_spv(device, meta, VK_META_OBJECT_KEY_HPLOC, hploc_spv, sizeof(hploc_spv), sizeof(struct hploc_args), - args, flags, &pipeline, + args, build_flags, &pipeline, false /* unaligned_dispatch */); if (result != VK_SUCCESS) return result; @@ -1126,6 +1143,8 @@ hploc_build_internal(VkCommandBuffer commandBuffer, struct vk_device *device, for (uint32_t i = 0; i < build_count; ++i) { if (states[i].config.internal_type != VK_INTERNAL_BUILD_TYPE_HPLOC) continue; + if ((states[i].build_flags & VK_HPLOC_BUILD_INTERNAL_FLAGS) != build_flags) + continue; assert(args->subgroup_size <= 64); @@ -1251,6 +1270,8 @@ vk_cmd_build_acceleration_structures(VkCommandBuffer commandBuffer, states[i].build_flags |= VK_BUILD_FLAG_ALWAYS_ACTIVE; if (args->propagate_cull_flags) states[i].build_flags |= VK_BUILD_FLAG_PROPAGATE_CULL_FLAGS; + if (states[i].config.u64_keys) + states[i].build_flags |= VK_BUILD_FLAG_64BIT_KEYS; if (states[i].config.internal_type != VK_INTERNAL_BUILD_TYPE_UPDATE) { /* The internal node count is updated in lbvh_build_internal for LBVH @@ -1312,7 +1333,7 @@ vk_cmd_build_acceleration_structures(VkCommandBuffer commandBuffer, vk_barrier_compute_w_to_compute_r(commandBuffer); - result = vk_build_stage(morton_generate, commandBuffer, device, meta, args, states, infoCount, 0); + result = vk_build_stage(morton_generate, commandBuffer, device, meta, args, states, infoCount, VK_MORTON_GENERATE_FLAGS); if (result != VK_SUCCESS) { free(states); vk_command_buffer_set_error(cmd_buffer, result); @@ -1321,7 +1342,7 @@ vk_cmd_build_acceleration_structures(VkCommandBuffer commandBuffer, vk_barrier_compute_w_to_compute_r(commandBuffer); - vk_build_stage(morton_sort, commandBuffer, device, meta, args, states, infoCount, 0); + vk_build_stage(morton_sort, commandBuffer, device, meta, args, states, infoCount, VK_MORTON_SORT_FLAGS); vk_barrier_compute_w_to_compute_r(commandBuffer); diff --git a/src/vulkan/runtime/vk_acceleration_structure.h b/src/vulkan/runtime/vk_acceleration_structure.h index cfbd1aa9ff3..4f113494656 100644 --- a/src/vulkan/runtime/vk_acceleration_structure.h +++ b/src/vulkan/runtime/vk_acceleration_structure.h @@ -103,6 +103,7 @@ enum vk_internal_build_type { struct vk_build_config { enum vk_internal_build_type internal_type; bool updateable; + bool u64_keys; uint32_t encode_key[MAX_ENCODE_PASSES]; uint32_t update_key[MAX_ENCODE_PASSES]; }; @@ -180,6 +181,7 @@ struct vk_acceleration_structure_build_args { bool propagate_cull_flags; bool emit_markers; const radix_sort_vk_t *radix_sort_64; + const radix_sort_vk_t *radix_sort_96; }; VkResult vk_get_bvh_build_pipeline_layout(struct vk_device *device, struct vk_meta_device *meta,