diff --git a/src/amd/vulkan/radv_pipeline_rt.c b/src/amd/vulkan/radv_pipeline_rt.c index 5e0c9ca6fb7..32f936d1e13 100644 --- a/src/amd/vulkan/radv_pipeline_rt.c +++ b/src/amd/vulkan/radv_pipeline_rt.c @@ -611,7 +611,7 @@ radv_rt_pipeline_create(VkDevice _device, VkPipelineCache _cache, VkPipelineShaderStageCreateInfo stage = { .sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO, .pNext = NULL, - .stage = VK_SHADER_STAGE_COMPUTE_BIT, + .stage = VK_SHADER_STAGE_RAYGEN_BIT_KHR, .module = vk_shader_module_to_handle(&module), .pName = "main", }; @@ -664,13 +664,18 @@ radv_rt_pipeline_create(VkDevice _device, VkPipelineCache _cache, goto shader_fail; } - radv_compute_pipeline_init(&rt_pipeline->base, pipeline_layout); - rt_pipeline->stack_size = compute_rt_stack_size(pCreateInfo, rt_pipeline->groups); rt_pipeline->base.base.shaders[MESA_SHADER_COMPUTE] = radv_create_rt_prolog(device); - *pPipeline = radv_pipeline_to_handle(&rt_pipeline->base.base); + combine_config(&rt_pipeline->base.base.shaders[MESA_SHADER_COMPUTE]->config, + &rt_pipeline->base.base.shaders[MESA_SHADER_RAYGEN]->config); + postprocess_rt_config(&rt_pipeline->base.base.shaders[MESA_SHADER_COMPUTE]->config, + device->physical_device->rt_wave_size); + + radv_compute_pipeline_init(&rt_pipeline->base, pipeline_layout); + + *pPipeline = radv_pipeline_to_handle(&rt_pipeline->base.base); shader_fail: ralloc_free(shader); pipeline_fail: diff --git a/src/amd/vulkan/radv_rt_shader.c b/src/amd/vulkan/radv_rt_shader.c index 9edaa6cac71..03990216fce 100644 --- a/src/amd/vulkan/radv_rt_shader.c +++ b/src/amd/vulkan/radv_rt_shader.c @@ -97,9 +97,6 @@ struct rt_variables { /* global address of the SBT entry used for the shader */ nir_variable *shader_record_ptr; - nir_variable *launch_size; - nir_variable *launch_id; - /* trace_ray arguments */ nir_variable *accel_struct; nir_variable *cull_mask_and_flags; @@ -137,10 +134,6 @@ create_rt_variables(nir_shader *shader, const VkRayTracingPipelineCreateInfoKHR vars.shader_record_ptr = nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "shader_record_ptr"); - const struct glsl_type *uvec3_type = glsl_vector_type(GLSL_TYPE_UINT, 3); - vars.launch_size = nir_variable_create(shader, nir_var_shader_temp, uvec3_type, "launch_size"); - vars.launch_id = nir_variable_create(shader, nir_var_shader_temp, uvec3_type, "launch_id"); - const struct glsl_type *vec3_type = glsl_vector_type(GLSL_TYPE_FLOAT, 3); vars.accel_struct = nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "accel_struct"); @@ -187,8 +180,6 @@ map_rt_variables(struct hash_table *var_remap, struct rt_variables *src, _mesa_hash_table_insert(var_remap, src->arg, dst->arg); _mesa_hash_table_insert(var_remap, src->stack_ptr, dst->stack_ptr); _mesa_hash_table_insert(var_remap, src->shader_record_ptr, dst->shader_record_ptr); - _mesa_hash_table_insert(var_remap, src->launch_size, dst->launch_size); - _mesa_hash_table_insert(var_remap, src->launch_id, dst->launch_id); _mesa_hash_table_insert(var_remap, src->accel_struct, dst->accel_struct); _mesa_hash_table_insert(var_remap, src->cull_mask_and_flags, dst->cull_mask_and_flags); @@ -403,14 +394,6 @@ lower_rt_instructions(nir_shader *shader, struct rt_variables *vars, unsigned ca ret = nir_load_var(&b_shader, vars->shader_record_ptr); break; } - case nir_intrinsic_load_ray_launch_id: { - ret = nir_load_var(&b_shader, vars->launch_id); - break; - } - case nir_intrinsic_load_ray_launch_size: { - ret = nir_load_var(&b_shader, vars->launch_size); - break; - } case nir_intrinsic_load_ray_t_min: { ret = nir_load_var(&b_shader, vars->tmin); break; @@ -1429,8 +1412,6 @@ build_traversal_shader(struct radv_device *device, nir_store_var(&b, vars.tmax, nir_load_ray_t_max(&b), 0x1); nir_store_var(&b, vars.arg, nir_load_rt_arg_scratch_offset_amd(&b), 0x1); nir_store_var(&b, vars.stack_ptr, nir_imm_int(&b, 0), 0x1); - nir_store_var(&b, vars.launch_size, nir_load_ray_launch_size(&b), 0x7); - nir_store_var(&b, vars.launch_id, nir_load_ray_launch_id(&b), 0x7); struct rt_traversal_vars trav_vars = init_traversal_vars(&b); @@ -1594,7 +1575,7 @@ nir_shader * create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo, struct radv_ray_tracing_module *groups, const struct radv_pipeline_key *key) { - nir_builder b = radv_meta_init_shader(device, MESA_SHADER_COMPUTE, "rt_combined"); + nir_builder b = radv_meta_init_shader(device, MESA_SHADER_RAYGEN, "rt_combined"); b.shader->info.internal = false; b.shader->info.workgroup_size[0] = 8; b.shader->info.workgroup_size[1] = device->physical_device->rt_wave_size == 64 ? 8 : 4; @@ -1604,17 +1585,6 @@ create_rt_shader(struct radv_device *device, const VkRayTracingPipelineCreateInf load_sbt_entry(&b, &vars, nir_imm_int(&b, 0), SBT_RAYGEN, SBT_GENERAL_IDX); nir_store_var(&b, vars.stack_ptr, nir_load_rt_dynamic_callable_stack_base_amd(&b), 0x1); - nir_store_var(&b, vars.launch_id, nir_load_global_invocation_id(&b, 32), 0x7); - nir_ssa_def *launch_size_addr = nir_load_ray_launch_size_addr_amd(&b); - nir_ssa_def *xy = nir_build_load_smem_amd(&b, 2, launch_size_addr, nir_imm_int(&b, 0)); - nir_ssa_def *z = nir_build_load_smem_amd(&b, 1, launch_size_addr, nir_imm_int(&b, 8)); - nir_ssa_def *xyz[3] = { - nir_channel(&b, xy, 0), - nir_channel(&b, xy, 1), - z, - }; - nir_store_var(&b, vars.launch_size, nir_vec(&b, xyz, 3), 0x7); - nir_loop *loop = nir_push_loop(&b); nir_ssa_def *idx = nir_load_var(&b, vars.idx); diff --git a/src/amd/vulkan/radv_shader_args.c b/src/amd/vulkan/radv_shader_args.c index 456dc310e57..5b9ba75e870 100644 --- a/src/amd/vulkan/radv_shader_args.c +++ b/src/amd/vulkan/radv_shader_args.c @@ -638,10 +638,6 @@ radv_declare_shader_args(const struct radv_device *device, const struct radv_pip case MESA_SHADER_TASK: declare_global_input_sgprs(info, &user_sgpr_info, args); - if (info->cs.is_rt_shader) { - ac_add_arg(&args->ac, AC_ARG_SGPR, 2, AC_ARG_CONST_PTR, &args->ac.sbt_descriptors); - } - if (info->cs.uses_grid_size) { if (args->load_grid_size_from_user_sgpr) ac_add_arg(&args->ac, AC_ARG_SGPR, 3, AC_ARG_INT, &args->ac.num_work_groups); @@ -649,15 +645,13 @@ radv_declare_shader_args(const struct radv_device *device, const struct radv_pip ac_add_arg(&args->ac, AC_ARG_SGPR, 2, AC_ARG_CONST_PTR, &args->ac.num_work_groups); } - if (info->cs.uses_ray_launch_size) { + if (info->cs.is_rt_shader) { + ac_add_arg(&args->ac, AC_ARG_SGPR, 2, AC_ARG_CONST_DESC_PTR, &args->ac.sbt_descriptors); ac_add_arg(&args->ac, AC_ARG_SGPR, 2, AC_ARG_CONST_PTR, &args->ac.ray_launch_size_addr); - } - - if (info->cs.uses_dynamic_rt_callable_stack) { - ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_INT, - &args->ac.rt_dynamic_callable_stack_base); ac_add_arg(&args->ac, AC_ARG_SGPR, 2, AC_ARG_CONST_PTR, &args->ac.rt_traversal_shader_addr); + ac_add_arg(&args->ac, AC_ARG_SGPR, 1, AC_ARG_INT, + &args->ac.rt_dynamic_callable_stack_base); } if (info->vs.needs_draw_id) { @@ -934,12 +928,12 @@ radv_declare_shader_args(const struct radv_device *device, const struct radv_pip if (args->ac.ray_launch_size_addr.used) { set_loc_shader_ptr(args, AC_UD_CS_RAY_LAUNCH_SIZE_ADDR, &user_sgpr_idx); } - if (args->ac.rt_dynamic_callable_stack_base.used) { - set_loc_shader(args, AC_UD_CS_RAY_DYNAMIC_CALLABLE_STACK_BASE, &user_sgpr_idx, 1); - } if (args->ac.rt_traversal_shader_addr.used) { set_loc_shader_ptr(args, AC_UD_CS_TRAVERSAL_SHADER_ADDR, &user_sgpr_idx); } + if (args->ac.rt_dynamic_callable_stack_base.used) { + set_loc_shader(args, AC_UD_CS_RAY_DYNAMIC_CALLABLE_STACK_BASE, &user_sgpr_idx, 1); + } if (args->ac.draw_id.used) { set_loc_shader(args, AC_UD_CS_TASK_DRAW_ID, &user_sgpr_idx, 1); }