diff --git a/src/microsoft/compiler/nir_to_dxil.c b/src/microsoft/compiler/nir_to_dxil.c index 8d7e797c543..f99cf691205 100644 --- a/src/microsoft/compiler/nir_to_dxil.c +++ b/src/microsoft/compiler/nir_to_dxil.c @@ -1800,7 +1800,7 @@ static const struct dxil_mdnode * emit_wave_size(struct ntd_context *ctx) { const nir_shader *s = ctx->shader; - const struct dxil_mdnode *wave_size_node = dxil_get_metadata_int32(&ctx->mod, s->info.subgroup_size); + const struct dxil_mdnode *wave_size_node = dxil_get_metadata_int32(&ctx->mod, s->info.min_subgroup_size); return dxil_get_metadata_node(&ctx->mod, &wave_size_node, 1); } @@ -1809,7 +1809,7 @@ emit_wave_size_range(struct ntd_context *ctx) { const nir_shader *s = ctx->shader; const struct dxil_mdnode *wave_size_nodes[3]; - wave_size_nodes[0] = dxil_get_metadata_int32(&ctx->mod, s->info.subgroup_size); + wave_size_nodes[0] = dxil_get_metadata_int32(&ctx->mod, s->info.min_subgroup_size); wave_size_nodes[1] = wave_size_nodes[0]; wave_size_nodes[2] = wave_size_nodes[0]; return dxil_get_metadata_node(&ctx->mod, wave_size_nodes, ARRAY_SIZE(wave_size_nodes)); @@ -2040,7 +2040,9 @@ emit_metadata(struct ntd_context *ctx) if (!emit_tag(ctx, DXIL_SHADER_TAG_NUM_THREADS, emit_threads(ctx))) return false; if (ctx->mod.minor_version >= 6 && - ctx->shader->info.subgroup_size >= SUBGROUP_SIZE_REQUIRE_4) { + ctx->shader->info.min_subgroup_size == ctx->shader->info.max_subgroup_size && + ctx->shader->info.min_subgroup_size == ctx->shader->info.api_subgroup_size && + ctx->shader->info.min_subgroup_size > 1) { if (ctx->mod.minor_version < 8) { if (!emit_tag(ctx, DXIL_SHADER_TAG_WAVE_SIZE, emit_wave_size(ctx))) return false; @@ -6370,12 +6372,7 @@ void dxil_fill_validation_state(struct ntd_context *ctx, struct dxil_psv_runtime_info_2 *psv2 = &psv3->psv2; struct dxil_psv_runtime_info_1 *psv1 = &psv2->psv1; struct dxil_psv_runtime_info_0 *psv0 = &psv1->psv0; - if (ctx->shader->info.subgroup_size >= SUBGROUP_SIZE_REQUIRE_4) { - psv0->max_expected_wave_lane_count = ctx->shader->info.subgroup_size; - psv0->min_expected_wave_lane_count = ctx->shader->info.subgroup_size; - } else { - psv0->max_expected_wave_lane_count = UINT_MAX; - } + psv0->max_expected_wave_lane_count = UINT_MAX; psv1->shader_stage = (uint8_t)ctx->mod.shader_kind; psv1->uses_view_id = (uint8_t)ctx->mod.feats.view_id; psv1->sig_input_elements = (uint8_t)ctx->mod.num_sig_inputs; @@ -6395,6 +6392,12 @@ void dxil_fill_validation_state(struct ntd_context *ctx, psv2->num_threads_x = MAX2(ctx->shader->info.workgroup_size[0], 1); psv2->num_threads_y = MAX2(ctx->shader->info.workgroup_size[1], 1); psv2->num_threads_z = MAX2(ctx->shader->info.workgroup_size[2], 1); + if (ctx->shader->info.min_subgroup_size == ctx->shader->info.max_subgroup_size && + ctx->shader->info.min_subgroup_size == ctx->shader->info.api_subgroup_size && + ctx->shader->info.min_subgroup_size > 1) { + psv0->max_expected_wave_lane_count = ctx->shader->info.min_subgroup_size; + psv0->min_expected_wave_lane_count = ctx->shader->info.min_subgroup_size; + } break; case DXIL_GEOMETRY_SHADER: psv1->max_vertex_count = ctx->shader->info.gs.vertices_out; diff --git a/src/microsoft/vulkan/dzn_pipeline.c b/src/microsoft/vulkan/dzn_pipeline.c index 638d3883a67..2316e461b05 100644 --- a/src/microsoft/vulkan/dzn_pipeline.c +++ b/src/microsoft/vulkan/dzn_pipeline.c @@ -858,14 +858,13 @@ dzn_graphics_pipeline_compile_shaders(struct dzn_device *device, _mesa_sha1_update(&pipeline_hash_ctx, &pipeline->use_gs_for_polygon_mode_point, sizeof(pipeline->use_gs_for_polygon_mode_point)); u_foreach_bit(stage, active_stage_mask) { - const VkPipelineShaderStageRequiredSubgroupSizeCreateInfo *subgroup_size = + const VkPipelineShaderStageRequiredSubgroupSizeCreateInfo *subgroup_size_info = (const VkPipelineShaderStageRequiredSubgroupSizeCreateInfo *) vk_find_struct_const(stages[stage].info->pNext, PIPELINE_SHADER_STAGE_REQUIRED_SUBGROUP_SIZE_CREATE_INFO); - enum gl_subgroup_size subgroup_enum = subgroup_size && subgroup_size->requiredSubgroupSize >= 8 ? - subgroup_size->requiredSubgroupSize : SUBGROUP_SIZE_FULL_SUBGROUPS; + uint8_t subgroup_size = subgroup_size_info ? subgroup_size_info->requiredSubgroupSize : 0; vk_pipeline_hash_shader_stage(pipeline->base.flags, stages[stage].info, NULL, stages[stage].spirv_hash); - _mesa_sha1_update(&pipeline_hash_ctx, &subgroup_enum, sizeof(subgroup_enum)); + _mesa_sha1_update(&pipeline_hash_ctx, &subgroup_size, sizeof(subgroup_size)); _mesa_sha1_update(&pipeline_hash_ctx, stages[stage].spirv_hash, sizeof(stages[stage].spirv_hash)); _mesa_sha1_update(&pipeline_hash_ctx, layout->stages[stage].hash, sizeof(layout->stages[stage].hash)); }