diff --git a/src/amd/vulkan/radv_rra_gfx12.c b/src/amd/vulkan/radv_rra_gfx12.c index 2349ae382e7..d0029d8be42 100644 --- a/src/amd/vulkan/radv_rra_gfx12.c +++ b/src/amd/vulkan/radv_rra_gfx12.c @@ -38,6 +38,24 @@ static const char *node_type_names[16] = { [15] = "invalid15", }; +static uint32_t +get_geometry_id(const void *node, uint32_t triangle_index) +{ + uint32_t geometry_index_base_bits = BITSET_EXTRACT(node, 20, 4) * 2; + uint32_t geometry_index_bits = BITSET_EXTRACT(node, 24, 4) * 2; + + uint32_t indices_midpoint = BITSET_EXTRACT(node, 42, 10); + uint32_t geometry_id_base = + BITSET_EXTRACT(node, indices_midpoint - geometry_index_base_bits, geometry_index_base_bits); + + if (triangle_index == 0) + return geometry_id_base; + + return (geometry_id_base & ~BITFIELD64_MASK(geometry_index_bits)) | + BITSET_EXTRACT(node, indices_midpoint - geometry_index_base_bits - geometry_index_bits * triangle_index, + geometry_index_bits); +} + bool rra_validate_node_gfx12(struct hash_table_u64 *accel_struct_vas, uint8_t *data, void *node, uint32_t geometry_count, uint32_t size, bool is_bottom_level, uint32_t depth) @@ -98,13 +116,25 @@ rra_validate_node_gfx12(struct hash_table_u64 *accel_struct_vas, uint8_t *data, rra_validation_fail(&child_ctx, "Invalid blas_addr(0x%llx)", (unsigned long long)blas_va); } else { uint32_t indices_midpoint = BITSET_EXTRACT(child_node, 42, 10); - if (indices_midpoint < 54 + 28) { + if (indices_midpoint < 54) { rra_validation_fail(&child_ctx, "Invalid indices_midpoint(%u)", indices_midpoint); } else { - uint32_t geometry_id = BITSET_EXTRACT(child_node, indices_midpoint - 28, 28); - if (geometry_id >= geometry_count) { - rra_validation_fail(&child_ctx, "Invalid geometry_id(%u) >= geometry_count(%u)", geometry_id, - geometry_count); + uint32_t pair_index = (child_type & 0x3) | ((child_type & 0x8) >> 1); + + if (BITSET_EXTRACT(node, 1024 - 29 * (pair_index + 1) + 17, 12)) { + uint32_t geometry_id = get_geometry_id(node, pair_index * 2 + 0); + if (geometry_id >= geometry_count) { + rra_validation_fail(&child_ctx, "Invalid geometry_id(%u) >= geometry_count(%u)", geometry_id, + geometry_count); + } + } + + if (BITSET_EXTRACT(node, 1024 - 29 * (pair_index + 1) + 3, 12)) { + uint32_t geometry_id = get_geometry_id(node, pair_index * 2 + 1); + if (geometry_id >= geometry_count) { + rra_validation_fail(&child_ctx, "Invalid geometry_id(%u) >= geometry_count(%u)", geometry_id, + geometry_count); + } } } if (!BITSET_TEST((BITSET_WORD *)child_node, 1024 - 29)) @@ -115,24 +145,6 @@ rra_validate_node_gfx12(struct hash_table_u64 *accel_struct_vas, uint8_t *data, return ctx.failed; } -static uint32_t -get_geometry_id(const void *node, uint32_t triangle_index) -{ - uint32_t geometry_index_base_bits = BITSET_EXTRACT(node, 20, 4) * 2; - uint32_t geometry_index_bits = BITSET_EXTRACT(node, 24, 4) * 2; - - uint32_t indices_midpoint = BITSET_EXTRACT(node, 42, 10); - uint32_t geometry_id_base = - BITSET_EXTRACT(node, indices_midpoint - geometry_index_base_bits, geometry_index_base_bits); - - if (triangle_index == 0) - return geometry_id_base; - - return (geometry_id_base & ~BITFIELD64_MASK(geometry_index_bits)) | - BITSET_EXTRACT(node, indices_midpoint - geometry_index_base_bits - geometry_index_bits * triangle_index, - geometry_index_bits); -} - void rra_gather_bvh_info_gfx12(const uint8_t *bvh, uint32_t node_id, struct rra_bvh_info *dst) {