nir: Let nir_eval_const_opcode() return a poison mask in case of UB.

This is unused by any callers currently, but will be useful for nir
algebraic pattern testing, and as a way to turn our comments in
nir_opcodes.py into actual C code.  For now, always returns false.

Reviewed-by: Alyssa Rosenzweig <alyssa.rosenzweig@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/39076>
This commit is contained in:
Emma Anholt 2026-01-06 11:25:35 -08:00 committed by Marge Bot
parent f6008645f6
commit b375da7f2a
7 changed files with 40 additions and 13 deletions

View file

@ -181,10 +181,10 @@ constant_fold_scalar(nir_scalar s, unsigned invocation_id, nir_shader *shader, n
srcs[i] = sources[i]; srcs[i] = sources[i];
nir_const_value dests[NIR_MAX_VEC_COMPONENTS]; nir_const_value dests[NIR_MAX_VEC_COMPONENTS];
if (op_info->output_size == 0) { if (op_info->output_size == 0) {
nir_eval_const_opcode(alu->op, dests, 1, bit_size, srcs, exec_mode); nir_eval_const_opcode(alu->op, dests, NULL, 1, bit_size, srcs, exec_mode);
*dest = dests[0]; *dest = dests[0];
} else { } else {
nir_eval_const_opcode(alu->op, dests, s.def->num_components, bit_size, srcs, exec_mode); nir_eval_const_opcode(alu->op, dests, NULL, s.def->num_components, bit_size, srcs, exec_mode);
*dest = dests[s.comp]; *dest = dests[s.comp];
} }
return true; return true;

View file

@ -35,7 +35,14 @@
extern "C" { extern "C" {
#endif #endif
void nir_eval_const_opcode(nir_op op, nir_const_value *dest, /**
* Evaluates the NIR opcode for the given source constant values.
*
* If @poison is non-NULL, it will containe the nir_component_mask of output
* channels that invoked undefined behavior (define not used here, to avoid
* pulling in all of nir.h).
*/
void nir_eval_const_opcode(nir_op op, nir_const_value *dest, uint16_t *poison,
unsigned num_components, unsigned bit_size, unsigned num_components, unsigned bit_size,
nir_const_value **src, nir_const_value **src,
unsigned float_controls_execution_mode); unsigned float_controls_execution_mode);

View file

@ -475,6 +475,7 @@ struct ${type}${width}_vec {
## components and apply the constant expression one component ## components and apply the constant expression one component
## at a time. ## at a time.
for (unsigned _i = 0; _i < num_components; _i++) { for (unsigned _i = 0; _i < num_components; _i++) {
bool poison = false;
## For each per-component input, create a variable srcN that ## For each per-component input, create a variable srcN that
## contains the value of the current (_i'th) component. ## contains the value of the current (_i'th) component.
% for j in range(op.num_inputs): % for j in range(op.num_inputs):
@ -538,6 +539,9 @@ struct ${type}${width}_vec {
} }
%endif %endif
% endif % endif
if (poison)
poison_mask |= (1 << _i);
} }
% else: % else:
## In the non-per-component case, create a struct dst with ## In the non-per-component case, create a struct dst with
@ -590,13 +594,15 @@ struct ${type}${width}_vec {
</%def> </%def>
% for name, op in sorted(opcodes.items()): % for name, op in sorted(opcodes.items()):
static void static nir_component_mask_t
evaluate_${name}(nir_const_value *_dst_val, evaluate_${name}(nir_const_value *_dst_val,
UNUSED unsigned num_components, UNUSED unsigned num_components,
${"UNUSED" if op_bit_sizes(op) is None else ""} unsigned bit_size, ${"UNUSED" if op_bit_sizes(op) is None else ""} unsigned bit_size,
UNUSED nir_const_value **_src, UNUSED nir_const_value **_src,
UNUSED unsigned execution_mode) UNUSED unsigned execution_mode)
{ {
nir_component_mask_t poison_mask = 0;
% if op_bit_sizes(op) is not None: % if op_bit_sizes(op) is not None:
switch (bit_size) { switch (bit_size) {
% for bit_size in op_bit_sizes(op): % for bit_size in op_bit_sizes(op):
@ -612,24 +618,31 @@ evaluate_${name}(nir_const_value *_dst_val,
% else: % else:
${evaluate_op(op, 0, execution_mode)} ${evaluate_op(op, 0, execution_mode)}
% endif % endif
return poison_mask;
} }
% endfor % endfor
void void
nir_eval_const_opcode(nir_op op, nir_const_value *dest, nir_eval_const_opcode(nir_op op, nir_const_value *dest, nir_component_mask_t *out_poison,
unsigned num_components, unsigned bit_width, unsigned num_components, unsigned bit_width,
nir_const_value **src, nir_const_value **src,
unsigned float_controls_execution_mode) unsigned float_controls_execution_mode)
{ {
nir_component_mask_t poison;
switch (op) { switch (op) {
% for name in sorted(opcodes.keys()): % for name in sorted(opcodes.keys()):
case nir_op_${name}: case nir_op_${name}:
evaluate_${name}(dest, num_components, bit_width, src, float_controls_execution_mode); poison = evaluate_${name}(dest, num_components, bit_width, src, float_controls_execution_mode);
return; break;
% endfor % endfor
default: default:
UNREACHABLE("shouldn't get here"); UNREACHABLE("shouldn't get here");
} }
if (out_poison)
*out_poison = poison;
}""" }"""
from mako.template import Template from mako.template import Template

View file

@ -583,7 +583,7 @@ eval_const_unop(nir_op op, unsigned bit_size, nir_const_value src0,
assert(nir_op_infos[op].num_inputs == 1); assert(nir_op_infos[op].num_inputs == 1);
nir_const_value dest; nir_const_value dest;
nir_const_value *src[1] = { &src0 }; nir_const_value *src[1] = { &src0 };
nir_eval_const_opcode(op, &dest, 1, bit_size, src, execution_mode); nir_eval_const_opcode(op, &dest, NULL, 1, bit_size, src, execution_mode);
return dest; return dest;
} }
@ -595,7 +595,7 @@ eval_const_binop(nir_op op, unsigned bit_size,
assert(nir_op_infos[op].num_inputs == 2); assert(nir_op_infos[op].num_inputs == 2);
nir_const_value dest; nir_const_value dest;
nir_const_value *src[2] = { &src0, &src1 }; nir_const_value *src[2] = { &src0, &src1 };
nir_eval_const_opcode(op, &dest, 1, bit_size, src, execution_mode); nir_eval_const_opcode(op, &dest, NULL, 1, bit_size, src, execution_mode);
return dest; return dest;
} }
@ -687,7 +687,7 @@ try_eval_const_alu(nir_const_value *dest, nir_scalar alu_s, const nir_scalar *or
} }
} }
nir_eval_const_opcode(alu->op, dest, 1, bit_size, src_ptrs, execution_mode); nir_eval_const_opcode(alu->op, dest, NULL, 1, bit_size, src_ptrs, execution_mode);
return true; return true;
} }
@ -897,7 +897,7 @@ test_iterations(int32_t iter_int, nir_const_value step,
/* Evaluate the loop exit condition */ /* Evaluate the loop exit condition */
nir_const_value result; nir_const_value result;
nir_eval_const_opcode(cond_op, &result, 1, bit_size, src, execution_mode); nir_eval_const_opcode(cond_op, &result, NULL, 1, bit_size, src, execution_mode);
return invert_cond ? !result.b : result.b; return invert_cond ? !result.b : result.b;
} }

View file

@ -67,6 +67,13 @@ class Opcode(object):
and the result will be equivalent to "dst = <expression>" for and the result will be equivalent to "dst = <expression>" for
per-component instructions and "dst.x = dst.y = ... = <expression>" per-component instructions and "dst.x = dst.y = ... = <expression>"
for non-per-component instructions. for non-per-component instructions.
The expression may set a poison = true flag to indicate that the
calculation invoked deferred undefined behavior (see
https://llvm.org/docs/UndefinedBehavior.html, which is similar to the
SPIRV 2.2.6 "Validity and Defined Behavior" definition.). For
non-per-component opcodes, poison_mask must be set to the undefined
components, instead.
""" """
assert isinstance(name, str) assert isinstance(name, str)
assert isinstance(output_size, int) assert isinstance(output_size, int)

View file

@ -77,7 +77,7 @@ nir_try_constant_fold_alu(nir_builder *b, nir_alu_instr *alu)
memset(dest, 0, sizeof(dest)); memset(dest, 0, sizeof(dest));
for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; ++i) for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; ++i)
srcs[i] = src[i]; srcs[i] = src[i];
nir_eval_const_opcode(alu->op, dest, alu->def.num_components, nir_eval_const_opcode(alu->op, dest, NULL, alu->def.num_components,
bit_size, srcs, bit_size, srcs,
b->shader->info.float_controls_execution_mode); b->shader->info.float_controls_execution_mode);

View file

@ -2953,7 +2953,7 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
nir_const_value *srcs[3] = { nir_const_value *srcs[3] = {
src[0], src[1], src[2], src[0], src[1], src[2],
}; };
nir_eval_const_opcode(op, val->constant->values, nir_eval_const_opcode(op, val->constant->values, NULL,
num_components, bit_size, srcs, num_components, bit_size, srcs,
b->shader->info.float_controls_execution_mode); b->shader->info.float_controls_execution_mode);