microsoft: switch to new subgroup size info

Reviewed-by: Jesse Natalie <jenatali@microsoft.com>
Acked-by: Timur Kristóf <timur.kristof@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/37258>
This commit is contained in:
Georg Lehmann 2025-09-09 19:28:02 +02:00 committed by Marge Bot
parent 8d7b1498cc
commit 89adefec64
2 changed files with 15 additions and 13 deletions

View file

@ -1800,7 +1800,7 @@ static const struct dxil_mdnode *
emit_wave_size(struct ntd_context *ctx)
{
const nir_shader *s = ctx->shader;
const struct dxil_mdnode *wave_size_node = dxil_get_metadata_int32(&ctx->mod, s->info.subgroup_size);
const struct dxil_mdnode *wave_size_node = dxil_get_metadata_int32(&ctx->mod, s->info.min_subgroup_size);
return dxil_get_metadata_node(&ctx->mod, &wave_size_node, 1);
}
@ -1809,7 +1809,7 @@ emit_wave_size_range(struct ntd_context *ctx)
{
const nir_shader *s = ctx->shader;
const struct dxil_mdnode *wave_size_nodes[3];
wave_size_nodes[0] = dxil_get_metadata_int32(&ctx->mod, s->info.subgroup_size);
wave_size_nodes[0] = dxil_get_metadata_int32(&ctx->mod, s->info.min_subgroup_size);
wave_size_nodes[1] = wave_size_nodes[0];
wave_size_nodes[2] = wave_size_nodes[0];
return dxil_get_metadata_node(&ctx->mod, wave_size_nodes, ARRAY_SIZE(wave_size_nodes));
@ -2040,7 +2040,9 @@ emit_metadata(struct ntd_context *ctx)
if (!emit_tag(ctx, DXIL_SHADER_TAG_NUM_THREADS, emit_threads(ctx)))
return false;
if (ctx->mod.minor_version >= 6 &&
ctx->shader->info.subgroup_size >= SUBGROUP_SIZE_REQUIRE_4) {
ctx->shader->info.min_subgroup_size == ctx->shader->info.max_subgroup_size &&
ctx->shader->info.min_subgroup_size == ctx->shader->info.api_subgroup_size &&
ctx->shader->info.min_subgroup_size > 1) {
if (ctx->mod.minor_version < 8) {
if (!emit_tag(ctx, DXIL_SHADER_TAG_WAVE_SIZE, emit_wave_size(ctx)))
return false;
@ -6370,12 +6372,7 @@ void dxil_fill_validation_state(struct ntd_context *ctx,
struct dxil_psv_runtime_info_2 *psv2 = &psv3->psv2;
struct dxil_psv_runtime_info_1 *psv1 = &psv2->psv1;
struct dxil_psv_runtime_info_0 *psv0 = &psv1->psv0;
if (ctx->shader->info.subgroup_size >= SUBGROUP_SIZE_REQUIRE_4) {
psv0->max_expected_wave_lane_count = ctx->shader->info.subgroup_size;
psv0->min_expected_wave_lane_count = ctx->shader->info.subgroup_size;
} else {
psv0->max_expected_wave_lane_count = UINT_MAX;
}
psv0->max_expected_wave_lane_count = UINT_MAX;
psv1->shader_stage = (uint8_t)ctx->mod.shader_kind;
psv1->uses_view_id = (uint8_t)ctx->mod.feats.view_id;
psv1->sig_input_elements = (uint8_t)ctx->mod.num_sig_inputs;
@ -6395,6 +6392,12 @@ void dxil_fill_validation_state(struct ntd_context *ctx,
psv2->num_threads_x = MAX2(ctx->shader->info.workgroup_size[0], 1);
psv2->num_threads_y = MAX2(ctx->shader->info.workgroup_size[1], 1);
psv2->num_threads_z = MAX2(ctx->shader->info.workgroup_size[2], 1);
if (ctx->shader->info.min_subgroup_size == ctx->shader->info.max_subgroup_size &&
ctx->shader->info.min_subgroup_size == ctx->shader->info.api_subgroup_size &&
ctx->shader->info.min_subgroup_size > 1) {
psv0->max_expected_wave_lane_count = ctx->shader->info.min_subgroup_size;
psv0->min_expected_wave_lane_count = ctx->shader->info.min_subgroup_size;
}
break;
case DXIL_GEOMETRY_SHADER:
psv1->max_vertex_count = ctx->shader->info.gs.vertices_out;

View file

@ -858,14 +858,13 @@ dzn_graphics_pipeline_compile_shaders(struct dzn_device *device,
_mesa_sha1_update(&pipeline_hash_ctx, &pipeline->use_gs_for_polygon_mode_point, sizeof(pipeline->use_gs_for_polygon_mode_point));
u_foreach_bit(stage, active_stage_mask) {
const VkPipelineShaderStageRequiredSubgroupSizeCreateInfo *subgroup_size =
const VkPipelineShaderStageRequiredSubgroupSizeCreateInfo *subgroup_size_info =
(const VkPipelineShaderStageRequiredSubgroupSizeCreateInfo *)
vk_find_struct_const(stages[stage].info->pNext, PIPELINE_SHADER_STAGE_REQUIRED_SUBGROUP_SIZE_CREATE_INFO);
enum gl_subgroup_size subgroup_enum = subgroup_size && subgroup_size->requiredSubgroupSize >= 8 ?
subgroup_size->requiredSubgroupSize : SUBGROUP_SIZE_FULL_SUBGROUPS;
uint8_t subgroup_size = subgroup_size_info ? subgroup_size_info->requiredSubgroupSize : 0;
vk_pipeline_hash_shader_stage(pipeline->base.flags, stages[stage].info, NULL, stages[stage].spirv_hash);
_mesa_sha1_update(&pipeline_hash_ctx, &subgroup_enum, sizeof(subgroup_enum));
_mesa_sha1_update(&pipeline_hash_ctx, &subgroup_size, sizeof(subgroup_size));
_mesa_sha1_update(&pipeline_hash_ctx, stages[stage].spirv_hash, sizeof(stages[stage].spirv_hash));
_mesa_sha1_update(&pipeline_hash_ctx, layout->stages[stage].hash, sizeof(layout->stages[stage].hash));
}