vulkan: Implement 64-bit morton codes

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/41300>
This commit is contained in:
Konstantin Seurer 2025-12-17 19:22:02 +01:00 committed by Marge Bot
parent 74e21c2c59
commit c432ffc5ce
10 changed files with 192 additions and 59 deletions

View file

@ -39,10 +39,20 @@ delta(uint32_t index)
uint32_t left_index = index;
uint32_t right_index = index + 1;
uint32_t left_key = DEREF(INDEX(key32_id_pair, args.ids, left_index)).key;
uint32_t right_key = DEREF(INDEX(key32_id_pair, args.ids, right_index)).key;
if (VK_BUILD_FLAG(VK_BUILD_FLAG_64BIT_KEYS)) {
uint64_t left_key = packUint2x32(u32vec2(
DEREF(INDEX(key64_id_pair, args.ids, left_index)).key_lo,
DEREF(INDEX(key64_id_pair, args.ids, left_index)).key_hi));
uint64_t right_key = packUint2x32(u32vec2(
DEREF(INDEX(key64_id_pair, args.ids, right_index)).key_lo,
DEREF(INDEX(key64_id_pair, args.ids, right_index)).key_hi));
return left_key != right_key ? (32 + uint32_t(findMSB(left_key ^ right_key))) : uint32_t(findMSB(left_index ^ right_index));
} else {
uint32_t left_key = DEREF(INDEX(key32_id_pair, args.ids, left_index)).key;
uint32_t right_key = DEREF(INDEX(key32_id_pair, args.ids, right_index)).key;
return left_key != right_key ? (32 + findMSB(left_key ^ right_key)) : findMSB(left_index ^ right_index);
return left_key != right_key ? (32 + findMSB(left_key ^ right_key)) : findMSB(left_index ^ right_index);
}
}
#define SEARCH_RADIUS 16
@ -67,8 +77,13 @@ main(void)
uint32_t child_id = VK_BVH_INVALID_NODE;
vk_ir_node child = vk_ir_node(vk_aabb(vec3(0.0), vec3(0.0)));
if (active_leaf_count > 0) {
REF(key32_id_pair) key_id = INDEX(key32_id_pair, args.ids, global_id);
child_id = DEREF(key_id).id;
if (VK_BUILD_FLAG(VK_BUILD_FLAG_64BIT_KEYS)) {
REF(key64_id_pair) key_id = INDEX(key64_id_pair, args.ids, global_id);
child_id = DEREF(key_id).id;
} else {
REF(key32_id_pair) key_id = INDEX(key32_id_pair, args.ids, global_id);
child_id = DEREF(key_id).id;
}
child = DEREF(REF(vk_ir_node)(OFFSET(args.bvh, ir_id_to_offset(child_id))));
}
@ -142,8 +157,13 @@ main(void)
}
uint32_t node_id_index = load_base + load_index;
uint32_t node_id = VK_BVH_INVALID_NODE;
if (load_index < index_range)
node_id = DEREF(INDEX(key32_id_pair, args.ids, node_id_index)).id;
if (load_index < index_range) {
if (VK_BUILD_FLAG(VK_BUILD_FLAG_64BIT_KEYS)) {
node_id = DEREF(INDEX(key64_id_pair, args.ids, node_id_index)).id;
} else {
node_id = DEREF(INDEX(key32_id_pair, args.ids, node_id_index)).id;
}
}
uvec4 node_valid_mask = subgroupBallot(node_id != VK_BVH_INVALID_NODE);
uint32_t node_prefix_sum = subgroupBallotExclusiveBitCount(node_valid_mask);
@ -225,7 +245,11 @@ main(void)
if (gl_SubgroupInvocationID < min(end - start + 1, cluster_threshold)) {
uint32_t dst_node = gl_SubgroupInvocationID < node_count ? node_ids[gl_SubgroupInvocationID] : VK_BVH_INVALID_NODE;
DEREF(INDEX(key32_id_pair, args.ids, start + gl_SubgroupInvocationID)).id = dst_node;
if (VK_BUILD_FLAG(VK_BUILD_FLAG_64BIT_KEYS)) {
DEREF(INDEX(key64_id_pair, args.ids, start + gl_SubgroupInvocationID)).id = dst_node;
} else {
DEREF(INDEX(key32_id_pair, args.ids, start + gl_SubgroupInvocationID)).id = dst_node;
}
}
memoryBarrier(gl_ScopeDevice, gl_StorageSemanticsBuffer,

View file

@ -35,21 +35,37 @@ layout(push_constant) uniform CONSTS
};
int32_t
longest_common_prefix(int32_t i, uint32_t key_i, int32_t j, uint32_t active_leaf_count)
longest_common_prefix(int32_t i, uint32_t key_i_lo, uint32_t key_i_hi, int32_t j, uint32_t active_leaf_count)
{
if (j < 0 || j >= active_leaf_count)
return -1;
uint32_t key_j = DEREF(INDEX(key32_id_pair, args.src_ids, j)).key;
if (VK_BUILD_FLAG(VK_BUILD_FLAG_64BIT_KEYS)) {
uint64_t key_i = packUint2x32(u32vec2(key_i_lo, key_i_hi));
key64_id_pair key_id_j = DEREF(INDEX(key64_id_pair, args.src_ids, j));
uint64_t key_j = packUint2x32(u32vec2(key_id_j.key_lo, key_id_j.key_hi));
uint32_t diff = key_i ^ key_j;
int32_t ret = 0;
if (key_i == key_j) {
ret += 32;
diff = i ^ j;
uint64_t diff = key_i ^ key_j;
int32_t ret = 0;
if (key_i == key_j) {
ret += 64;
diff = i ^ j;
}
return ret + 63 - int32_t(findMSB(diff));
} else {
uint32_t key_i = key_i_lo;
uint32_t key_j = DEREF(INDEX(key32_id_pair, args.src_ids, j)).key;
uint32_t diff = key_i ^ key_j;
int32_t ret = 0;
if (key_i == key_j) {
ret += 32;
diff = i ^ j;
}
return ret + 31 - findMSB(diff);
}
return ret + 31 - findMSB(diff);
}
/*
@ -84,8 +100,9 @@ main()
REF(lbvh_node_info) dst = REF(lbvh_node_info)(args.node_info);
DEREF(dst).parent = VK_BVH_INVALID_NODE;
DEREF(dst).path_count = 2;
DEREF(dst).children[0] =
active_leaf_count == 1 ? DEREF(INDEX(key32_id_pair, args.src_ids, 0)).id : VK_BVH_INVALID_NODE;
uint32_t id = VK_BUILD_FLAG(VK_BUILD_FLAG_64BIT_KEYS) ? DEREF(INDEX(key64_id_pair, args.src_ids, 0)).id
: DEREF(INDEX(key32_id_pair, args.src_ids, 0)).id;
DEREF(dst).children[0] = active_leaf_count == 1 ? id : VK_BVH_INVALID_NODE;
DEREF(dst).children[1] = VK_BVH_INVALID_NODE;
return;
}
@ -95,10 +112,18 @@ main()
if (id >= internal_node_count)
return;
uint32_t id_key = DEREF(INDEX(key32_id_pair, args.src_ids, id)).key;
uint32_t key_lo;
uint32_t key_hi;
if (VK_BUILD_FLAG(VK_BUILD_FLAG_64BIT_KEYS)) {
key64_id_pair key_id = DEREF(INDEX(key64_id_pair, args.src_ids, id));
key_lo = key_id.key_lo;
key_hi = key_id.key_hi;
} else {
key_lo = DEREF(INDEX(key32_id_pair, args.src_ids, id)).key;
}
int32_t left_lcp = longest_common_prefix(id, id_key, id - 1, active_leaf_count);
int32_t right_lcp = longest_common_prefix(id, id_key, id + 1, active_leaf_count);
int32_t left_lcp = longest_common_prefix(id, key_lo, key_hi, id - 1, active_leaf_count);
int32_t right_lcp = longest_common_prefix(id, key_lo, key_hi, id + 1, active_leaf_count);
int32_t dir = right_lcp > left_lcp ? 1 : -1;
int32_t lcp_min = min(left_lcp, right_lcp);
@ -106,13 +131,13 @@ main()
* this subtree is going to own.
*/
int32_t lmax = 128;
while (longest_common_prefix(id, id_key, id + dir * lmax, active_leaf_count) > lcp_min) {
while (longest_common_prefix(id, key_lo, key_hi, id + dir * lmax, active_leaf_count) > lcp_min) {
lmax *= 2;
}
int32_t length = 0;
for (int32_t t = lmax / 2; t >= 1; t /= 2) {
if (longest_common_prefix(id, id_key, id + (length + t) * dir, active_leaf_count) > lcp_min)
if (longest_common_prefix(id, key_lo, key_hi, id + (length + t) * dir, active_leaf_count) > lcp_min)
length += t;
}
int32_t other_end = id + length * dir;
@ -120,11 +145,11 @@ main()
/* The number of bits in the prefix that is the same for all elements in the
* range.
*/
int32_t lcp_node = longest_common_prefix(id, id_key, other_end, active_leaf_count);
int32_t lcp_node = longest_common_prefix(id, key_lo, key_hi, other_end, active_leaf_count);
int32_t child_range = 0;
for (int32_t diff = 2; diff < 2 * length; diff *= 2) {
int32_t t = DIV_ROUND_UP(length, diff);
if (longest_common_prefix(id, id_key, id + (child_range + t) * dir, active_leaf_count) > lcp_node)
if (longest_common_prefix(id, key_lo, key_hi, id + (child_range + t) * dir, active_leaf_count) > lcp_node)
child_range += t;
}
@ -145,8 +170,13 @@ main()
REF(lbvh_node_info) dst = INDEX(lbvh_node_info, args.node_info, id);
DEREF(dst).path_count = (left_leaf ? 1 : 0) + (right_leaf ? 1 : 0);
DEREF(dst).children[0] = DEREF(INDEX(key32_id_pair, args.src_ids, left)).id;
DEREF(dst).children[1] = DEREF(INDEX(key32_id_pair, args.src_ids, right)).id;
if (VK_BUILD_FLAG(VK_BUILD_FLAG_64BIT_KEYS)) {
DEREF(dst).children[0] = DEREF(INDEX(key64_id_pair, args.src_ids, left)).id;
DEREF(dst).children[1] = DEREF(INDEX(key64_id_pair, args.src_ids, right)).id;
} else {
DEREF(dst).children[0] = DEREF(INDEX(key32_id_pair, args.src_ids, left)).id;
DEREF(dst).children[1] = DEREF(INDEX(key32_id_pair, args.src_ids, right)).id;
}
if (id == 0)
DEREF(dst).parent = VK_BVH_INVALID_NODE;
}

View file

@ -204,7 +204,6 @@ main(void)
uint32_t global_id = gl_GlobalInvocationID.x;
uint32_t primitive_id = args.geom_data.first_id + global_id;
REF(key32_id_pair) id_ptr = INDEX(key32_id_pair, args.ids, primitive_id);
uint32_t src_offset = global_id * args.geom_data.stride;
uint32_t dst_stride;
@ -243,7 +242,11 @@ main(void)
if (VK_BUILD_FLAG(VK_BUILD_FLAG_ALWAYS_ACTIVE))
is_active = true;
DEREF(id_ptr).id = is_active ? pack_ir_node_id(dst_offset, node_type) : VK_BVH_INVALID_NODE;
uint32_t id = is_active ? pack_ir_node_id(dst_offset, node_type) : VK_BVH_INVALID_NODE;
if (VK_BUILD_FLAG(VK_BUILD_FLAG_64BIT_KEYS))
DEREF(INDEX(key64_id_pair, args.ids, primitive_id)).id = id;
else
DEREF(INDEX(key32_id_pair, args.ids, primitive_id)).id = id;
uvec4 ballot = subgroupBallot(is_active);
if (subgroupElect())

View file

@ -34,34 +34,61 @@ layout(push_constant) uniform CONSTS {
uint32_t
morton_component(uint32_t x)
{
x = (x * 0x00010001u) & 0xFF0000FFu;
x = (x * 0x00000101u) & 0x0F00F00Fu;
x = (x * 0x00000011u) & 0xC30C30C3u;
x = (x * 0x00000005u) & 0x49249249u;
return x;
}
uint64_t
morton_component64(uint64_t x)
{
x = (x | x << 32) & 0x1f00000000fffful;
x = (x | x << 16) & 0x1f0000ff0000fful;
x = (x | x << 8) & 0x100f00f00f00f00ful;
x = (x | x << 4) & 0x10c30c30c30c30c3ul;
x = (x | x << 2) & 0x1249249249249249ul;
return x;
}
uint32_t
morton_code(uint32_t x, uint32_t y, uint32_t z)
{
return (morton_component(x) << 2) | (morton_component(y) << 1) | morton_component(z);
}
uint64_t
morton_code64(uint64_t x, uint64_t y, uint64_t z)
{
return (morton_component64(x) << 2) | (morton_component64(y) << 1) | morton_component64(z);
}
uint32_t
lbvh_key(float x01, float y01, float z01)
{
return morton_code(uint32_t(x01 * 255.0), uint32_t(y01 * 255.0), uint32_t(z01 * 255.0)) << 8;
}
uint64_t
lbvh_key64(float x01, float y01, float z01)
{
return morton_code64(uint64_t(x01 * float(0x1fffff)), uint64_t(y01 * float(0x1fffff)), uint64_t(z01 * float(0x1fffff)));
}
void
main(void)
{
uint32_t global_id = gl_GlobalInvocationID.x;
REF(key32_id_pair) key_id = INDEX(key32_id_pair, args.ids, global_id);
uint32_t id = DEREF(key_id).id;
uint32_t id;
if (VK_BUILD_FLAG(VK_BUILD_FLAG_64BIT_KEYS))
id = DEREF(INDEX(key64_id_pair, args.ids, global_id)).id;
else
id = DEREF(INDEX(key32_id_pair, args.ids, global_id)).id;
uint32_t key;
uint64_t key64;
if (id != VK_BVH_INVALID_NODE) {
vk_aabb bounds = DEREF(REF(vk_ir_node)OFFSET(args.bvh, ir_id_to_offset(id))).aabb;
vec3 center = (bounds.min + bounds.max) * 0.5;
@ -76,12 +103,22 @@ main(void)
vec3 normalized_center = (center - bvh_bounds.min) / (bvh_bounds.max - bvh_bounds.min);
key = lbvh_key(normalized_center.x, normalized_center.y, normalized_center.z);
if (VK_BUILD_FLAG(VK_BUILD_FLAG_64BIT_KEYS))
key64 = lbvh_key64(normalized_center.x, normalized_center.y, normalized_center.z);
else
key = lbvh_key(normalized_center.x, normalized_center.y, normalized_center.z);
} else {
/* Move null instances to the end to avoid mixing null instances with active instances. This
* way, we can skip early during traversal.
*/
key = 0xFFFFFFFF;
key = 0xFFFFFFFFu;
key64 = 0xFFFFFFFFFFFFFFFFul;
}
if (VK_BUILD_FLAG(VK_BUILD_FLAG_64BIT_KEYS)) {
REF(key64_id_pair) key_id = INDEX(key64_id_pair, args.ids, global_id);
DEREF(key_id).key_lo = unpackUint2x32(key64).x;
DEREF(key_id).key_hi = unpackUint2x32(key64).y;
} else {
DEREF(INDEX(key32_id_pair, args.ids, global_id)).key = key;
}
DEREF(key_id).key = key;
}

View file

@ -177,10 +177,14 @@ shared uint32_t nearest_neighbour_indices[NUM_PLOC_LDS_ITEMS];
uint32_t
load_id(VOID_REF ids, uint32_t iter, uint32_t index)
{
if (iter == 0)
return DEREF(REF(key32_id_pair)(INDEX(key32_id_pair, ids, index))).id;
else
if (iter == 0) {
if (VK_BUILD_FLAG(VK_BUILD_FLAG_64BIT_KEYS))
return DEREF(REF(key64_id_pair)(INDEX(key64_id_pair, ids, index))).id;
else
return DEREF(REF(key32_id_pair)(INDEX(key32_id_pair, ids, index))).id;
} else {
return DEREF(REF(uint32_t)(INDEX(uint32_t, ids, index)));
}
}
void
@ -231,7 +235,11 @@ main(void)
uint32_t i = 0;
for (; i < DEREF(args.header).active_leaf_count; i++) {
uint32_t child_id = DEREF(INDEX(key32_id_pair, src_ids, i)).id;
uint32_t child_id;
if (VK_BUILD_FLAG(VK_BUILD_FLAG_64BIT_KEYS))
child_id = DEREF(INDEX(key64_id_pair, src_ids, i)).id;
else
child_id = DEREF(INDEX(key32_id_pair, src_ids, i)).id;
if (child_id != VK_BVH_INVALID_NODE) {
VOID_REF node = OFFSET(args.bvh, ir_id_to_offset(child_id));

View file

@ -235,6 +235,7 @@ from_emulated_float(int32_t bits)
TYPE(vk_aabb, 4);
TYPE(key32_id_pair, 4);
TYPE(key64_id_pair, 4);
TYPE(vk_accel_struct_serialization_header, 8);

View file

@ -46,7 +46,8 @@ layout (constant_id = ROOT_FLAGS_OFFSET_ID) const int ROOT_FLAGS_OFFSET = -1;
#define VK_BUILD_FLAG_ALWAYS_ACTIVE (1u << 0)
#define VK_BUILD_FLAG_PROPAGATE_CULL_FLAGS (1u << 1)
#define VK_BUILD_FLAG_COUNT 2
#define VK_BUILD_FLAG_64BIT_KEYS (1u << 2)
#define VK_BUILD_FLAG_COUNT 3
#define VK_BUILD_FLAG(flag) ((BUILD_FLAGS & flag) != 0)

View file

@ -174,4 +174,10 @@ struct key32_id_pair {
uint32_t key;
};
struct key64_id_pair {
uint32_t id;
uint32_t key_lo;
uint32_t key_hi;
};
#endif

View file

@ -135,8 +135,6 @@ vk_common_GetAccelerationStructureDeviceAddressKHR(
return vk_acceleration_structure_get_va(accel_struct);
}
#define MORTON_BIT_SIZE 24
static void
vk_acceleration_structure_build_state_init(struct vk_acceleration_structure_build_state *state,
struct vk_device *device, uint32_t leaf_count,
@ -173,7 +171,7 @@ vk_acceleration_structure_build_state_init(struct vk_acceleration_structure_buil
radix_sort_vk_memory_requirements_t requirements = {
0,
};
radix_sort_vk_get_memory_requirements(args->radix_sort_64, leaf_count,
radix_sort_vk_get_memory_requirements(state->config.u64_keys ? args->radix_sort_96 : args->radix_sort_64, leaf_count,
&requirements);
uint32_t ir_leaf_size;
@ -516,7 +514,7 @@ vk_accel_struct_cmd_end_debug_marker(VkCommandBuffer commandBuffer,
device->dispatch_table.CmdDebugMarkerEndEXT(commandBuffer);
}
#define VK_BUILD_LEAVES_FLAGS (VK_BUILD_FLAG_ALWAYS_ACTIVE | VK_BUILD_FLAG_PROPAGATE_CULL_FLAGS)
#define VK_BUILD_LEAVES_FLAGS (VK_BUILD_FLAG_ALWAYS_ACTIVE | VK_BUILD_FLAG_PROPAGATE_CULL_FLAGS | VK_BUILD_FLAG_64BIT_KEYS)
static VkResult
build_leaves(VkCommandBuffer commandBuffer, struct vk_device *device,
@ -610,6 +608,8 @@ build_leaves(VkCommandBuffer commandBuffer, struct vk_device *device,
return VK_SUCCESS;
}
#define VK_MORTON_GENERATE_FLAGS (VK_BUILD_FLAG_64BIT_KEYS)
static VkResult
morton_generate(VkCommandBuffer commandBuffer, struct vk_device *device,
struct vk_meta_device *meta,
@ -622,8 +622,8 @@ morton_generate(VkCommandBuffer commandBuffer, struct vk_device *device,
VkResult result = vk_get_bvh_build_pipeline_spv(device, meta, VK_META_OBJECT_KEY_MORTON,
morton_spv, sizeof(morton_spv),
sizeof(struct morton_args), args, 0,
&pipeline,
sizeof(struct morton_args), args,
build_flags, &pipeline,
true /* unaligned_dispatch */);
if (result != VK_SUCCESS)
return result;
@ -646,6 +646,8 @@ morton_generate(VkCommandBuffer commandBuffer, struct vk_device *device,
for (uint32_t i = 0; i < build_count; ++i) {
if (states[i].config.internal_type == VK_INTERNAL_BUILD_TYPE_UPDATE)
continue;
if ((states[i].build_flags & VK_MORTON_GENERATE_FLAGS) != build_flags)
continue;
uint64_t scratch_addr = states[i].build_info->scratchData.deviceAddress;
const struct morton_args consts = {
@ -669,6 +671,8 @@ morton_generate(VkCommandBuffer commandBuffer, struct vk_device *device,
return VK_SUCCESS;
}
#define VK_MORTON_SORT_FLAGS (VK_BUILD_FLAG_64BIT_KEYS)
static VkResult
morton_sort(VkCommandBuffer commandBuffer, struct vk_device *device,
struct vk_meta_device *meta,
@ -686,7 +690,12 @@ morton_sort(VkCommandBuffer commandBuffer, struct vk_device *device,
}
/* Copyright 2019 The Fuchsia Authors. */
uint32_t key_bits = 24;
const radix_sort_vk_t *rs = args->radix_sort_64;
if (build_flags & VK_BUILD_FLAG_64BIT_KEYS) {
key_bits = 64;
rs = args->radix_sort_96;
}
/*
* OVERVIEW
@ -706,8 +715,6 @@ morton_sort(VkCommandBuffer commandBuffer, struct vk_device *device,
/* How many passes? */
uint32_t keyval_bytes = rs->config.keyval_dwords * (uint32_t)sizeof(uint32_t);
uint32_t keyval_bits = keyval_bytes * 8;
uint32_t key_bits = MIN2(MORTON_BIT_SIZE, keyval_bits);
uint32_t passes = (key_bits + RS_RADIX_LOG2 - 1) / RS_RADIX_LOG2;
for (uint32_t i = 0; i < build_count; ++i) {
@ -749,6 +756,8 @@ morton_sort(VkCommandBuffer commandBuffer, struct vk_device *device,
continue;
if (states[i].config.internal_type == VK_INTERNAL_BUILD_TYPE_UPDATE)
continue;
if ((states[i].build_flags & VK_MORTON_SORT_FLAGS) != build_flags)
continue;
uint64_t scratch_addr = states[i].build_info->scratchData.deviceAddress;
uint64_t keyvals_even_addr = scratch_addr + states[i].scratch.sort_buffer_offset[0];
@ -804,6 +813,8 @@ morton_sort(VkCommandBuffer commandBuffer, struct vk_device *device,
continue;
if (states[i].config.internal_type == VK_INTERNAL_BUILD_TYPE_UPDATE)
continue;
if ((states[i].build_flags & VK_MORTON_SORT_FLAGS) != build_flags)
continue;
uint64_t scratch_addr = states[i].build_info->scratchData.deviceAddress;
uint64_t keyvals_even_addr = scratch_addr + states[i].scratch.sort_buffer_offset[0];
@ -837,6 +848,8 @@ morton_sort(VkCommandBuffer commandBuffer, struct vk_device *device,
continue;
if (states[i].config.internal_type == VK_INTERNAL_BUILD_TYPE_UPDATE)
continue;
if ((states[i].build_flags & VK_MORTON_SORT_FLAGS) != build_flags)
continue;
uint64_t internal_addr = states[i].build_info->scratchData.deviceAddress +
states[i].scratch.sort_internal_offset;
@ -889,6 +902,8 @@ morton_sort(VkCommandBuffer commandBuffer, struct vk_device *device,
continue;
if (states[i].config.internal_type == VK_INTERNAL_BUILD_TYPE_UPDATE)
continue;
if ((states[i].build_flags & VK_MORTON_SORT_FLAGS) != build_flags)
continue;
states[i].push_scatter.pass_offset = (pass_idx & 3) * RS_RADIX_LOG2;
@ -919,7 +934,7 @@ morton_sort(VkCommandBuffer commandBuffer, struct vk_device *device,
return VK_SUCCESS;
}
#define VK_LBVH_BUILD_INTERNAL_FLAGS (VK_BUILD_FLAG_PROPAGATE_CULL_FLAGS)
#define VK_LBVH_BUILD_INTERNAL_FLAGS (VK_BUILD_FLAG_PROPAGATE_CULL_FLAGS | VK_BUILD_FLAG_64BIT_KEYS)
static VkResult
lbvh_build_internal(VkCommandBuffer commandBuffer, struct vk_device *device,
@ -957,6 +972,8 @@ lbvh_build_internal(VkCommandBuffer commandBuffer, struct vk_device *device,
for (uint32_t i = 0; i < build_count; ++i) {
if (states[i].config.internal_type != VK_INTERNAL_BUILD_TYPE_LBVH)
continue;
if ((states[i].build_flags & VK_LBVH_BUILD_INTERNAL_FLAGS) != build_flags)
continue;
uint32_t src_scratch_offset = states[i].scratch_offset;
@ -993,6 +1010,8 @@ lbvh_build_internal(VkCommandBuffer commandBuffer, struct vk_device *device,
for (uint32_t i = 0; i < build_count; ++i) {
if (states[i].config.internal_type != VK_INTERNAL_BUILD_TYPE_LBVH)
continue;
if ((states[i].build_flags & VK_LBVH_BUILD_INTERNAL_FLAGS) != build_flags)
continue;
uint64_t scratch_addr = states[i].build_info->scratchData.deviceAddress;
const struct lbvh_generate_ir_args consts = {
@ -1017,7 +1036,7 @@ lbvh_build_internal(VkCommandBuffer commandBuffer, struct vk_device *device,
return VK_SUCCESS;
}
#define VK_PLOC_BUILD_INTERNAL_FLAGS (VK_BUILD_FLAG_PROPAGATE_CULL_FLAGS)
#define VK_PLOC_BUILD_INTERNAL_FLAGS (VK_BUILD_FLAG_PROPAGATE_CULL_FLAGS | VK_BUILD_FLAG_64BIT_KEYS)
static VkResult
ploc_build_internal(VkCommandBuffer commandBuffer, struct vk_device *device,
@ -1054,6 +1073,8 @@ ploc_build_internal(VkCommandBuffer commandBuffer, struct vk_device *device,
for (uint32_t i = 0; i < build_count; ++i) {
if (states[i].config.internal_type != VK_INTERNAL_BUILD_TYPE_PLOC)
continue;
if ((states[i].build_flags & VK_PLOC_BUILD_INTERNAL_FLAGS) != build_flags)
continue;
uint32_t src_scratch_offset = states[i].scratch_offset;
uint32_t dst_scratch_offset = (src_scratch_offset == states[i].scratch.sort_buffer_offset[0])
@ -1085,7 +1106,7 @@ ploc_build_internal(VkCommandBuffer commandBuffer, struct vk_device *device,
return VK_SUCCESS;
}
#define VK_HPLOC_BUILD_INTERNAL_FLAGS (VK_BUILD_FLAG_PROPAGATE_CULL_FLAGS)
#define VK_HPLOC_BUILD_INTERNAL_FLAGS (VK_BUILD_FLAG_PROPAGATE_CULL_FLAGS | VK_BUILD_FLAG_64BIT_KEYS)
static VkResult
hploc_build_internal(VkCommandBuffer commandBuffer, struct vk_device *device,
@ -1097,13 +1118,9 @@ hploc_build_internal(VkCommandBuffer commandBuffer, struct vk_device *device,
VkPipeline pipeline;
VkPipelineLayout layout;
uint32_t flags = 0;
if (args->propagate_cull_flags)
flags |= VK_BUILD_FLAG_PROPAGATE_CULL_FLAGS;
VkResult result = vk_get_bvh_build_pipeline_spv(device, meta, VK_META_OBJECT_KEY_HPLOC, hploc_spv,
sizeof(hploc_spv), sizeof(struct hploc_args),
args, flags, &pipeline,
args, build_flags, &pipeline,
false /* unaligned_dispatch */);
if (result != VK_SUCCESS)
return result;
@ -1126,6 +1143,8 @@ hploc_build_internal(VkCommandBuffer commandBuffer, struct vk_device *device,
for (uint32_t i = 0; i < build_count; ++i) {
if (states[i].config.internal_type != VK_INTERNAL_BUILD_TYPE_HPLOC)
continue;
if ((states[i].build_flags & VK_HPLOC_BUILD_INTERNAL_FLAGS) != build_flags)
continue;
assert(args->subgroup_size <= 64);
@ -1251,6 +1270,8 @@ vk_cmd_build_acceleration_structures(VkCommandBuffer commandBuffer,
states[i].build_flags |= VK_BUILD_FLAG_ALWAYS_ACTIVE;
if (args->propagate_cull_flags)
states[i].build_flags |= VK_BUILD_FLAG_PROPAGATE_CULL_FLAGS;
if (states[i].config.u64_keys)
states[i].build_flags |= VK_BUILD_FLAG_64BIT_KEYS;
if (states[i].config.internal_type != VK_INTERNAL_BUILD_TYPE_UPDATE) {
/* The internal node count is updated in lbvh_build_internal for LBVH
@ -1312,7 +1333,7 @@ vk_cmd_build_acceleration_structures(VkCommandBuffer commandBuffer,
vk_barrier_compute_w_to_compute_r(commandBuffer);
result = vk_build_stage(morton_generate, commandBuffer, device, meta, args, states, infoCount, 0);
result = vk_build_stage(morton_generate, commandBuffer, device, meta, args, states, infoCount, VK_MORTON_GENERATE_FLAGS);
if (result != VK_SUCCESS) {
free(states);
vk_command_buffer_set_error(cmd_buffer, result);
@ -1321,7 +1342,7 @@ vk_cmd_build_acceleration_structures(VkCommandBuffer commandBuffer,
vk_barrier_compute_w_to_compute_r(commandBuffer);
vk_build_stage(morton_sort, commandBuffer, device, meta, args, states, infoCount, 0);
vk_build_stage(morton_sort, commandBuffer, device, meta, args, states, infoCount, VK_MORTON_SORT_FLAGS);
vk_barrier_compute_w_to_compute_r(commandBuffer);

View file

@ -103,6 +103,7 @@ enum vk_internal_build_type {
struct vk_build_config {
enum vk_internal_build_type internal_type;
bool updateable;
bool u64_keys;
uint32_t encode_key[MAX_ENCODE_PASSES];
uint32_t update_key[MAX_ENCODE_PASSES];
};
@ -180,6 +181,7 @@ struct vk_acceleration_structure_build_args {
bool propagate_cull_flags;
bool emit_markers;
const radix_sort_vk_t *radix_sort_64;
const radix_sort_vk_t *radix_sort_96;
};
VkResult vk_get_bvh_build_pipeline_layout(struct vk_device *device, struct vk_meta_device *meta,