spirv: use base type instead of bit size to determine fp_math_ctrl

Reviewed-by: Alyssa Rosenzweig <alyssa.rosenzweig@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/39460>
This commit is contained in:
Georg Lehmann 2026-01-22 17:59:30 +01:00 committed by Marge Bot
parent 565f37b98c
commit 46a617884e
3 changed files with 59 additions and 49 deletions

View file

@ -5734,9 +5734,15 @@ vtn_handle_execution_mode(struct vtn_builder *b, struct vtn_value *entry_point,
b->shader->info.fs.sample_interlock_unordered = true;
break;
case SpvExecutionModeSignedZeroInfNanPreserve: {
const struct glsl_type *type = glsl_floatN_t_type(mode->operands[0]);
unsigned *fp_math_ctrl = vtn_fp_math_ctrl_for_base_type(b, glsl_get_base_type(type));
*fp_math_ctrl |= nir_fp_preserve_sz_inf_nan;
break;
}
case SpvExecutionModeDenormPreserve:
case SpvExecutionModeDenormFlushToZero:
case SpvExecutionModeSignedZeroInfNanPreserve:
case SpvExecutionModeRoundingModeRTE:
case SpvExecutionModeRoundingModeRTZ: {
unsigned execution_mode = 0;
@ -5757,14 +5763,6 @@ vtn_handle_execution_mode(struct vtn_builder *b, struct vtn_value *entry_point,
default: vtn_fail("Floating point type not supported");
}
break;
case SpvExecutionModeSignedZeroInfNanPreserve:
switch (mode->operands[0]) {
case 16: b->fp_math_ctrl_fp16 |= nir_fp_preserve_sz_inf_nan; break;
case 32: b->fp_math_ctrl_fp32 |= nir_fp_preserve_sz_inf_nan; break;
case 64: b->fp_math_ctrl_fp64 |= nir_fp_preserve_sz_inf_nan; break;
default: vtn_fail("Floating point type not supported");
}
break;
case SpvExecutionModeRoundingModeRTE:
switch (mode->operands[0]) {
case 16: execution_mode = FLOAT_CONTROLS_ROUNDING_MODE_RTE_FP16; break;
@ -5925,6 +5923,12 @@ vtn_handle_execution_mode_id(struct vtn_builder *b, struct vtn_value *entry_poin
struct vtn_type *type = vtn_get_type(b, mode->operands[0]);
SpvFPFastMathModeMask flags = vtn_constant_uint(b, mode->operands[1]);
enum glsl_base_type base_type = glsl_get_base_type(type->type);
unsigned *fp_math_ctrl = vtn_fp_math_ctrl_for_base_type(b, base_type);
if (!fp_math_ctrl)
vtn_fail("Unkown float type for FPFastMathDefault");
SpvFPFastMathModeMask can_fast_math =
SpvFPFastMathModeAllowRecipMask |
SpvFPFastMathModeAllowContractMask |
@ -5933,27 +5937,15 @@ vtn_handle_execution_mode_id(struct vtn_builder *b, struct vtn_value *entry_poin
if ((flags & can_fast_math) != can_fast_math)
b->exact = true;
if (!(flags & SpvFPFastMathModeNotNaNMask)) {
switch (glsl_get_bit_size(type->type)) {
case 16: b->fp_math_ctrl_fp16 |= nir_fp_preserve_nan; break;
case 32: b->fp_math_ctrl_fp32 |= nir_fp_preserve_nan; break;
case 64: b->fp_math_ctrl_fp64 |= nir_fp_preserve_nan; break;
}
}
if (!(flags & SpvFPFastMathModeNotInfMask)) {
switch (glsl_get_bit_size(type->type)) {
case 16: b->fp_math_ctrl_fp16 |= nir_fp_preserve_inf; break;
case 32: b->fp_math_ctrl_fp32 |= nir_fp_preserve_inf; break;
case 64: b->fp_math_ctrl_fp64 |= nir_fp_preserve_inf; break;
}
}
if (!(flags & SpvFPFastMathModeNSZMask)) {
switch (glsl_get_bit_size(type->type)) {
case 16: b->fp_math_ctrl_fp16 |= nir_fp_preserve_signed_zero; break;
case 32: b->fp_math_ctrl_fp32 |= nir_fp_preserve_signed_zero; break;
case 64: b->fp_math_ctrl_fp64 |= nir_fp_preserve_signed_zero; break;
}
}
if (!(flags & SpvFPFastMathModeNotNaNMask))
*fp_math_ctrl |= nir_fp_preserve_nan;
if (!(flags & SpvFPFastMathModeNotInfMask))
*fp_math_ctrl |= nir_fp_preserve_inf;
if (!(flags & SpvFPFastMathModeNSZMask))
*fp_math_ctrl |= nir_fp_preserve_signed_zero;
break;
}

View file

@ -425,6 +425,39 @@ handle_fp_fast_math(struct vtn_builder *b, UNUSED struct vtn_value *val,
b->nb.fp_math_ctrl |= nir_fp_preserve_inf;
}
unsigned *
vtn_fp_math_ctrl_for_base_type(struct vtn_builder *b, enum glsl_base_type base_type)
{
switch (base_type) {
case GLSL_TYPE_FLOAT16: return &b->fp_math_ctrl[0];
case GLSL_TYPE_FLOAT: return &b->fp_math_ctrl[1];
case GLSL_TYPE_DOUBLE: return &b->fp_math_ctrl[2];
case GLSL_TYPE_BFLOAT16: return &b->fp_math_ctrl[3];
case GLSL_TYPE_FLOAT_E4M3FN: return &b->fp_math_ctrl[4];
case GLSL_TYPE_FLOAT_E5M2: return &b->fp_math_ctrl[5];
default: return NULL;
}
}
static unsigned
fp_math_ctrl_for_type(struct vtn_builder *b, struct vtn_type *type)
{
if (!type)
return nir_fp_fast_math;
enum glsl_base_type base_type;
/* Some ALU like modf and frexp return a struct of two values. */
if (glsl_type_is_struct(type->type))
base_type = glsl_get_base_type(type->type->fields.structure[0].type);
else
base_type = glsl_get_base_type(type->type);
unsigned *fp_math_ctrl = vtn_fp_math_ctrl_for_base_type(b, base_type);
return fp_math_ctrl ? *fp_math_ctrl : nir_fp_fast_math;
}
void
vtn_handle_fp_fast_math(struct vtn_builder *b, struct vtn_value *val)
{
@ -432,23 +465,8 @@ vtn_handle_fp_fast_math(struct vtn_builder *b, struct vtn_value *val)
* on the builder, so the generated instructions can take it from it.
* We only care about some of them, check nir_alu_instr for details.
*/
unsigned bit_size;
/* Some ALU like modf and frexp return a struct of two values. */
if (!val->type)
bit_size = 0;
else if (glsl_type_is_struct(val->type->type))
bit_size = glsl_get_bit_size(val->type->type->fields.structure[0].type);
else
bit_size = glsl_get_bit_size(val->type->type);
switch (bit_size) {
case 16: b->nb.fp_math_ctrl = b->fp_math_ctrl_fp16; break;
case 32: b->nb.fp_math_ctrl = b->fp_math_ctrl_fp32; break;
case 64: b->nb.fp_math_ctrl = b->fp_math_ctrl_fp64; break;
default: b->nb.fp_math_ctrl = 0; break;
}
b->nb.fp_math_ctrl = fp_math_ctrl_for_type(b, val->type);
vtn_foreach_decoration(b, val, handle_fp_fast_math, NULL);

View file

@ -710,9 +710,7 @@ struct vtn_builder {
/* false by default, set to true by the ContractionOff execution mode */
bool exact;
unsigned fp_math_ctrl_fp16;
unsigned fp_math_ctrl_fp32;
unsigned fp_math_ctrl_fp64;
unsigned fp_math_ctrl[6];
/* when a physical memory model is choosen */
bool physical_ptrs;
@ -992,6 +990,8 @@ void vtn_handle_integer_dot(struct vtn_builder *b, SpvOp opcode,
void vtn_handle_bitcast(struct vtn_builder *b, const uint32_t *w,
unsigned count);
unsigned *vtn_fp_math_ctrl_for_base_type(struct vtn_builder *b, enum glsl_base_type base_type);
void vtn_handle_fp_fast_math(struct vtn_builder *b, struct vtn_value *val);
void vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,