diff --git a/src/amd/vulkan/radv_pipeline_rt.c b/src/amd/vulkan/radv_pipeline_rt.c index 79f4651a60d..eeb481e30b8 100644 --- a/src/amd/vulkan/radv_pipeline_rt.c +++ b/src/amd/vulkan/radv_pipeline_rt.c @@ -1422,150 +1422,157 @@ insert_traversal_aabb_case(struct radv_device *device, nir_pop_if(b, NULL); } -static void -insert_traversal(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, - nir_builder *b, const struct rt_variables *vars) +static nir_shader * +build_traversal_shader(struct radv_device *device, + const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, + const struct rt_variables *dst_vars, + struct hash_table *var_remap) { + nir_builder b = radv_meta_init_shader(device, MESA_SHADER_COMPUTE, "rt_traversal"); + b.shader->info.internal = false; + b.shader->info.workgroup_size[0] = 8; + b.shader->info.workgroup_size[1] = device->physical_device->rt_wave_size == 64 ? 8 : 4; + struct rt_variables vars = create_rt_variables(b.shader, dst_vars->stack_sizes); + map_rt_variables(var_remap, &vars, dst_vars); + unsigned stack_entry_size = 4; - unsigned lanes = b->shader->info.workgroup_size[0] * b->shader->info.workgroup_size[1] * - b->shader->info.workgroup_size[2]; + unsigned lanes = device->physical_device->rt_wave_size; unsigned stack_entry_stride = stack_entry_size * lanes; - nir_ssa_def *stack_entry_stride_def = nir_imm_int(b, stack_entry_stride); + nir_ssa_def *stack_entry_stride_def = nir_imm_int(&b, stack_entry_stride); nir_ssa_def *stack_base = - nir_iadd_imm(b, nir_imul_imm(b, nir_load_local_invocation_index(b), stack_entry_size), - b->shader->info.shared_size); + nir_iadd_imm(&b, nir_imul_imm(&b, nir_load_local_invocation_index(&b), stack_entry_size), + b.shader->info.shared_size); - b->shader->info.shared_size += stack_entry_stride * MAX_STACK_ENTRY_COUNT; - assert(b->shader->info.shared_size <= 32768); + b.shader->info.shared_size += stack_entry_stride * MAX_STACK_ENTRY_COUNT; - nir_ssa_def *accel_struct = nir_load_var(b, vars->accel_struct); + nir_ssa_def *accel_struct = nir_load_var(&b, vars.accel_struct); - struct rt_traversal_vars trav_vars = init_traversal_vars(b); + struct rt_traversal_vars trav_vars = init_traversal_vars(&b); /* Initialize the follow-up shader idx to 0, to be replaced by the miss shader * if we actually miss. */ - nir_store_var(b, vars->idx, nir_imm_int(b, 0), 1); + nir_store_var(&b, vars.idx, nir_imm_int(&b, 0), 1); - nir_store_var(b, trav_vars.should_return, nir_imm_bool(b, false), 1); + nir_store_var(&b, trav_vars.should_return, nir_imm_bool(&b, false), 1); - nir_push_if(b, nir_ine_imm(b, accel_struct, 0)); + nir_push_if(&b, nir_ine_imm(&b, accel_struct, 0)); { - nir_store_var(b, trav_vars.bvh_base, build_addr_to_node(b, accel_struct), 1); + nir_store_var(&b, trav_vars.bvh_base, build_addr_to_node(&b, accel_struct), 1); nir_ssa_def *bvh_root = nir_build_load_global( - b, 1, 32, accel_struct, .access = ACCESS_NON_WRITEABLE, .align_mul = 64); + &b, 1, 32, accel_struct, .access = ACCESS_NON_WRITEABLE, .align_mul = 64); - nir_ssa_def *desc = create_bvh_descriptor(b); - nir_ssa_def *vec3ones = nir_channels(b, nir_imm_vec4(b, 1.0, 1.0, 1.0, 1.0), 0x7); + nir_ssa_def *desc = create_bvh_descriptor(&b); + nir_ssa_def *vec3ones = nir_channels(&b, nir_imm_vec4(&b, 1.0, 1.0, 1.0, 1.0), 0x7); - nir_store_var(b, trav_vars.origin, nir_load_var(b, vars->origin), 7); - nir_store_var(b, trav_vars.dir, nir_load_var(b, vars->direction), 7); - nir_store_var(b, trav_vars.inv_dir, nir_fdiv(b, vec3ones, nir_load_var(b, trav_vars.dir)), 7); - nir_store_var(b, trav_vars.sbt_offset_and_flags, nir_imm_int(b, 0), 1); - nir_store_var(b, trav_vars.instance_addr, nir_imm_int64(b, 0), 1); + nir_store_var(&b, trav_vars.origin, nir_load_var(&b, vars.origin), 7); + nir_store_var(&b, trav_vars.dir, nir_load_var(&b, vars.direction), 7); + nir_store_var(&b, trav_vars.inv_dir, nir_fdiv(&b, vec3ones, nir_load_var(&b, trav_vars.dir)), 7); + nir_store_var(&b, trav_vars.sbt_offset_and_flags, nir_imm_int(&b, 0), 1); + nir_store_var(&b, trav_vars.instance_addr, nir_imm_int64(&b, 0), 1); - nir_store_var(b, trav_vars.stack, nir_iadd(b, stack_base, stack_entry_stride_def), 1); - nir_store_shared(b, bvh_root, stack_base, .base = 0, .align_mul = stack_entry_size); + nir_store_var(&b, trav_vars.stack, nir_iadd(&b, stack_base, stack_entry_stride_def), 1); + nir_store_shared(&b, bvh_root, stack_base, .base = 0, .align_mul = stack_entry_size); - nir_store_var(b, trav_vars.top_stack, nir_imm_int(b, 0), 1); + nir_store_var(&b, trav_vars.top_stack, nir_imm_int(&b, 0), 1); - nir_push_loop(b); + nir_push_loop(&b); - nir_push_if(b, nir_ieq(b, nir_load_var(b, trav_vars.stack), stack_base)); - nir_jump(b, nir_jump_break); - nir_pop_if(b, NULL); + nir_push_if(&b, nir_ieq(&b, nir_load_var(&b, trav_vars.stack), stack_base)); + nir_jump(&b, nir_jump_break); + nir_pop_if(&b, NULL); nir_push_if( - b, nir_uge(b, nir_load_var(b, trav_vars.top_stack), nir_load_var(b, trav_vars.stack))); - nir_store_var(b, trav_vars.top_stack, nir_imm_int(b, 0), 1); - nir_store_var(b, trav_vars.bvh_base, - build_addr_to_node(b, nir_load_var(b, vars->accel_struct)), 1); - nir_store_var(b, trav_vars.origin, nir_load_var(b, vars->origin), 7); - nir_store_var(b, trav_vars.dir, nir_load_var(b, vars->direction), 7); - nir_store_var(b, trav_vars.inv_dir, nir_fdiv(b, vec3ones, nir_load_var(b, trav_vars.dir)), 7); - nir_store_var(b, trav_vars.instance_addr, nir_imm_int64(b, 0), 1); + &b, nir_uge(&b, nir_load_var(&b, trav_vars.top_stack), nir_load_var(&b, trav_vars.stack))); + nir_store_var(&b, trav_vars.top_stack, nir_imm_int(&b, 0), 1); + nir_store_var(&b, trav_vars.bvh_base, + build_addr_to_node(&b, nir_load_var(&b, vars.accel_struct)), 1); + nir_store_var(&b, trav_vars.origin, nir_load_var(&b, vars.origin), 7); + nir_store_var(&b, trav_vars.dir, nir_load_var(&b, vars.direction), 7); + nir_store_var(&b, trav_vars.inv_dir, nir_fdiv(&b, vec3ones, nir_load_var(&b, trav_vars.dir)), 7); + nir_store_var(&b, trav_vars.instance_addr, nir_imm_int64(&b, 0), 1); - nir_pop_if(b, NULL); + nir_pop_if(&b, NULL); - nir_store_var(b, trav_vars.stack, - nir_isub(b, nir_load_var(b, trav_vars.stack), stack_entry_stride_def), 1); + nir_store_var(&b, trav_vars.stack, + nir_isub(&b, nir_load_var(&b, trav_vars.stack), stack_entry_stride_def), 1); - nir_ssa_def *bvh_node = nir_load_shared(b, 1, 32, nir_load_var(b, trav_vars.stack), .base = 0, + nir_ssa_def *bvh_node = nir_load_shared(&b, 1, 32, nir_load_var(&b, trav_vars.stack), .base = 0, .align_mul = stack_entry_size); - nir_ssa_def *bvh_node_type = nir_iand_imm(b, bvh_node, 7); + nir_ssa_def *bvh_node_type = nir_iand_imm(&b, bvh_node, 7); - bvh_node = nir_iadd(b, nir_load_var(b, trav_vars.bvh_base), nir_u2u(b, bvh_node, 64)); + bvh_node = nir_iadd(&b, nir_load_var(&b, trav_vars.bvh_base), nir_u2u(&b, bvh_node, 64)); nir_ssa_def *intrinsic_result = NULL; if (!radv_emulate_rt(device->physical_device)) { intrinsic_result = nir_bvh64_intersect_ray_amd( - b, 32, desc, nir_unpack_64_2x32(b, bvh_node), nir_load_var(b, vars->tmax), - nir_load_var(b, trav_vars.origin), nir_load_var(b, trav_vars.dir), - nir_load_var(b, trav_vars.inv_dir)); + &b, 32, desc, nir_unpack_64_2x32(&b, bvh_node), nir_load_var(&b, vars.tmax), + nir_load_var(&b, trav_vars.origin), nir_load_var(&b, trav_vars.dir), + nir_load_var(&b, trav_vars.inv_dir)); } - nir_push_if(b, nir_ine_imm(b, nir_iand_imm(b, bvh_node_type, 4), 0)); + nir_push_if(&b, nir_ine_imm(&b, nir_iand_imm(&b, bvh_node_type, 4), 0)); { - nir_push_if(b, nir_ine_imm(b, nir_iand_imm(b, bvh_node_type, 2), 0)); + nir_push_if(&b, nir_ine_imm(&b, nir_iand_imm(&b, bvh_node_type, 2), 0)); { /* custom */ - nir_push_if(b, nir_ine_imm(b, nir_iand_imm(b, bvh_node_type, 1), 0)); + nir_push_if(&b, nir_ine_imm(&b, nir_iand_imm(&b, bvh_node_type, 1), 0)); if (!(pCreateInfo->flags & VK_PIPELINE_CREATE_RAY_TRACING_SKIP_AABBS_BIT_KHR)) { - insert_traversal_aabb_case(device, pCreateInfo, b, vars, &trav_vars, bvh_node); + insert_traversal_aabb_case(device, pCreateInfo, &b, &vars, &trav_vars, bvh_node); } - nir_push_else(b, NULL); + nir_push_else(&b, NULL); { /* instance */ - nir_ssa_def *instance_node_addr = build_node_to_addr(device, b, bvh_node); + nir_ssa_def *instance_node_addr = build_node_to_addr(device, &b, bvh_node); nir_ssa_def *instance_data = - nir_build_load_global(b, 4, 32, instance_node_addr, .align_mul = 64); + nir_build_load_global(&b, 4, 32, instance_node_addr, .align_mul = 64); nir_ssa_def *wto_matrix[] = { - nir_build_load_global(b, 4, 32, nir_iadd_imm(b, instance_node_addr, 16), + nir_build_load_global(&b, 4, 32, nir_iadd_imm(&b, instance_node_addr, 16), .align_mul = 64, .align_offset = 16), - nir_build_load_global(b, 4, 32, nir_iadd_imm(b, instance_node_addr, 32), + nir_build_load_global(&b, 4, 32, nir_iadd_imm(&b, instance_node_addr, 32), .align_mul = 64, .align_offset = 32), - nir_build_load_global(b, 4, 32, nir_iadd_imm(b, instance_node_addr, 48), + nir_build_load_global(&b, 4, 32, nir_iadd_imm(&b, instance_node_addr, 48), .align_mul = 64, .align_offset = 48)}; nir_ssa_def *instance_id = - nir_build_load_global(b, 1, 32, nir_iadd_imm(b, instance_node_addr, 88)); - nir_ssa_def *instance_and_mask = nir_channel(b, instance_data, 2); - nir_ssa_def *instance_mask = nir_ushr_imm(b, instance_and_mask, 24); + nir_build_load_global(&b, 1, 32, nir_iadd_imm(&b, instance_node_addr, 88)); + nir_ssa_def *instance_and_mask = nir_channel(&b, instance_data, 2); + nir_ssa_def *instance_mask = nir_ushr_imm(&b, instance_and_mask, 24); nir_push_if( - b, - nir_ieq_imm(b, nir_iand(b, instance_mask, nir_load_var(b, vars->cull_mask)), 0)); - nir_jump(b, nir_jump_continue); - nir_pop_if(b, NULL); + &b, + nir_ieq_imm(&b, nir_iand(&b, instance_mask, nir_load_var(&b, vars.cull_mask)), 0)); + nir_jump(&b, nir_jump_continue); + nir_pop_if(&b, NULL); - nir_store_var(b, trav_vars.top_stack, nir_load_var(b, trav_vars.stack), 1); - nir_store_var(b, trav_vars.bvh_base, + nir_store_var(&b, trav_vars.top_stack, nir_load_var(&b, trav_vars.stack), 1); + nir_store_var(&b, trav_vars.bvh_base, build_addr_to_node( - b, nir_pack_64_2x32(b, nir_channels(b, instance_data, 0x3))), + &b, nir_pack_64_2x32(&b, nir_channels(&b, instance_data, 0x3))), 1); - nir_store_shared(b, nir_iand_imm(b, nir_channel(b, instance_data, 0), 63), - nir_load_var(b, trav_vars.stack), .base = 0, + nir_store_shared(&b, nir_iand_imm(&b, nir_channel(&b, instance_data, 0), 63), + nir_load_var(&b, trav_vars.stack), .base = 0, .align_mul = stack_entry_size); - nir_store_var(b, trav_vars.stack, - nir_iadd(b, nir_load_var(b, trav_vars.stack), stack_entry_stride_def), + nir_store_var(&b, trav_vars.stack, + nir_iadd(&b, nir_load_var(&b, trav_vars.stack), stack_entry_stride_def), 1); nir_store_var( - b, trav_vars.origin, - nir_build_vec3_mat_mult_pre(b, nir_load_var(b, vars->origin), wto_matrix), 7); + &b, trav_vars.origin, + nir_build_vec3_mat_mult_pre(&b, nir_load_var(&b, vars.origin), wto_matrix), 7); nir_store_var( - b, trav_vars.dir, - nir_build_vec3_mat_mult(b, nir_load_var(b, vars->direction), wto_matrix, false), + &b, trav_vars.dir, + nir_build_vec3_mat_mult(&b, nir_load_var(&b, vars.direction), wto_matrix, false), 7); - nir_store_var(b, trav_vars.inv_dir, - nir_fdiv(b, vec3ones, nir_load_var(b, trav_vars.dir)), 7); - nir_store_var(b, trav_vars.custom_instance_and_mask, instance_and_mask, 1); - nir_store_var(b, trav_vars.sbt_offset_and_flags, nir_channel(b, instance_data, 3), + nir_store_var(&b, trav_vars.inv_dir, + nir_fdiv(&b, vec3ones, nir_load_var(&b, trav_vars.dir)), 7); + nir_store_var(&b, trav_vars.custom_instance_and_mask, instance_and_mask, 1); + nir_store_var(&b, trav_vars.sbt_offset_and_flags, nir_channel(&b, instance_data, 3), 1); - nir_store_var(b, trav_vars.instance_id, instance_id, 1); - nir_store_var(b, trav_vars.instance_addr, instance_node_addr, 1); + nir_store_var(&b, trav_vars.instance_id, instance_id, 1); + nir_store_var(&b, trav_vars.instance_addr, instance_node_addr, 1); } - nir_pop_if(b, NULL); + nir_pop_if(&b, NULL); } - nir_push_else(b, NULL); + nir_push_else(&b, NULL); { /* box */ nir_ssa_def *result = intrinsic_result; @@ -1573,61 +1580,85 @@ insert_traversal(struct radv_device *device, const VkRayTracingPipelineCreateInf /* If we didn't run the intrinsic cause the hardware didn't support it, * emulate ray/box intersection here */ result = intersect_ray_amd_software_box(device, - b, bvh_node, nir_load_var(b, vars->tmax), nir_load_var(b, trav_vars.origin), - nir_load_var(b, trav_vars.dir), nir_load_var(b, trav_vars.inv_dir)); + &b, bvh_node, nir_load_var(&b, vars.tmax), nir_load_var(&b, trav_vars.origin), + nir_load_var(&b, trav_vars.dir), nir_load_var(&b, trav_vars.inv_dir)); } for (unsigned i = 4; i-- > 0; ) { - nir_ssa_def *new_node = nir_channel(b, result, i); - nir_push_if(b, nir_ine_imm(b, new_node, 0xffffffff)); + nir_ssa_def *new_node = nir_channel(&b, result, i); + nir_push_if(&b, nir_ine_imm(&b, new_node, 0xffffffff)); { - nir_store_shared(b, new_node, nir_load_var(b, trav_vars.stack), .base = 0, + nir_store_shared(&b, new_node, nir_load_var(&b, trav_vars.stack), .base = 0, .align_mul = stack_entry_size); nir_store_var( - b, trav_vars.stack, - nir_iadd(b, nir_load_var(b, trav_vars.stack), stack_entry_stride_def), 1); + &b, trav_vars.stack, + nir_iadd(&b, nir_load_var(&b, trav_vars.stack), stack_entry_stride_def), 1); } - nir_pop_if(b, NULL); + nir_pop_if(&b, NULL); } } - nir_pop_if(b, NULL); + nir_pop_if(&b, NULL); } - nir_push_else(b, NULL); + nir_push_else(&b, NULL); if (!(pCreateInfo->flags & VK_PIPELINE_CREATE_RAY_TRACING_SKIP_TRIANGLES_BIT_KHR)) { nir_ssa_def *result = intrinsic_result; if (!result) { /* If we didn't run the intrinsic cause the hardware didn't support it, * emulate ray/tri intersection here */ result = intersect_ray_amd_software_tri(device, - b, bvh_node, nir_load_var(b, vars->tmax), nir_load_var(b, trav_vars.origin), - nir_load_var(b, trav_vars.dir), nir_load_var(b, trav_vars.inv_dir)); + &b, bvh_node, nir_load_var(&b, vars.tmax), nir_load_var(&b, trav_vars.origin), + nir_load_var(&b, trav_vars.dir), nir_load_var(&b, trav_vars.inv_dir)); } - insert_traversal_triangle_case(device, pCreateInfo, b, result, vars, &trav_vars, bvh_node); + insert_traversal_triangle_case(device, pCreateInfo, &b, result, &vars, &trav_vars, bvh_node); } - nir_pop_if(b, NULL); + nir_pop_if(&b, NULL); - nir_pop_loop(b, NULL); + nir_pop_loop(&b, NULL); } - nir_pop_if(b, NULL); + nir_pop_if(&b, NULL); /* should_return is set if we had a hit but we won't be calling the closest hit shader and hence * need to return immediately to the calling shader. */ - nir_push_if(b, nir_load_var(b, trav_vars.should_return)); + nir_push_if(&b, nir_load_var(&b, trav_vars.should_return)); { - insert_rt_return(b, vars); + insert_rt_return(&b, &vars); } - nir_push_else(b, NULL); + nir_push_else(&b, NULL); { /* Only load the miss shader if we actually miss, which we determining by not having set * a closest hit shader. It is valid to not specify an SBT pointer for miss shaders if none * of the rays miss. */ - nir_push_if(b, nir_ieq_imm(b, nir_load_var(b, vars->idx), 0)); + nir_push_if(&b, nir_ieq_imm(&b, nir_load_var(&b, vars.idx), 0)); { - load_sbt_entry(b, vars, nir_load_var(b, vars->miss_index), SBT_MISS, 0); + load_sbt_entry(&b, &vars, nir_load_var(&b, vars.miss_index), SBT_MISS, 0); } - nir_pop_if(b, NULL); + nir_pop_if(&b, NULL); } + nir_pop_if(&b, NULL); + + return b.shader; +} + + +static void +insert_traversal(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, + nir_builder *b, const struct rt_variables *vars) +{ + struct hash_table *var_remap = _mesa_pointer_hash_table_create(NULL); + nir_shader *shader = build_traversal_shader(device, pCreateInfo, vars, var_remap); + b->shader->info.shared_size += shader->info.shared_size; + assert(b->shader->info.shared_size <= 32768); + + /* For now, just inline the traversal shader */ + nir_push_if(b, nir_ieq_imm(b, nir_load_var(b, vars->idx), 1)); + nir_store_var(b, vars->main_loop_case_visited, nir_imm_bool(b, true), 1); + nir_inline_function_impl(b, nir_shader_get_entrypoint(shader), NULL, var_remap); nir_pop_if(b, NULL); + + /* Adopt the instructions from the source shader, since they are merely moved, not cloned. */ + ralloc_adopt(ralloc_context(b->shader), ralloc_context(shader)); + + ralloc_free(var_remap); } static unsigned @@ -1770,10 +1801,7 @@ create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInf nir_store_var(&b, vars.main_loop_case_visited, nir_imm_bool(&b, false), 1); - nir_push_if(&b, nir_ieq_imm(&b, nir_load_var(&b, vars.idx), 1)); - nir_store_var(&b, vars.main_loop_case_visited, nir_imm_bool(&b, true), 1); insert_traversal(device, pCreateInfo, &b, &vars); - nir_pop_if(&b, NULL); nir_ssa_def *idx = nir_load_var(&b, vars.idx);