diff --git a/src/amd/vulkan/nir/radv_nir_lower_ray_queries.c b/src/amd/vulkan/nir/radv_nir_lower_ray_queries.c index 35230b7e8f5..5600571e3c5 100644 --- a/src/amd/vulkan/nir/radv_nir_lower_ray_queries.c +++ b/src/amd/vulkan/nir/radv_nir_lower_ray_queries.c @@ -144,7 +144,7 @@ struct ray_query_vars { }; static void -init_ray_query_vars(nir_shader *shader, unsigned array_length, struct ray_query_vars *dst, const char *base_name, +init_ray_query_vars(nir_shader *shader, const glsl_type *opaque_type, struct ray_query_vars *dst, const char *base_name, uint32_t max_shared_size) { memset(dst, 0, sizeof(*dst)); @@ -154,7 +154,7 @@ init_ray_query_vars(nir_shader *shader, unsigned array_length, struct ray_query_ uint32_t shared_stack_entries = shader->info.ray_queries == 1 ? 16 : 8; uint32_t shared_stack_size = workgroup_size * shared_stack_entries * 4; uint32_t shared_offset = align(shader->info.shared_size, 4); - if (shader->info.stage != MESA_SHADER_COMPUTE || array_length > 1 || + if (shader->info.stage != MESA_SHADER_COMPUTE || glsl_type_is_array(opaque_type) || shared_offset + shared_stack_size > max_shared_size) { dst->stack_entries = MAX_SCRATCH_STACK_ENTRY_COUNT; } else { @@ -165,10 +165,7 @@ init_ray_query_vars(nir_shader *shader, unsigned array_length, struct ray_query_ shader->info.shared_size = shared_offset + shared_stack_size; } - const glsl_type *type = radv_get_ray_query_type(); - if (array_length > 1) - type = glsl_array_type(type, array_length, 0); - + const glsl_type *type = glsl_type_wrap_in_arrays(radv_get_ray_query_type(), opaque_type); dst->var = nir_variable_create(shader, nir_var_shader_temp, type, base_name); } @@ -177,11 +174,7 @@ lower_ray_query(nir_shader *shader, nir_variable *ray_query, struct hash_table * { struct ray_query_vars *vars = ralloc(ht, struct ray_query_vars); - unsigned array_length = 1; - if (glsl_type_is_array(ray_query->type)) - array_length = glsl_get_length(ray_query->type); - - init_ray_query_vars(shader, array_length, vars, ray_query->name == NULL ? "" : ray_query->name, max_shared_size); + init_ray_query_vars(shader, ray_query->type, vars, ray_query->name == NULL ? "" : ray_query->name, max_shared_size); _mesa_hash_table_insert(ht, ray_query, vars); } @@ -562,6 +555,16 @@ lower_rq_terminate(nir_builder *b, nir_intrinsic_instr *instr, nir_deref_instr * rq_store(b, rq, incomplete, nir_imm_false(b)); } +static nir_deref_instr * +radv_lower_opaque_ray_query_deref(nir_builder *b, nir_deref_instr *opaque_deref, nir_variable *var) +{ + if (opaque_deref->deref_type != nir_deref_type_array) + return nir_build_deref_var(b, var); + + nir_deref_instr *outer_deref = radv_lower_opaque_ray_query_deref(b, nir_deref_instr_parent(opaque_deref), var); + return nir_build_deref_array(b, outer_deref, opaque_deref->arr.index.ssa); +} + bool radv_nir_lower_ray_queries(struct nir_shader *shader, struct radv_device *device) { @@ -609,9 +612,7 @@ radv_nir_lower_ray_queries(struct nir_shader *shader, struct radv_device *device builder.cursor = nir_before_instr(instr); - nir_deref_instr *rq = nir_build_deref_var(&builder, vars->var); - if (ray_query_deref->deref_type == nir_deref_type_array) - rq = nir_build_deref_array(&builder, rq, ray_query_deref->arr.index.ssa); + nir_deref_instr *rq = radv_lower_opaque_ray_query_deref(&builder, ray_query_deref, vars->var); nir_def *new_dest = NULL;