radv: Clean up the accel-struct build shaders

Signed-off-by: Konstantin Seurer <konstantin.seurer@gmail.com>
Reviewed-by: Bas Nieuwenhuizen <bas@basnieuwenhuizen.nl>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/15648>
This commit is contained in:
Konstantin Seurer 2022-04-02 14:47:37 +02:00
parent be57b085be
commit e4a6f09d12

View file

@ -1064,7 +1064,7 @@ id_to_node_id_offset(nir_builder *b, nir_ssa_def *global_id,
uint32_t stride = get_node_id_stride(
get_accel_struct_build(pdevice, VK_ACCELERATION_STRUCTURE_BUILD_TYPE_DEVICE_KHR));
return nir_umul24(b, global_id, nir_imm_int(b, stride));
return nir_imul_imm(b, global_id, stride);
}
static nir_ssa_def *
@ -1077,16 +1077,14 @@ id_to_morton_offset(nir_builder *b, nir_ssa_def *global_id,
uint32_t stride = get_node_id_stride(build_mode);
return nir_iadd_imm(b, nir_umul24(b, global_id, nir_imm_int(b, stride)), sizeof(uint32_t));
return nir_iadd_imm(b, nir_imul_imm(b, global_id, stride), sizeof(uint32_t));
}
static nir_shader *
build_leaf_shader(struct radv_device *dev)
{
const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3);
nir_builder b = radv_meta_init_shader(dev, MESA_SHADER_COMPUTE, "accel_build_leaf_shader");
b.shader->info.workgroup_size[0] = 64;
nir_builder b = create_accel_build_shader(dev, "accel_build_leaf_shader");
nir_ssa_def *pconst0 =
nir_load_push_constant(&b, 4, 32, nir_imm_int(&b, 0), .base = 0, .range = 16);
@ -1096,20 +1094,20 @@ build_leaf_shader(struct radv_device *dev)
nir_load_push_constant(&b, 4, 32, nir_imm_int(&b, 0), .base = 32, .range = 16);
nir_ssa_def *pconst3 =
nir_load_push_constant(&b, 4, 32, nir_imm_int(&b, 0), .base = 48, .range = 16);
nir_ssa_def *pconst4 =
nir_ssa_def *index_format =
nir_load_push_constant(&b, 1, 32, nir_imm_int(&b, 0), .base = 64, .range = 4);
nir_ssa_def *geom_type = nir_channel(&b, pconst1, 2);
nir_ssa_def *node_dst_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst0, 3));
nir_ssa_def *scratch_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst0, 12));
nir_ssa_def *node_dst_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst0, 0b0011));
nir_ssa_def *scratch_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst0, 0b1100));
nir_ssa_def *node_dst_offset = nir_channel(&b, pconst1, 0);
nir_ssa_def *scratch_offset = nir_channel(&b, pconst1, 1);
nir_ssa_def *geom_type = nir_channel(&b, pconst1, 2);
nir_ssa_def *geometry_id = nir_channel(&b, pconst1, 3);
nir_ssa_def *global_id =
nir_iadd(&b,
nir_umul24(&b, nir_channels(&b, nir_load_workgroup_id(&b, 32), 1),
nir_imm_int(&b, b.shader->info.workgroup_size[0])),
nir_imul_imm(&b, nir_channels(&b, nir_load_workgroup_id(&b, 32), 1),
b.shader->info.workgroup_size[0]),
nir_channels(&b, nir_load_local_invocation_id(&b), 1));
nir_ssa_def *scratch_dst_addr =
nir_iadd(&b, scratch_addr,
@ -1123,16 +1121,14 @@ build_leaf_shader(struct radv_device *dev)
nir_push_if(&b, nir_ieq_imm(&b, geom_type, VK_GEOMETRY_TYPE_TRIANGLES_KHR));
{ /* Triangles */
nir_ssa_def *vertex_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst2, 3));
nir_ssa_def *index_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst2, 12));
nir_ssa_def *vertex_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst2, 0b0011));
nir_ssa_def *index_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst2, 0b1100));
nir_ssa_def *transform_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst3, 3));
nir_ssa_def *vertex_stride = nir_channel(&b, pconst3, 2);
nir_ssa_def *vertex_format = nir_channel(&b, pconst3, 3);
nir_ssa_def *index_format = nir_channel(&b, pconst4, 0);
unsigned repl_swizzle[4] = {0, 0, 0, 0};
nir_ssa_def *node_offset =
nir_iadd(&b, node_dst_offset, nir_umul24(&b, global_id, nir_imm_int(&b, 64)));
nir_ssa_def *node_offset = nir_iadd(&b, node_dst_offset, nir_imul_imm(&b, global_id, 64));
nir_ssa_def *triangle_node_dst_addr = nir_iadd(&b, node_dst_addr, nir_u2u64(&b, node_offset));
nir_ssa_def *indices = get_indices(&b, index_addr, index_format, global_id);
@ -1201,11 +1197,10 @@ build_leaf_shader(struct radv_device *dev)
nir_push_else(&b, NULL);
nir_push_if(&b, nir_ieq_imm(&b, geom_type, VK_GEOMETRY_TYPE_AABBS_KHR));
{ /* AABBs */
nir_ssa_def *aabb_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst2, 3));
nir_ssa_def *aabb_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst2, 0b0011));
nir_ssa_def *aabb_stride = nir_channel(&b, pconst2, 2);
nir_ssa_def *node_offset =
nir_iadd(&b, node_dst_offset, nir_umul24(&b, global_id, nir_imm_int(&b, 64)));
nir_ssa_def *node_offset = nir_iadd(&b, node_dst_offset, nir_imul_imm(&b, global_id, 64));
nir_ssa_def *aabb_node_dst_addr = nir_iadd(&b, node_dst_addr, nir_u2u64(&b, node_offset));
nir_ssa_def *node_id = nir_iadd_imm(&b, nir_ushr_imm(&b, node_offset, 3), 7);
@ -1240,8 +1235,8 @@ build_leaf_shader(struct radv_device *dev)
nir_variable_create(b.shader, nir_var_shader_temp, glsl_uint64_t_type(), "instance_addr");
nir_push_if(&b, nir_ine_imm(&b, nir_channel(&b, pconst2, 2), 0));
{
nir_ssa_def *ptr = nir_iadd(&b, nir_pack_64_2x32(&b, nir_channels(&b, pconst2, 3)),
nir_u2u64(&b, nir_imul_imm(&b, global_id, 8)));
nir_ssa_def *ptr = nir_iadd(&b, nir_pack_64_2x32(&b, nir_channels(&b, pconst2, 0b0011)),
nir_u2u64(&b, nir_imul(&b, global_id, nir_imm_int(&b, 8))));
nir_ssa_def *addr =
nir_pack_64_2x32(&b, nir_build_load_global(&b, 2, 32, ptr, .align_mul = 8));
nir_store_var(&b, instance_addr_var, addr, 1);
@ -1261,8 +1256,7 @@ build_leaf_shader(struct radv_device *dev)
nir_build_load_global(&b, 4, 32, nir_iadd_imm(&b, instance_addr, 32))};
nir_ssa_def *inst3 = nir_build_load_global(&b, 4, 32, nir_iadd_imm(&b, instance_addr, 48));
nir_ssa_def *node_offset =
nir_iadd(&b, node_dst_offset, nir_umul24(&b, global_id, nir_imm_int(&b, 128)));
nir_ssa_def *node_offset = nir_iadd(&b, node_dst_offset, nir_imul_imm(&b, global_id, 128));
node_dst_addr = nir_iadd(&b, node_dst_addr, nir_u2u64(&b, node_offset));
nir_ssa_def *node_id = nir_iadd_imm(&b, nir_ushr_imm(&b, node_offset, 3), 6);
@ -1507,9 +1501,7 @@ static nir_shader *
build_internal_shader(struct radv_device *dev)
{
const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3);
nir_builder b = radv_meta_init_shader(dev, MESA_SHADER_COMPUTE, "accel_build_internal_shader");
b.shader->info.workgroup_size[0] = 64;
nir_builder b = create_accel_build_shader(dev, "accel_build_internal_shader");
/*
* push constants:
@ -1525,8 +1517,8 @@ build_internal_shader(struct radv_device *dev)
nir_ssa_def *pconst1 =
nir_load_push_constant(&b, 4, 32, nir_imm_int(&b, 0), .base = 16, .range = 16);
nir_ssa_def *node_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst0, 3));
nir_ssa_def *scratch_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst0, 12));
nir_ssa_def *node_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst0, 0b0011));
nir_ssa_def *scratch_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst0, 0b1100));
nir_ssa_def *node_dst_offset = nir_channel(&b, pconst1, 0);
nir_ssa_def *dst_scratch_offset = nir_channel(&b, pconst1, 1);
nir_ssa_def *src_scratch_offset = nir_channel(&b, pconst1, 2);
@ -1536,8 +1528,8 @@ build_internal_shader(struct radv_device *dev)
nir_ssa_def *global_id =
nir_iadd(&b,
nir_umul24(&b, nir_channels(&b, nir_load_workgroup_id(&b, 32), 1),
nir_imm_int(&b, b.shader->info.workgroup_size[0])),
nir_imul_imm(&b, nir_channels(&b, nir_load_workgroup_id(&b, 32), 1),
b.shader->info.workgroup_size[0]),
nir_channels(&b, nir_load_local_invocation_id(&b), 1));
nir_ssa_def *src_idx = nir_imul_imm(&b, global_id, 4);
nir_ssa_def *src_count = nir_umin(&b, nir_imm_int(&b, 4), nir_isub(&b, src_node_count, src_idx));
@ -1615,8 +1607,7 @@ struct copy_constants {
static nir_shader *
build_copy_shader(struct radv_device *dev)
{
nir_builder b = radv_meta_init_shader(dev, MESA_SHADER_COMPUTE, "accel_copy");
b.shader->info.workgroup_size[0] = 64;
nir_builder b = create_accel_build_shader(dev, "accel_copy");
nir_ssa_def *invoc_id = nir_load_local_invocation_id(&b);
nir_ssa_def *wg_id = nir_load_workgroup_id(&b, 32);
@ -1639,8 +1630,8 @@ build_copy_shader(struct radv_device *dev)
nir_load_push_constant(&b, 4, 32, nir_imm_int(&b, 0), .base = 0, .range = 16);
nir_ssa_def *pconst1 =
nir_load_push_constant(&b, 1, 32, nir_imm_int(&b, 0), .base = 16, .range = 4);
nir_ssa_def *src_base_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst0, 3));
nir_ssa_def *dst_base_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst0, 0xc));
nir_ssa_def *src_base_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst0, 0b0011));
nir_ssa_def *dst_base_addr = nir_pack_64_2x32(&b, nir_channels(&b, pconst0, 0b1100));
nir_ssa_def *mode = nir_channel(&b, pconst1, 0);
nir_variable *compacted_size_var =