nir/opcodes: use u_overflow to fix incorrect checks

Operands of an addition will be promoted to int making the a+b<a
kind of checks ineffective.

Use u_overflow.h helpers to perform the check correctly.
The commit would be simpler if it used __typeof__ like so:

   util_add_check_overflow(__typeof__(src0), src0, src1)

But typeof only became a standard in C23 so this commit instead extends
nir_opcodes a bit to allow opcodes that need the dest_type to get it.

Reviewed-by: Marek Olšák <marek.olsak@amd.com>
Reviewed-by: Dylan Baker <dylan.c.baker@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/37331>
This commit is contained in:
Pierre-Eric Pelloux-Prayer 2025-09-11 16:15:32 +02:00
parent 1b23a2ba80
commit cc4b50b023
2 changed files with 52 additions and 20 deletions

View file

@ -66,6 +66,7 @@ template = """\
#include "util/format/format_utils.h"
#include "util/format_r11g11b10f.h"
#include "util/u_math.h"
#include "util/u_overflow.h"
#include "nir_constant_expressions.h"
#include "nir.h"
@ -392,6 +393,26 @@ typedef bool bool8_t;
typedef bool bool16_t;
typedef bool bool32_t;
typedef bool bool64_t;
static inline bool
util_add_check_overflow_int1_t(int1_t a, int1_t b)
{
return (a & 1 && b & 1);
}
static inline bool
util_add_check_overflow_uint1_t(uint1_t a, int1_t b)
{
return (a & 1 && b & 1);
}
static inline bool
util_sub_check_overflow_int1_t(int1_t a, int1_t b)
{
/* int1_t uses 0/-1 convention, so the only
* overflow case is "0 - (-1)".
*/
return a == 0 && b != 0;
}
% for type in ["float", "int", "uint", "bool"]:
% for width in type_sizes(type):
struct ${type}${width}_vec {
@ -477,12 +498,15 @@ struct ${type}${width}_vec {
## Create an appropriately-typed variable dst and assign the
## result of the const_expr to it. If const_expr already contains
## writes to dst, just include const_expr directly.
<%
expr = op.render(output_type + '_t')
%>
% if "dst" in op.const_expr:
${output_type}_t dst;
${op.const_expr}
${expr}
% else:
${output_type}_t dst = ${op.const_expr};
${output_type}_t dst = ${expr};
% endif
## Store the current component of the actual destination to the

View file

@ -34,7 +34,7 @@ class Opcode(object):
"""
def __init__(self, name, output_size, output_type, input_sizes,
input_types, is_conversion, algebraic_properties, const_expr,
description):
description, needs_dest_type):
"""Parameters:
- name is the name of the opcode (prepend nir_op_ for the enum name)
@ -46,6 +46,8 @@ class Opcode(object):
- const_expr is an expression or series of statements that computes the
constant value of the opcode given the constant values of its inputs.
- Optional description of the opcode for documentation.
- needs_dest_type means const_expr depends on the destination type and
needs a formatting step.
Constant expressions are formed from the variables src0, src1, ...,
src(N-1), where N is the number of arguments. The output of the
@ -93,6 +95,12 @@ class Opcode(object):
self.algebraic_properties = algebraic_properties
self.const_expr = const_expr
self.description = description
self.needs_dest_type = needs_dest_type
def render(self, dest_type):
if self.needs_dest_type:
return self.const_expr.format(dest_type=dest_type)
return self.const_expr
# helper variables for strings
tfloat = "float"
@ -157,11 +165,12 @@ selection = "selection "
opcodes = {}
def opcode(name, output_size, output_type, input_sizes, input_types,
is_conversion, algebraic_properties, const_expr, description = ""):
is_conversion, algebraic_properties, const_expr, description = "",
needs_dest_type=False):
assert name not in opcodes
opcodes[name] = Opcode(name, output_size, output_type, input_sizes,
input_types, is_conversion, algebraic_properties,
const_expr, description)
const_expr, description, needs_dest_type)
def unop_convert(name, out_type, in_type, const_expr, description = ""):
opcode(name, 0, out_type, [0], [in_type], False, "", const_expr, description)
@ -545,14 +554,14 @@ for (unsigned bit = 0; bit < bit_size; bit++) {
unop_reduce("fsum", 1, tfloat, tfloat, "{src}", "{src0} + {src1}", "{src}",
description = "Sum of vector components")
def binop_convert(name, out_type, in_type1, alg_props, const_expr, description="", in_type2=None):
def binop_convert(name, out_type, in_type1, alg_props, const_expr, description="", in_type2=None, needs_dest_type=False):
if in_type2 is None:
in_type2 = in_type1
opcode(name, 0, out_type, [0, 0], [in_type1, in_type2],
False, alg_props, const_expr, description)
False, alg_props, const_expr, description, needs_dest_type)
def binop(name, ty, alg_props, const_expr, description = ""):
binop_convert(name, ty, ty, alg_props, const_expr, description)
def binop(name, ty, alg_props, const_expr, description = "", needs_dest_type=False):
binop_convert(name, ty, ty, alg_props, const_expr, description, needs_dest_type=needs_dest_type)
def binop_compare(name, ty, alg_props, const_expr, description = "", ty2=None):
binop_convert(name, tbool1, ty, alg_props, const_expr, description, ty2)
@ -626,17 +635,16 @@ if (nir_is_rounding_mode_rtz(execution_mode, bit_size)) {
""")
binop("iadd", tint, _2src_commutative + associative, "(uint64_t)src0 + (uint64_t)src1")
binop("iadd_sat", tint, _2src_commutative, """
src1 > 0 ?
(src0 + src1 < src0 ? u_intN_max(bit_size) : src0 + src1) :
(src0 < src0 + src1 ? u_intN_min(bit_size) : src0 + src1)
""")
util_add_check_overflow({dest_type}, src0, src1) ?
(src1 < 0 ? u_intN_max(bit_size) : u_uintN_max(bit_size)) : (src0 + src1)
""", "", True)
binop("uadd_sat", tuint, _2src_commutative,
"(src0 + src1) < src0 ? u_uintN_max(sizeof(src0) * 8) : (src0 + src1)")
"util_add_check_overflow({dest_type}, src0, src1) ? u_uintN_max(sizeof(src0) * 8) : (src0 + src1)",
"", True)
binop("isub_sat", tint, "", """
src1 < 0 ?
(src0 - src1 < src0 ? u_intN_max(bit_size) : src0 - src1) :
(src0 < src0 - src1 ? u_intN_min(bit_size) : src0 - src1)
""")
util_sub_check_overflow({dest_type}, src0, src1) ?
(src1 < 0 ? u_intN_max(bit_size) : u_intN_min(bit_size)) : (src0 - src1)
""", "", True)
binop("usub_sat", tuint, "", "src0 < src1 ? 0 : src0 - src1")
opcode("uadd64_32", 2, tuint32, [1, 1, 1], [tuint32, tuint32, tuint32], False, "", """
@ -762,11 +770,11 @@ binop("idiv", tint, "", "src1 == 0 ? 0 : (src0 / src1)")
binop("udiv", tuint, "", "src1 == 0 ? 0 : (src0 / src1)")
binop_convert("uadd_carry", tuint, tuint, _2src_commutative,
"src0 + src1 < src0",
"util_add_check_overflow({dest_type}, src0, src1)",
description = """
Return an integer (1 or 0) representing the carry resulting from the
addition of the two unsigned arguments.
""")
""", needs_dest_type = True)
binop_convert("usub_borrow", tuint, tuint, "", "src0 < src1", description = """
Return an integer (1 or 0) representing the borrow resulting from the