diff --git a/src/amd/vulkan/nir/radv_nir_lower_hit_attrib_derefs.c b/src/amd/vulkan/nir/radv_nir_lower_hit_attrib_derefs.c index d229fe3ba48..38e14dd4015 100644 --- a/src/amd/vulkan/nir/radv_nir_lower_hit_attrib_derefs.c +++ b/src/amd/vulkan/nir/radv_nir_lower_hit_attrib_derefs.c @@ -6,6 +6,7 @@ #include "nir.h" #include "nir_builder.h" +#include "nir_deref.h" #include "radv_constants.h" #include "radv_nir.h" @@ -29,10 +30,12 @@ lower_hit_attrib_deref(nir_builder *b, nir_instr *instr, void *data) if (!nir_deref_mode_is(deref, args->mode)) return false; - assert(deref->deref_type == nir_deref_type_var); - b->cursor = nir_after_instr(instr); + nir_variable *var = nir_deref_instr_get_variable(deref); + uint32_t location = args->base_offset + var->data.driver_location + + nir_deref_instr_get_const_offset(deref, glsl_get_natural_size_align_bytes); + if (intrin->intrinsic == nir_intrinsic_load_deref) { uint32_t num_components = intrin->def.num_components; uint32_t bit_size = intrin->def.bit_size; @@ -40,7 +43,7 @@ lower_hit_attrib_deref(nir_builder *b, nir_instr *instr, void *data) nir_def *components[NIR_MAX_VEC_COMPONENTS]; for (uint32_t comp = 0; comp < num_components; comp++) { - uint32_t offset = args->base_offset + deref->var->data.driver_location + comp * DIV_ROUND_UP(bit_size, 8); + uint32_t offset = location + comp * DIV_ROUND_UP(bit_size, 8); uint32_t base = offset / 4; uint32_t comp_offset = offset % 4; @@ -68,7 +71,7 @@ lower_hit_attrib_deref(nir_builder *b, nir_instr *instr, void *data) uint32_t bit_size = value->bit_size; for (uint32_t comp = 0; comp < num_components; comp++) { - uint32_t offset = args->base_offset + deref->var->data.driver_location + comp * DIV_ROUND_UP(bit_size, 8); + uint32_t offset = location + comp * DIV_ROUND_UP(bit_size, 8); uint32_t base = offset / 4; uint32_t comp_offset = offset % 4; @@ -102,17 +105,35 @@ lower_hit_attrib_deref(nir_builder *b, nir_instr *instr, void *data) return true; } +static bool +radv_lower_payload_arg_to_offset(nir_builder *b, nir_intrinsic_instr *instr, void *data) +{ + if (instr->intrinsic != nir_intrinsic_trace_ray) + return false; + + nir_deref_instr *payload = nir_src_as_deref(instr->src[10]); + assert(payload->deref_type == nir_deref_type_var); + + b->cursor = nir_before_instr(&instr->instr); + nir_def *offset = nir_imm_int(b, payload->var->data.driver_location); + + nir_src_rewrite(&instr->src[10], offset); + + return true; +} + static bool radv_nir_lower_rt_vars(nir_shader *shader, nir_variable_mode mode, uint32_t base_offset) { bool progress = false; - progress |= nir_split_struct_vars(shader, mode); progress |= nir_lower_indirect_derefs(shader, mode, UINT32_MAX); - progress |= nir_split_array_vars(shader, mode); progress |= nir_lower_vars_to_explicit_types(shader, mode, glsl_get_natural_size_align_bytes); + if (shader->info.stage == MESA_SHADER_RAYGEN && mode == nir_var_function_temp) + progress |= nir_shader_intrinsics_pass(shader, radv_lower_payload_arg_to_offset, nir_metadata_control_flow, NULL); + struct lower_hit_attrib_deref_args args = { .mode = mode, .base_offset = base_offset, diff --git a/src/amd/vulkan/nir/radv_nir_rt_shader.c b/src/amd/vulkan/nir/radv_nir_rt_shader.c index eccc6acc474..ced8732fee2 100644 --- a/src/amd/vulkan/nir/radv_nir_rt_shader.c +++ b/src/amd/vulkan/nir/radv_nir_rt_shader.c @@ -835,23 +835,6 @@ insert_rt_case(nir_builder *b, nir_shader *shader, struct rt_variables *vars, ni ralloc_free(var_remap); } -static bool -radv_lower_payload_arg_to_offset(nir_builder *b, nir_intrinsic_instr *instr, void *data) -{ - if (instr->intrinsic != nir_intrinsic_trace_ray) - return false; - - nir_deref_instr *payload = nir_src_as_deref(instr->src[10]); - assert(payload->deref_type == nir_deref_type_var); - - b->cursor = nir_before_instr(&instr->instr); - nir_def *offset = nir_imm_int(b, payload->var->data.driver_location); - - nir_src_rewrite(&instr->src[10], offset); - - return true; -} - void radv_nir_lower_rt_io(nir_shader *nir, bool monolithic, uint32_t payload_offset) { @@ -863,17 +846,6 @@ radv_nir_lower_rt_io(nir_shader *nir, bool monolithic, uint32_t payload_offset) NIR_PASS(_, nir, nir_lower_explicit_io, nir_var_function_temp, nir_address_format_32bit_offset); } else { - if (nir->info.stage == MESA_SHADER_RAYGEN) { - /* Use nir_lower_vars_to_explicit_types to assign the payload locations. We call - * nir_lower_vars_to_explicit_types later after splitting the payloads. - */ - uint32_t scratch_size = nir->scratch_size; - nir_lower_vars_to_explicit_types(nir, nir_var_function_temp, glsl_get_natural_size_align_bytes); - nir->scratch_size = scratch_size; - - nir_shader_intrinsics_pass(nir, radv_lower_payload_arg_to_offset, nir_metadata_control_flow, NULL); - } - NIR_PASS(_, nir, radv_nir_lower_ray_payload_derefs, payload_offset); } }