diff --git a/src/compiler/shader_info.h b/src/compiler/shader_info.h index 92d65f92133..237d7d6d82b 100644 --- a/src/compiler/shader_info.h +++ b/src/compiler/shader_info.h @@ -56,6 +56,7 @@ struct spirv_supported_capabilities { bool device_group; bool draw_parameters; bool float_controls; + bool float_controls2; bool float16_atomic_add; bool float16_atomic_min_max; bool float16; diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c index f2c9ac300d9..bb8829af1a3 100644 --- a/src/compiler/spirv/spirv_to_nir.c +++ b/src/compiler/spirv/spirv_to_nir.c @@ -4902,6 +4902,10 @@ vtn_handle_preamble_instruction(struct vtn_builder *b, SpvOp opcode, spv_check_supported(float_controls, cap); break; + case SpvCapabilityFloatControls2: + spv_check_supported(float_controls2, cap); + break; + case SpvCapabilityPhysicalStorageBufferAddresses: spv_check_supported(physical_storage_buffer_address, cap); break; @@ -5532,6 +5536,7 @@ vtn_handle_execution_mode(struct vtn_builder *b, struct vtn_value *entry_point, case SpvExecutionModeLocalSizeId: case SpvExecutionModeLocalSizeHintId: case SpvExecutionModeSubgroupsPerWorkgroupId: + case SpvExecutionModeFPFastMathDefault: case SpvExecutionModeMaxNodeRecursionAMDX: case SpvExecutionModeStaticNumWorkgroupsAMDX: case SpvExecutionModeMaxNumWorkgroupsAMDX: @@ -5646,6 +5651,44 @@ vtn_handle_execution_mode_id(struct vtn_builder *b, struct vtn_value *entry_poin b->shader->info.num_subgroups = vtn_constant_uint(b, mode->operands[0]); break; + case SpvExecutionModeFPFastMathDefault: { + struct vtn_type *type = vtn_get_type(b, mode->operands[0]); + SpvFPFastMathModeMask flags = vtn_constant_uint(b, mode->operands[1]); + + SpvFPFastMathModeMask can_fast_math = + SpvFPFastMathModeAllowRecipMask | + SpvFPFastMathModeAllowContractMask | + SpvFPFastMathModeAllowReassocMask | + SpvFPFastMathModeAllowTransformMask; + if ((flags & can_fast_math) != can_fast_math) + b->exact = true; + + unsigned execution_mode = 0; + if (!(flags & SpvFPFastMathModeNotNaNMask)) { + switch (glsl_get_bit_size(type->type)) { + case 16: execution_mode |= FLOAT_CONTROLS_NAN_PRESERVE_FP16; break; + case 32: execution_mode |= FLOAT_CONTROLS_NAN_PRESERVE_FP32; break; + case 64: execution_mode |= FLOAT_CONTROLS_NAN_PRESERVE_FP64; break; + } + } + if (!(flags & SpvFPFastMathModeNotInfMask)) { + switch (glsl_get_bit_size(type->type)) { + case 16: execution_mode |= FLOAT_CONTROLS_INF_PRESERVE_FP16; break; + case 32: execution_mode |= FLOAT_CONTROLS_INF_PRESERVE_FP32; break; + case 64: execution_mode |= FLOAT_CONTROLS_INF_PRESERVE_FP64; break; + } + } + if (!(flags & SpvFPFastMathModeNSZMask)) { + switch (glsl_get_bit_size(type->type)) { + case 16: execution_mode |= FLOAT_CONTROLS_SIGNED_ZERO_PRESERVE_FP16; break; + case 32: execution_mode |= FLOAT_CONTROLS_SIGNED_ZERO_PRESERVE_FP32; break; + case 64: execution_mode |= FLOAT_CONTROLS_SIGNED_ZERO_PRESERVE_FP64; break; + } + } + b->shader->info.float_controls_execution_mode |= execution_mode; + break; + } + case SpvExecutionModeMaxNodeRecursionAMDX: vtn_assert(b->shader->info.stage == MESA_SHADER_COMPUTE); break; diff --git a/src/compiler/spirv/vtn_alu.c b/src/compiler/spirv/vtn_alu.c index 0f039626d7c..13e357952b7 100644 --- a/src/compiler/spirv/vtn_alu.c +++ b/src/compiler/spirv/vtn_alu.c @@ -402,6 +402,43 @@ vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b, } } +static void +handle_fp_fast_math(struct vtn_builder *b, UNUSED struct vtn_value *val, + UNUSED int member, const struct vtn_decoration *dec, + UNUSED void *_void) +{ + vtn_assert(dec->scope == VTN_DEC_DECORATION); + if (dec->decoration != SpvDecorationFPFastMathMode) + return; + + SpvFPFastMathModeMask can_fast_math = + SpvFPFastMathModeAllowRecipMask | + SpvFPFastMathModeAllowContractMask | + SpvFPFastMathModeAllowReassocMask | + SpvFPFastMathModeAllowTransformMask; + + if ((dec->operands[0] & can_fast_math) != can_fast_math) + b->nb.exact = true; + + /* Decoration overrides defaults */ + b->nb.fp_fast_math = 0; + if (!(dec->operands[0] & SpvFPFastMathModeNSZMask)) + b->nb.fp_fast_math |= + FLOAT_CONTROLS_SIGNED_ZERO_PRESERVE_FP16 | + FLOAT_CONTROLS_SIGNED_ZERO_PRESERVE_FP32 | + FLOAT_CONTROLS_SIGNED_ZERO_PRESERVE_FP64; + if (!(dec->operands[0] & SpvFPFastMathModeNotNaNMask)) + b->nb.fp_fast_math |= + FLOAT_CONTROLS_NAN_PRESERVE_FP16 | + FLOAT_CONTROLS_NAN_PRESERVE_FP32 | + FLOAT_CONTROLS_NAN_PRESERVE_FP64; + if (!(dec->operands[0] & SpvFPFastMathModeNotInfMask)) + b->nb.fp_fast_math |= + FLOAT_CONTROLS_INF_PRESERVE_FP16 | + FLOAT_CONTROLS_INF_PRESERVE_FP32 | + FLOAT_CONTROLS_INF_PRESERVE_FP64; +} + void vtn_handle_fp_fast_math(struct vtn_builder *b, struct vtn_value *val) { @@ -418,6 +455,7 @@ vtn_handle_fp_fast_math(struct vtn_builder *b, struct vtn_value *val) "enum float_controls and fp_fast_math out of sync!"); b->nb.fp_fast_math = b->shader->info.float_controls_execution_mode & FLOAT_CONTROLS2_BITS; + vtn_foreach_decoration(b, val, handle_fp_fast_math, NULL); #undef FLOAT_CONTROLS2_BITS }