diff --git a/src/amd/vulkan/radv_rt_shader.c b/src/amd/vulkan/radv_rt_shader.c index b51b93e7310..c1e23230e89 100644 --- a/src/amd/vulkan/radv_rt_shader.c +++ b/src/amd/vulkan/radv_rt_shader.c @@ -98,6 +98,9 @@ struct rt_variables { /* global address of the SBT entry used for the shader */ nir_variable *shader_record_ptr; + nir_variable *launch_size; + nir_variable *launch_id; + /* trace_ray arguments */ nir_variable *accel_struct; nir_variable *flags; @@ -157,6 +160,10 @@ create_rt_variables(nir_shader *shader, const VkRayTracingPipelineCreateInfoKHR vars.shader_record_ptr = nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "shader_record_ptr"); + const struct glsl_type *uvec3_type = glsl_vector_type(GLSL_TYPE_UINT, 3); + vars.launch_size = nir_variable_create(shader, nir_var_shader_temp, uvec3_type, "launch_size"); + vars.launch_id = nir_variable_create(shader, nir_var_shader_temp, uvec3_type, "launch_id"); + const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3); vars.accel_struct = nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "accel_struct"); @@ -204,6 +211,8 @@ map_rt_variables(struct hash_table *var_remap, struct rt_variables *src, _mesa_hash_table_insert(var_remap, src->arg, dst->arg); _mesa_hash_table_insert(var_remap, src->stack_ptr, dst->stack_ptr); _mesa_hash_table_insert(var_remap, src->shader_record_ptr, dst->shader_record_ptr); + _mesa_hash_table_insert(var_remap, src->launch_size, dst->launch_size); + _mesa_hash_table_insert(var_remap, src->launch_id, dst->launch_id); _mesa_hash_table_insert(var_remap, src->accel_struct, dst->accel_struct); _mesa_hash_table_insert(var_remap, src->flags, dst->flags); @@ -422,23 +431,11 @@ lower_rt_instructions(nir_shader *shader, struct rt_variables *vars, unsigned ca break; } case nir_intrinsic_load_ray_launch_id: { - ret = nir_load_global_invocation_id(&b_shader, 32); + ret = nir_load_var(&b_shader, vars->launch_id); break; } case nir_intrinsic_load_ray_launch_size: { - nir_ssa_def *launch_size_addr = nir_load_ray_launch_size_addr_amd(&b_shader); - - nir_ssa_def *xy = nir_build_load_smem_amd(&b_shader, 2, launch_size_addr, - nir_imm_int(&b_shader, 0)); - nir_ssa_def *z = nir_build_load_smem_amd(&b_shader, 1, launch_size_addr, - nir_imm_int(&b_shader, 8)); - - nir_ssa_def *xyz[3] = { - nir_channel(&b_shader, xy, 0), - nir_channel(&b_shader, xy, 1), - z, - }; - ret = nir_vec(&b_shader, xyz, 3); + ret = nir_load_var(&b_shader, vars->launch_size); break; } case nir_intrinsic_load_ray_t_min: { @@ -1328,6 +1325,8 @@ build_traversal_shader(struct radv_device *device, nir_store_var(&b, vars.tmax, nir_load_ray_t_max(&b), 0x1); nir_store_var(&b, vars.arg, nir_load_rt_arg_scratch_offset_amd(&b), 0x1); nir_store_var(&b, vars.stack_ptr, nir_imm_int(&b, 0), 0x1); + nir_store_var(&b, vars.launch_size, nir_load_ray_launch_size(&b), 0x7); + nir_store_var(&b, vars.launch_id, nir_load_ray_launch_id(&b), 0x7); struct rt_traversal_vars trav_vars = init_traversal_vars(&b); @@ -1590,6 +1589,17 @@ create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInf else nir_store_var(&b, vars.stack_ptr, nir_imm_int(&b, 0), 0x1); + nir_store_var(&b, vars.launch_id, nir_load_global_invocation_id(&b, 32), 0x7); + nir_ssa_def *launch_size_addr = nir_load_ray_launch_size_addr_amd(&b); + nir_ssa_def *xy = nir_build_load_smem_amd(&b, 2, launch_size_addr, nir_imm_int(&b, 0)); + nir_ssa_def *z = nir_build_load_smem_amd(&b, 1, launch_size_addr, nir_imm_int(&b, 8)); + nir_ssa_def *xyz[3] = { + nir_channel(&b, xy, 0), + nir_channel(&b, xy, 1), + z, + }; + nir_store_var(&b, vars.launch_size, nir_vec(&b, xyz, 3), 0x7); + nir_variable *hit_attribs[RADV_MAX_HIT_ATTRIB_SIZE / sizeof(uint32_t)]; for (uint32_t i = 0; i < ARRAY_SIZE(hit_attribs); i++) hit_attribs[i] = nir_local_variable_create(nir_shader_get_entrypoint(b.shader),