spirv: Implement SPV_AMDX_shader_enqueue

Reviewed-by: Mike Blumenkrantz <michael.blumenkrantz@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/24512>
This commit is contained in:
Konstantin Seurer 2023-08-01 14:37:28 +02:00 committed by Marge Bot
parent 289df72d10
commit 2489f7d84f
4 changed files with 107 additions and 0 deletions

View file

@ -122,6 +122,9 @@ struct spirv_to_nir_options {
* but continue executing other tests.
*/
bool skip_os_break_in_debug_build;
/* Shader index provided by VkPipelineShaderStageNodeCreateInfoAMDX */
uint32_t shader_index;
};
enum spirv_verify_result {

View file

@ -4980,6 +4980,10 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode,
case SpvCapabilityFragmentBarycentricKHR:
spv_check_supported(fragment_barycentric, cap);
break;
case SpvCapabilityShaderEnqueueAMDX:
spv_check_supported(shader_enqueue, cap);
break;
default:
vtn_fail("Unhandled capability: %s (%u)",
@ -5428,6 +5432,10 @@ vtn_handle_execution_mode(struct vtn_builder *b, struct vtn_value *entry_point,
case SpvExecutionModeLocalSizeId:
case SpvExecutionModeLocalSizeHintId:
case SpvExecutionModeSubgroupsPerWorkgroupId:
case SpvExecutionModeMaxNodeRecursionAMDX:
case SpvExecutionModeStaticNumWorkgroupsAMDX:
case SpvExecutionModeMaxNumWorkgroupsAMDX:
case SpvExecutionModeShaderIndexAMDX:
/* Handled later by vtn_handle_execution_mode_id(). */
break;
@ -5483,6 +5491,13 @@ vtn_handle_execution_mode(struct vtn_builder *b, struct vtn_value *entry_point,
b->shader->info.fs.stencil_back_layout = FRAG_STENCIL_LAYOUT_UNCHANGED;
break;
case SpvExecutionModeCoalescingAMDX:
vtn_assert(b->shader->info.stage == MESA_SHADER_COMPUTE);
b->shader->info.cs.workgroup_count[0] = 1;
b->shader->info.cs.workgroup_count[1] = 1;
b->shader->info.cs.workgroup_count[2] = 1;
break;
default:
vtn_fail("Unhandled execution mode: %s (%u)",
spirv_executionmode_to_string(mode->exec_mode),
@ -5521,6 +5536,29 @@ vtn_handle_execution_mode_id(struct vtn_builder *b, struct vtn_value *entry_poin
b->shader->info.num_subgroups = vtn_constant_uint(b, mode->operands[0]);
break;
case SpvExecutionModeMaxNodeRecursionAMDX:
vtn_assert(b->shader->info.stage == MESA_SHADER_COMPUTE);
break;
case SpvExecutionModeStaticNumWorkgroupsAMDX:
vtn_assert(b->shader->info.stage == MESA_SHADER_COMPUTE);
b->shader->info.cs.workgroup_count[0] = vtn_constant_uint(b, mode->operands[0]);
b->shader->info.cs.workgroup_count[1] = vtn_constant_uint(b, mode->operands[1]);
b->shader->info.cs.workgroup_count[2] = vtn_constant_uint(b, mode->operands[2]);
assert(b->shader->info.cs.workgroup_count[0]);
assert(b->shader->info.cs.workgroup_count[1]);
assert(b->shader->info.cs.workgroup_count[2]);
break;
case SpvExecutionModeMaxNumWorkgroupsAMDX:
vtn_assert(b->shader->info.stage == MESA_SHADER_COMPUTE);
break;
case SpvExecutionModeShaderIndexAMDX:
vtn_assert(b->shader->info.stage == MESA_SHADER_COMPUTE);
b->shader->info.cs.shader_index = vtn_constant_uint(b, mode->operands[0]);
break;
default:
/* Nothing to do. Literal execution modes already handled by
* vtn_handle_execution_mode(). */
@ -6028,6 +6066,20 @@ vtn_handle_ray_query_intrinsic(struct vtn_builder *b, SpvOp opcode,
}
}
static void
vtn_handle_initialize_node_payloads(struct vtn_builder *b, SpvOp opcode,
const uint32_t *w, unsigned count)
{
vtn_assert(opcode == SpvOpInitializeNodePayloadsAMDX);
nir_def *payloads = vtn_ssa_value(b, w[1])->def;
mesa_scope scope = vtn_translate_scope(b, vtn_constant_uint(b, w[2]));
nir_def *payload_count = vtn_ssa_value(b, w[3])->def;
nir_def *node_index = vtn_ssa_value(b, w[4])->def;
nir_initialize_node_payloads(&b->nb, payloads, payload_count, node_index, .execution_scope = scope);
}
static bool
vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode,
const uint32_t *w, unsigned count)
@ -6494,6 +6546,16 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode,
&b->nb, vtn_get_nir_ssa(b, w[1]), vtn_get_nir_ssa(b, w[2]));
break;
case SpvOpInitializeNodePayloadsAMDX:
vtn_handle_initialize_node_payloads(b, opcode, w, count);
break;
case SpvOpFinalizeNodePayloadsAMDX:
break;
case SpvOpFinishWritingNodePayloadAMDX:
break;
default:
vtn_fail_with_opcode("Unhandled opcode", opcode);
}
@ -6733,6 +6795,7 @@ spirv_to_nir(const uint32_t *words, size_t word_count,
b->shader = nir_shader_create(b, stage, nir_options, NULL);
b->shader->info.subgroup_size = options->subgroup_size;
b->shader->info.float_controls_execution_mode = options->float_controls_execution_mode;
b->shader->info.cs.shader_index = options->shader_index;
_mesa_sha1_compute(words, word_count * sizeof(uint32_t), b->shader->info.source_sha1);
/* Skip the SPIR-V header, handled at vtn_create_builder */

View file

@ -455,6 +455,7 @@ enum vtn_variable_mode {
vtn_variable_mode_ray_payload_in,
vtn_variable_mode_hit_attrib,
vtn_variable_mode_shader_record,
vtn_variable_mode_node_payload,
};
struct vtn_pointer {

View file

@ -183,6 +183,7 @@ vtn_mode_is_cross_invocation(struct vtn_builder *b,
mode == vtn_variable_mode_push_constant ||
mode == vtn_variable_mode_workgroup ||
mode == vtn_variable_mode_cross_workgroup ||
mode == vtn_variable_mode_node_payload ||
(cross_invocation_outputs && mode == vtn_variable_mode_output) ||
(b->shader->info.stage == MESA_SHADER_TASK && mode == vtn_variable_mode_task_payload);
}
@ -1198,6 +1199,14 @@ vtn_get_builtin_location(struct vtn_builder *b,
*location = SYSTEM_VALUE_BARYCENTRIC_LINEAR_COORD;
set_mode_system_value(b, mode);
break;
case SpvBuiltInShaderIndexAMDX:
*location = SYSTEM_VALUE_SHADER_INDEX;
set_mode_system_value(b, mode);
break;
case SpvBuiltInCoalescedInputCountAMDX:
*location = SYSTEM_VALUE_COALESCED_INPUT_COUNT;
set_mode_system_value(b, mode);
break;
default:
vtn_fail("Unsupported builtin: %s (%u)",
@ -1395,6 +1404,27 @@ apply_var_decoration(struct vtn_builder *b,
var_data->per_vertex = true;
break;
case SpvDecorationNodeMaxPayloadsAMDX:
vtn_fail_if(b->shader->info.stage != MESA_SHADER_COMPUTE,
"NodeMaxPayloadsAMDX decoration only allowed in compute shaders");
break;
case SpvDecorationNodeSharesPayloadLimitsWithAMDX:
vtn_fail_if(b->shader->info.stage != MESA_SHADER_COMPUTE,
"NodeMaxPayloadsAMDX decoration only allowed in compute shaders");
break;
case SpvDecorationPayloadNodeNameAMDX:
vtn_fail_if(b->shader->info.stage != MESA_SHADER_COMPUTE,
"NodeMaxPayloadsAMDX decoration only allowed in compute shaders");
var_data->node_name = vtn_string_literal(b, dec->operands, dec->num_operands, NULL);
break;
case SpvDecorationTrackFinishWritingAMDX:
vtn_fail_if(b->shader->info.stage != MESA_SHADER_COMPUTE,
"NodeMaxPayloadsAMDX decoration only allowed in compute shaders");
break;
default:
vtn_fail_with_decoration("Unhandled decoration", dec->decoration);
}
@ -1680,6 +1710,14 @@ vtn_storage_class_to_mode(struct vtn_builder *b,
mode = vtn_variable_mode_shader_record;
nir_mode = nir_var_mem_constant;
break;
case SpvStorageClassNodePayloadAMDX:
mode = vtn_variable_mode_node_payload;
nir_mode = nir_var_mem_node_payload_in;
break;
case SpvStorageClassNodeOutputPayloadAMDX:
mode = vtn_variable_mode_node_payload;
nir_mode = nir_var_mem_node_payload;
break;
case SpvStorageClassGeneric:
mode = vtn_variable_mode_generic;
@ -1724,6 +1762,7 @@ vtn_mode_to_address_format(struct vtn_builder *b, enum vtn_variable_mode mode)
return b->options->constant_addr_format;
case vtn_variable_mode_accel_struct:
case vtn_variable_mode_node_payload:
return nir_address_format_64bit_global;
case vtn_variable_mode_task_payload:
@ -1985,6 +2024,7 @@ vtn_create_variable(struct vtn_builder *b, struct vtn_value *val,
case vtn_variable_mode_ray_payload:
case vtn_variable_mode_ray_payload_in:
case vtn_variable_mode_hit_attrib:
case vtn_variable_mode_node_payload:
/* For these, we create the variable normally */
var->var = rzalloc(b->shader, nir_variable);
var->var->name = ralloc_strdup(var->var, val->name);