diff --git a/src/gallium/drivers/d3d12/d3d12_compiler.cpp b/src/gallium/drivers/d3d12/d3d12_compiler.cpp index 377a38bc623..77d6af84489 100644 --- a/src/gallium/drivers/d3d12/d3d12_compiler.cpp +++ b/src/gallium/drivers/d3d12/d3d12_compiler.cpp @@ -131,40 +131,6 @@ compile_nir(struct d3d12_context *ctx, struct d3d12_shader_selector *sel, shader->has_default_ubo0 = num_uniforms_before_lower_to_ubo > 0 && nir->info.num_ubos > num_ubos_before_lower_to_ubo; - NIR_PASS_V(nir, dxil_nir_lower_subgroup_id); - NIR_PASS_V(nir, dxil_nir_lower_num_subgroups); - - nir_lower_subgroups_options subgroup_options = {}; - subgroup_options.ballot_bit_size = 32; - subgroup_options.ballot_components = 4; - subgroup_options.lower_subgroup_masks = true; - subgroup_options.lower_to_scalar = true; - subgroup_options.lower_relative_shuffle = true; - subgroup_options.lower_inverse_ballot = true; - if (nir->info.stage != MESA_SHADER_FRAGMENT && nir->info.stage != MESA_SHADER_COMPUTE) - subgroup_options.lower_quad = true; - NIR_PASS_V(nir, nir_lower_subgroups, &subgroup_options); - NIR_PASS_V(nir, nir_lower_bit_size, [](const nir_instr *instr, void *) -> unsigned { - if (instr->type != nir_instr_type_intrinsic) - return 0; - nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr); - switch (intr->intrinsic) { - case nir_intrinsic_quad_swap_horizontal: - case nir_intrinsic_quad_swap_vertical: - case nir_intrinsic_quad_swap_diagonal: - case nir_intrinsic_reduce: - case nir_intrinsic_inclusive_scan: - case nir_intrinsic_exclusive_scan: - return intr->def.bit_size == 1 ? 32 : 0; - default: - return 0; - } - }, NULL); - - // Ensure subgroup scans on bools are gone - NIR_PASS_V(nir, nir_opt_dce); - NIR_PASS_V(nir, dxil_nir_lower_unsupported_subgroup_scan); - if (key->last_vertex_processing_stage) { if (key->invert_depth) NIR_PASS_V(nir, d3d12_nir_invert_depth, key->invert_depth, key->halfz); @@ -172,12 +138,11 @@ compile_nir(struct d3d12_context *ctx, struct d3d12_shader_selector *sel, NIR_PASS_V(nir, nir_lower_clip_halfz); NIR_PASS_V(nir, d3d12_lower_yflip); } - NIR_PASS_V(nir, d3d12_lower_load_draw_params); - NIR_PASS_V(nir, d3d12_lower_load_patch_vertices_in); + NIR_PASS_V(nir, d3d12_lower_state_vars, shader); + const struct dxil_nir_lower_loads_stores_options loads_stores_options = {}; NIR_PASS_V(nir, dxil_nir_lower_loads_stores_to_dxil, &loads_stores_options); - NIR_PASS_V(nir, dxil_nir_lower_double_math); if (key->stage == PIPE_SHADER_FRAGMENT && key->fs.multisample_disabled) NIR_PASS_V(nir, d3d12_disable_multisampling); @@ -1520,6 +1485,44 @@ d3d12_create_shader_impl(struct d3d12_context *ctx, * lower integer cube maps to be handled like 2D textures arrays*/ NIR_PASS_V(nir, dxil_nir_lower_int_cubemaps, true); + NIR_PASS_V(nir, dxil_nir_lower_subgroup_id); + NIR_PASS_V(nir, dxil_nir_lower_num_subgroups); + + nir_lower_subgroups_options subgroup_options = {}; + subgroup_options.ballot_bit_size = 32; + subgroup_options.ballot_components = 4; + subgroup_options.lower_subgroup_masks = true; + subgroup_options.lower_to_scalar = true; + subgroup_options.lower_relative_shuffle = true; + subgroup_options.lower_inverse_ballot = true; + if (nir->info.stage != MESA_SHADER_FRAGMENT && nir->info.stage != MESA_SHADER_COMPUTE) + subgroup_options.lower_quad = true; + NIR_PASS_V(nir, nir_lower_subgroups, &subgroup_options); + NIR_PASS_V(nir, nir_lower_bit_size, [](const nir_instr *instr, void *) -> unsigned { + if (instr->type != nir_instr_type_intrinsic) + return 0; + nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr); + switch (intr->intrinsic) { + case nir_intrinsic_quad_swap_horizontal: + case nir_intrinsic_quad_swap_vertical: + case nir_intrinsic_quad_swap_diagonal: + case nir_intrinsic_reduce: + case nir_intrinsic_inclusive_scan: + case nir_intrinsic_exclusive_scan: + return intr->def.bit_size == 1 ? 32 : 0; + default: + return 0; + } + }, NULL); + + // Ensure subgroup scans on bools are gone + NIR_PASS_V(nir, nir_opt_dce); + NIR_PASS_V(nir, dxil_nir_lower_unsupported_subgroup_scan); + + NIR_PASS_V(nir, d3d12_lower_load_draw_params); + NIR_PASS_V(nir, d3d12_lower_load_patch_vertices_in); + NIR_PASS_V(nir, dxil_nir_lower_double_math); + /* Keep this initial shader as the blue print for possible variants */ sel->initial = nir; sel->initial_output_vars = nullptr; diff --git a/src/gallium/drivers/d3d12/d3d12_draw.cpp b/src/gallium/drivers/d3d12/d3d12_draw.cpp index 866f9f81a86..12d034ee821 100644 --- a/src/gallium/drivers/d3d12/d3d12_draw.cpp +++ b/src/gallium/drivers/d3d12/d3d12_draw.cpp @@ -1292,7 +1292,7 @@ update_dispatch_indirect_with_sysvals(struct d3d12_context *ctx, ctx->compute_state == nullptr) return false; - if (!BITSET_TEST(ctx->compute_state->current->nir->info.system_values_read, SYSTEM_VALUE_NUM_WORKGROUPS)) + if (!BITSET_TEST(ctx->compute_state->initial->info.system_values_read, SYSTEM_VALUE_NUM_WORKGROUPS)) return false; if (ctx->current_predication)