diff --git a/src/gallium/drivers/d3d12/d3d12_compiler.cpp b/src/gallium/drivers/d3d12/d3d12_compiler.cpp index 7bd7c531bef..6b6b2f9e4eb 100644 --- a/src/gallium/drivers/d3d12/d3d12_compiler.cpp +++ b/src/gallium/drivers/d3d12/d3d12_compiler.cpp @@ -562,10 +562,7 @@ fill_varyings(struct d3d12_context *ctx, const nir_shader *s, } const struct glsl_type *type = var->type; - if ((s->info.stage == MESA_SHADER_GEOMETRY || - s->info.stage == MESA_SHADER_TESS_CTRL) && - (modes & nir_var_shader_in) && - glsl_type_is_array(type)) + if (nir_is_arrayed_io(var, s->info.stage)) type = glsl_get_array_element(type); info.slots[slot].types[var->data.location_frac] = type; @@ -729,7 +726,6 @@ validate_tess_ctrl_shader_variant(struct d3d12_selection_context *sel_ctx) if (tcs != NULL && !tcs->is_variant) return; - d3d12_shader_selector *vs = ctx->gfx_stages[PIPE_SHADER_VERTEX]; d3d12_shader_selector *tes = ctx->gfx_stages[PIPE_SHADER_TESS_EVAL]; struct d3d12_tcs_variant_key key = {0}; @@ -737,11 +733,12 @@ validate_tess_ctrl_shader_variant(struct d3d12_selection_context *sel_ctx) /* Fill the variant key */ if (variant_needed) { - if (vs->initial_output_vars == nullptr) { - vs->initial_output_vars = fill_varyings(sel_ctx->ctx, vs->initial, nir_var_shader_out, - vs->initial->info.outputs_written, false); + if (tes->initial_input_vars == nullptr) { + tes->initial_input_vars = fill_varyings(sel_ctx->ctx, tes->initial, nir_var_shader_in, + tes->initial->info.inputs_read & ~(VARYING_BIT_TESS_LEVEL_INNER | VARYING_BIT_TESS_LEVEL_OUTER), + false); } - key.varyings = vs->initial_output_vars; + key.varyings = tes->initial_input_vars; key.vertices_out = ctx->patch_vertices; } @@ -1519,6 +1516,7 @@ d3d12_create_shader_impl(struct d3d12_context *ctx, /* Keep this initial shader as the blue print for possible variants */ sel->initial = nir; sel->initial_output_vars = nullptr; + sel->initial_input_vars = nullptr; sel->gs_key.varyings = nullptr; sel->tcs_key.varyings = nullptr; diff --git a/src/gallium/drivers/d3d12/d3d12_compiler.h b/src/gallium/drivers/d3d12/d3d12_compiler.h index 9e6cd4a9b1a..4ab2a5a2ee4 100644 --- a/src/gallium/drivers/d3d12/d3d12_compiler.h +++ b/src/gallium/drivers/d3d12/d3d12_compiler.h @@ -269,6 +269,7 @@ struct d3d12_shader_selector { enum pipe_shader_type stage; const nir_shader *initial; struct d3d12_varying_info *initial_output_vars; + struct d3d12_varying_info *initial_input_vars; struct d3d12_shader *first; struct d3d12_shader *current;