diff --git a/src/compiler/spirv/nir_spirv.h b/src/compiler/spirv/nir_spirv.h index 87e9e67c54e..85eeccc2e2c 100644 --- a/src/compiler/spirv/nir_spirv.h +++ b/src/compiler/spirv/nir_spirv.h @@ -122,6 +122,9 @@ struct spirv_to_nir_options { * but continue executing other tests. */ bool skip_os_break_in_debug_build; + + /* Shader index provided by VkPipelineShaderStageNodeCreateInfoAMDX */ + uint32_t shader_index; }; enum spirv_verify_result { diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c index c72b31d03fa..ae96d4e907e 100644 --- a/src/compiler/spirv/spirv_to_nir.c +++ b/src/compiler/spirv/spirv_to_nir.c @@ -4980,6 +4980,10 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode, case SpvCapabilityFragmentBarycentricKHR: spv_check_supported(fragment_barycentric, cap); break; + + case SpvCapabilityShaderEnqueueAMDX: + spv_check_supported(shader_enqueue, cap); + break; default: vtn_fail("Unhandled capability: %s (%u)", @@ -5428,6 +5432,10 @@ vtn_handle_execution_mode(struct vtn_builder *b, struct vtn_value *entry_point, case SpvExecutionModeLocalSizeId: case SpvExecutionModeLocalSizeHintId: case SpvExecutionModeSubgroupsPerWorkgroupId: + case SpvExecutionModeMaxNodeRecursionAMDX: + case SpvExecutionModeStaticNumWorkgroupsAMDX: + case SpvExecutionModeMaxNumWorkgroupsAMDX: + case SpvExecutionModeShaderIndexAMDX: /* Handled later by vtn_handle_execution_mode_id(). */ break; @@ -5483,6 +5491,13 @@ vtn_handle_execution_mode(struct vtn_builder *b, struct vtn_value *entry_point, b->shader->info.fs.stencil_back_layout = FRAG_STENCIL_LAYOUT_UNCHANGED; break; + case SpvExecutionModeCoalescingAMDX: + vtn_assert(b->shader->info.stage == MESA_SHADER_COMPUTE); + b->shader->info.cs.workgroup_count[0] = 1; + b->shader->info.cs.workgroup_count[1] = 1; + b->shader->info.cs.workgroup_count[2] = 1; + break; + default: vtn_fail("Unhandled execution mode: %s (%u)", spirv_executionmode_to_string(mode->exec_mode), @@ -5521,6 +5536,29 @@ vtn_handle_execution_mode_id(struct vtn_builder *b, struct vtn_value *entry_poin b->shader->info.num_subgroups = vtn_constant_uint(b, mode->operands[0]); break; + case SpvExecutionModeMaxNodeRecursionAMDX: + vtn_assert(b->shader->info.stage == MESA_SHADER_COMPUTE); + break; + + case SpvExecutionModeStaticNumWorkgroupsAMDX: + vtn_assert(b->shader->info.stage == MESA_SHADER_COMPUTE); + b->shader->info.cs.workgroup_count[0] = vtn_constant_uint(b, mode->operands[0]); + b->shader->info.cs.workgroup_count[1] = vtn_constant_uint(b, mode->operands[1]); + b->shader->info.cs.workgroup_count[2] = vtn_constant_uint(b, mode->operands[2]); + assert(b->shader->info.cs.workgroup_count[0]); + assert(b->shader->info.cs.workgroup_count[1]); + assert(b->shader->info.cs.workgroup_count[2]); + break; + + case SpvExecutionModeMaxNumWorkgroupsAMDX: + vtn_assert(b->shader->info.stage == MESA_SHADER_COMPUTE); + break; + + case SpvExecutionModeShaderIndexAMDX: + vtn_assert(b->shader->info.stage == MESA_SHADER_COMPUTE); + b->shader->info.cs.shader_index = vtn_constant_uint(b, mode->operands[0]); + break; + default: /* Nothing to do. Literal execution modes already handled by * vtn_handle_execution_mode(). */ @@ -6028,6 +6066,20 @@ vtn_handle_ray_query_intrinsic(struct vtn_builder *b, SpvOp opcode, } } +static void +vtn_handle_initialize_node_payloads(struct vtn_builder *b, SpvOp opcode, + const uint32_t *w, unsigned count) +{ + vtn_assert(opcode == SpvOpInitializeNodePayloadsAMDX); + + nir_def *payloads = vtn_ssa_value(b, w[1])->def; + mesa_scope scope = vtn_translate_scope(b, vtn_constant_uint(b, w[2])); + nir_def *payload_count = vtn_ssa_value(b, w[3])->def; + nir_def *node_index = vtn_ssa_value(b, w[4])->def; + + nir_initialize_node_payloads(&b->nb, payloads, payload_count, node_index, .execution_scope = scope); +} + static bool vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode, const uint32_t *w, unsigned count) @@ -6494,6 +6546,16 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode, &b->nb, vtn_get_nir_ssa(b, w[1]), vtn_get_nir_ssa(b, w[2])); break; + case SpvOpInitializeNodePayloadsAMDX: + vtn_handle_initialize_node_payloads(b, opcode, w, count); + break; + + case SpvOpFinalizeNodePayloadsAMDX: + break; + + case SpvOpFinishWritingNodePayloadAMDX: + break; + default: vtn_fail_with_opcode("Unhandled opcode", opcode); } @@ -6733,6 +6795,7 @@ spirv_to_nir(const uint32_t *words, size_t word_count, b->shader = nir_shader_create(b, stage, nir_options, NULL); b->shader->info.subgroup_size = options->subgroup_size; b->shader->info.float_controls_execution_mode = options->float_controls_execution_mode; + b->shader->info.cs.shader_index = options->shader_index; _mesa_sha1_compute(words, word_count * sizeof(uint32_t), b->shader->info.source_sha1); /* Skip the SPIR-V header, handled at vtn_create_builder */ diff --git a/src/compiler/spirv/vtn_private.h b/src/compiler/spirv/vtn_private.h index db43c9e32f3..830c154a9ef 100644 --- a/src/compiler/spirv/vtn_private.h +++ b/src/compiler/spirv/vtn_private.h @@ -455,6 +455,7 @@ enum vtn_variable_mode { vtn_variable_mode_ray_payload_in, vtn_variable_mode_hit_attrib, vtn_variable_mode_shader_record, + vtn_variable_mode_node_payload, }; struct vtn_pointer { diff --git a/src/compiler/spirv/vtn_variables.c b/src/compiler/spirv/vtn_variables.c index a18d5a1db08..8fb59568da4 100644 --- a/src/compiler/spirv/vtn_variables.c +++ b/src/compiler/spirv/vtn_variables.c @@ -183,6 +183,7 @@ vtn_mode_is_cross_invocation(struct vtn_builder *b, mode == vtn_variable_mode_push_constant || mode == vtn_variable_mode_workgroup || mode == vtn_variable_mode_cross_workgroup || + mode == vtn_variable_mode_node_payload || (cross_invocation_outputs && mode == vtn_variable_mode_output) || (b->shader->info.stage == MESA_SHADER_TASK && mode == vtn_variable_mode_task_payload); } @@ -1198,6 +1199,14 @@ vtn_get_builtin_location(struct vtn_builder *b, *location = SYSTEM_VALUE_BARYCENTRIC_LINEAR_COORD; set_mode_system_value(b, mode); break; + case SpvBuiltInShaderIndexAMDX: + *location = SYSTEM_VALUE_SHADER_INDEX; + set_mode_system_value(b, mode); + break; + case SpvBuiltInCoalescedInputCountAMDX: + *location = SYSTEM_VALUE_COALESCED_INPUT_COUNT; + set_mode_system_value(b, mode); + break; default: vtn_fail("Unsupported builtin: %s (%u)", @@ -1395,6 +1404,27 @@ apply_var_decoration(struct vtn_builder *b, var_data->per_vertex = true; break; + case SpvDecorationNodeMaxPayloadsAMDX: + vtn_fail_if(b->shader->info.stage != MESA_SHADER_COMPUTE, + "NodeMaxPayloadsAMDX decoration only allowed in compute shaders"); + break; + + case SpvDecorationNodeSharesPayloadLimitsWithAMDX: + vtn_fail_if(b->shader->info.stage != MESA_SHADER_COMPUTE, + "NodeMaxPayloadsAMDX decoration only allowed in compute shaders"); + break; + + case SpvDecorationPayloadNodeNameAMDX: + vtn_fail_if(b->shader->info.stage != MESA_SHADER_COMPUTE, + "NodeMaxPayloadsAMDX decoration only allowed in compute shaders"); + var_data->node_name = vtn_string_literal(b, dec->operands, dec->num_operands, NULL); + break; + + case SpvDecorationTrackFinishWritingAMDX: + vtn_fail_if(b->shader->info.stage != MESA_SHADER_COMPUTE, + "NodeMaxPayloadsAMDX decoration only allowed in compute shaders"); + break; + default: vtn_fail_with_decoration("Unhandled decoration", dec->decoration); } @@ -1680,6 +1710,14 @@ vtn_storage_class_to_mode(struct vtn_builder *b, mode = vtn_variable_mode_shader_record; nir_mode = nir_var_mem_constant; break; + case SpvStorageClassNodePayloadAMDX: + mode = vtn_variable_mode_node_payload; + nir_mode = nir_var_mem_node_payload_in; + break; + case SpvStorageClassNodeOutputPayloadAMDX: + mode = vtn_variable_mode_node_payload; + nir_mode = nir_var_mem_node_payload; + break; case SpvStorageClassGeneric: mode = vtn_variable_mode_generic; @@ -1724,6 +1762,7 @@ vtn_mode_to_address_format(struct vtn_builder *b, enum vtn_variable_mode mode) return b->options->constant_addr_format; case vtn_variable_mode_accel_struct: + case vtn_variable_mode_node_payload: return nir_address_format_64bit_global; case vtn_variable_mode_task_payload: @@ -1985,6 +2024,7 @@ vtn_create_variable(struct vtn_builder *b, struct vtn_value *val, case vtn_variable_mode_ray_payload: case vtn_variable_mode_ray_payload_in: case vtn_variable_mode_hit_attrib: + case vtn_variable_mode_node_payload: /* For these, we create the variable normally */ var->var = rzalloc(b->shader, nir_variable); var->var->name = ralloc_strdup(var->var, val->name);