diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c index 972fed1c893..fd830b0bcbc 100644 --- a/src/compiler/spirv/spirv_to_nir.c +++ b/src/compiler/spirv/spirv_to_nir.c @@ -6778,6 +6778,27 @@ vtn_create_builder(const uint32_t *words, size_t word_count, return NULL; } +/* See glsl_type_add_to_function_params and vtn_ssa_value_add_to_call_params */ +static void +vtn_emit_kernel_entry_point_wrapper_struct_param(struct nir_builder *b, + nir_deref_instr *deref, + nir_call_instr *call, + unsigned *idx) +{ + if (glsl_type_is_vector_or_scalar(deref->type)) { + call->params[(*idx)++] = nir_src_for_ssa(nir_load_deref(b, deref)); + } else { + unsigned elems = glsl_get_length(deref->type); + for (unsigned i = 0; i < elems; i++) { + nir_deref_instr *child_deref = glsl_type_is_struct(deref->type) + ? nir_build_deref_struct(b, deref, i) + : nir_build_deref_array_imm(b, deref, i); + vtn_emit_kernel_entry_point_wrapper_struct_param(b, child_deref, call, + idx); + } + } +} + static nir_function * vtn_emit_kernel_entry_point_wrapper(struct vtn_builder *b, nir_function *entry_point) @@ -6796,7 +6817,8 @@ vtn_emit_kernel_entry_point_wrapper(struct vtn_builder *b, nir_call_instr *call = nir_call_instr_create(b->nb.shader, entry_point); - for (unsigned i = 0; i < entry_point->num_params; ++i) { + unsigned call_idx = 0; + for (unsigned i = 0; i < b->entry_point->func->type->length; ++i) { struct vtn_type *param_type = b->entry_point->func->type->params[i]; b->shader->info.cs.has_variable_shared_mem |= @@ -6837,17 +6859,30 @@ vtn_emit_kernel_entry_point_wrapper(struct vtn_builder *b, nir_variable *copy_var = nir_local_variable_create(impl, in_var->type, "copy_in"); nir_copy_var(&b->nb, copy_var, in_var); - call->params[i] = + call->params[call_idx++] = nir_src_for_ssa(&nir_build_deref_var(&b->nb, copy_var)->def); } else if (param_type->base_type == vtn_base_type_image || param_type->base_type == vtn_base_type_sampler) { /* Don't load the var, just pass a deref of it */ - call->params[i] = nir_src_for_ssa(&nir_build_deref_var(&b->nb, in_var)->def); + call->params[call_idx++] = + nir_src_for_ssa(&nir_build_deref_var(&b->nb, in_var)->def); + } else if (param_type->base_type == vtn_base_type_struct) { + /* We decompose struct and array parameters in vtn, so we'll need to + * handle it here explicitly. + * We have to keep the arguments on the actual entry point intact, + * because the runtimes rely on it to match the SPIR-V. + */ + nir_deref_instr *deref = nir_build_deref_var(&b->nb, in_var); + vtn_emit_kernel_entry_point_wrapper_struct_param(&b->nb, deref, call, + &call_idx); } else { - call->params[i] = nir_src_for_ssa(nir_load_var(&b->nb, in_var)); + call->params[call_idx++] = + nir_src_for_ssa(nir_load_var(&b->nb, in_var)); } } + assert(call_idx == entry_point->num_params); + nir_builder_instr_insert(&b->nb, &call->instr); return main_entry_point;