nir/opcodes: remove valid_fp_math_ctrl bits from some opcodes

This is mostly about conversions.

Conversions from float to int don't care about signed zero
and in the case of plain f2u/f2i, nan and inf are always
undefined too.

Conversions for int to float can't create nan, so they don't
need preserve_nan.

b2f only cares about preserve_sz, and nothing else.

Reviewed-by: Alyssa Rosenzweig <alyssa.rosenzweig@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/39966>
This commit is contained in:
Georg Lehmann 2026-02-18 16:04:58 +01:00 committed by Marge Bot
parent 62f3be87c4
commit 5e544ecd08

View file

@ -204,9 +204,9 @@ def unop(name, ty, const_expr, description = "", algebraic_properties = ""):
description)
def unop_horiz(name, output_size, output_type, input_size, input_type,
const_expr, description = ""):
const_expr, description = "", valid_fp_math_ctrl = None):
opcode(name, output_size, output_type, [input_size], [input_type],
False, "", const_expr, description)
False, "", const_expr, description, valid_fp_math_ctrl = valid_fp_math_ctrl)
def unop_reduce(name, output_size, output_type, input_type, prereduce_expr,
reduce_expr, final_expr, description = ""):
@ -228,8 +228,8 @@ def unop_reduce(name, output_size, output_type, input_type, prereduce_expr,
final(reduce_(reduce_(src0, src1), reduce_(src2, src3))),
description)
def unop_numeric_convert(name, out_type, in_type, const_expr, description = ""):
opcode(name, 0, out_type, [0], [in_type], True, "", const_expr, description)
def unop_numeric_convert(name, out_type, in_type, const_expr, description = "", valid_fp_math_ctrl = None):
opcode(name, 0, out_type, [0], [in_type], True, "", const_expr, description, valid_fp_math_ctrl = valid_fp_math_ctrl)
unop("mov", tuint, "src0")
@ -376,14 +376,22 @@ for src_t in [tint, tuint, tfloat, tbool]:
"""
unop_numeric_convert("{0}2{1}{2}".format(src_t[0], dst_t[0],
dst_bit_size),
dst_t + str(dst_bit_size), src_t, conv_expr)
dst_t + str(dst_bit_size), src_t, conv_expr,
valid_fp_math_ctrl = exact)
else:
valid_fp_math_ctrl = None
if dst_t == tfloat and src_t == tbool:
valid_fp_math_ctrl = preserve_sz
elif dst_t == tfloat and src_t in [tint, tuint]:
valid_fp_math_ctrl = preserve_sz + preserve_inf + exact
conv_expr = "src0 != 0" if dst_t == tbool else "src0"
unop_numeric_convert("{0}2{1}{2}".format(src_t[0], dst_t[0],
dst_bit_size),
dst_t + str(dst_bit_size), src_t, conv_expr)
dst_t + str(dst_bit_size), src_t, conv_expr,
valid_fp_math_ctrl = valid_fp_math_ctrl)
def unop_numeric_convert_mp(base, src_t, dst_t):
def unop_numeric_convert_mp(base, src_t, dst_t, valid_fp_math_ctrl = None):
op_like = base + "16"
unop_numeric_convert(base + "mp", src_t, dst_t, opcodes[op_like].const_expr,
description = """
@ -391,17 +399,17 @@ Special opcode that is the same as :nir:alu-op:`{}` except that it is safe to
remove it if the result is immediately converted back to 32 bits again. This is
generated as part of the precision lowering pass. ``mp`` stands for medium
precision.
""".format(op_like))
""".format(op_like), valid_fp_math_ctrl = valid_fp_math_ctrl)
unop_numeric_convert_mp("f2f", tfloat16, tfloat32)
unop_numeric_convert_mp("i2i", tint16, tint32)
# u2ump isn't defined, because the behavior is equal to i2imp
unop_numeric_convert_mp("f2i", tint16, tfloat32)
unop_numeric_convert_mp("f2u", tuint16, tfloat32)
unop_numeric_convert_mp("i2f", tfloat16, tint32)
unop_numeric_convert_mp("u2f", tfloat16, tuint32)
unop_numeric_convert_mp("f2i", tint16, tfloat32, exact)
unop_numeric_convert_mp("f2u", tuint16, tfloat32, exact)
unop_numeric_convert_mp("i2f", tfloat16, tint32, preserve_sz + preserve_inf + exact)
unop_numeric_convert_mp("u2f", tfloat16, tuint32, preserve_sz + preserve_inf + exact)
unop_numeric_convert("f2i32_rtne", tint32, tfloat32, "(int32_t)_mesa_roundevenf(src0)")
unop_numeric_convert("f2i32_rtne", tint32, tfloat32, "(int32_t)_mesa_roundevenf(src0)", valid_fp_math_ctrl = exact)
# Note: 64-bit integers are intentionally not supported. Casting u_uintN_max
# (and related signed values) to double is precisely representable for upto
@ -410,11 +418,13 @@ unop_numeric_convert("f2i32_rtne", tint32, tfloat32, "(int32_t)_mesa_roundevenf(
for bits in (8, 16, 32):
unop_numeric_convert(f"f2u{bits}_sat", f"uint{bits}", tfloat,
f"(uint{bits}_t)fmin(fmax(src0, 0.0), (double)u_uintN_max({bits}))",
"Convert float to uint with clamping to uint range. NaN becomes zero.")
"Convert float to uint with clamping to uint range. NaN becomes zero.",
valid_fp_math_ctrl = preserve_inf + preserve_nan + exact)
unop_numeric_convert(f"f2i{bits}_sat", f"int{bits}", tfloat,
f"(int{bits}_t) isnan(src0) ? 0.0 : fmin(fmax(src0, (double)u_intN_min({bits})), (double)u_intN_max({bits}))",
"Convert float to int with clamping to int range. NaN becomes zero.")
"Convert float to int with clamping to int range. NaN becomes zero.",
valid_fp_math_ctrl = preserve_inf + preserve_nan + exact)
# Unary floating-point rounding operations.
@ -434,49 +444,49 @@ unop("fsin", tfloat, "bit_size == 64 ? sin(src0) : sinf(src0)")
unop("fcos", tfloat, "bit_size == 64 ? cos(src0) : cosf(src0)")
# dfrexp
unop_convert("frexp_exp", tint32, tfloat, "frexp(src0, &dst);")
unop_convert("frexp_exp", tint32, tfloat, "frexp(src0, &dst);", valid_fp_math_ctrl = preserve_inf + preserve_nan + exact)
unop_convert("frexp_sig", tfloat, tfloat, "int n; dst = frexp(src0, &n);")
# Floating point pack and unpack operations.
def pack_2x16(fmt, in_type):
def pack_2x16(fmt, in_type, valid_fp_math_ctrl = None):
unop_horiz("pack_" + fmt + "_2x16", 1, tuint32, 2, in_type, """
dst.x = (uint32_t) pack_fmt_1x16(src0.x);
dst.x |= ((uint32_t) pack_fmt_1x16(src0.y)) << 16;
""".replace("fmt", fmt))
""".replace("fmt", fmt), valid_fp_math_ctrl = valid_fp_math_ctrl)
def pack_4x8(fmt):
def pack_4x8(fmt, valid_fp_math_ctrl = None):
unop_horiz("pack_" + fmt + "_4x8", 1, tuint32, 4, tfloat32, """
dst.x = (uint32_t) pack_fmt_1x8(src0.x);
dst.x |= ((uint32_t) pack_fmt_1x8(src0.y)) << 8;
dst.x |= ((uint32_t) pack_fmt_1x8(src0.z)) << 16;
dst.x |= ((uint32_t) pack_fmt_1x8(src0.w)) << 24;
""".replace("fmt", fmt))
""".replace("fmt", fmt), valid_fp_math_ctrl = valid_fp_math_ctrl)
def unpack_2x16(fmt):
def unpack_2x16(fmt, valid_fp_math_ctrl = None):
unop_horiz("unpack_" + fmt + "_2x16", 2, tfloat32, 1, tuint32, """
dst.x = unpack_fmt_1x16((uint16_t)(src0.x & 0xffff));
dst.y = unpack_fmt_1x16((uint16_t)(src0.x >> 16));
""".replace("fmt", fmt))
""".replace("fmt", fmt), valid_fp_math_ctrl = valid_fp_math_ctrl)
def unpack_4x8(fmt):
def unpack_4x8(fmt, valid_fp_math_ctrl = None):
unop_horiz("unpack_" + fmt + "_4x8", 4, tfloat32, 1, tuint32, """
dst.x = unpack_fmt_1x8((uint8_t)(src0.x & 0xff));
dst.y = unpack_fmt_1x8((uint8_t)((src0.x >> 8) & 0xff));
dst.z = unpack_fmt_1x8((uint8_t)((src0.x >> 16) & 0xff));
dst.w = unpack_fmt_1x8((uint8_t)(src0.x >> 24));
""".replace("fmt", fmt))
""".replace("fmt", fmt), valid_fp_math_ctrl = valid_fp_math_ctrl)
pack_2x16("snorm", tfloat)
pack_4x8("snorm")
pack_2x16("unorm", tfloat)
pack_4x8("unorm")
pack_2x16("snorm", tfloat, preserve_inf + preserve_nan + exact)
pack_4x8("snorm", preserve_inf + preserve_nan + exact)
pack_2x16("unorm", tfloat, preserve_inf + preserve_nan + exact)
pack_4x8("unorm", preserve_inf + preserve_nan + exact)
pack_2x16("half", tfloat32)
unpack_2x16("snorm")
unpack_4x8("snorm")
unpack_2x16("unorm")
unpack_4x8("unorm")
unpack_2x16("snorm", preserve_sz + exact)
unpack_4x8("snorm", preserve_sz + exact)
unpack_2x16("unorm", preserve_sz + exact)
unpack_4x8("unorm", preserve_sz + exact)
unop_horiz("pack_uint_2x16", 1, tuint32, 2, tuint32, """
dst.x = _mesa_unsigned_to_unsigned(src0.x, 16);