nir: Let nir_eval_const_opcode() return a poison mask in case of UB.

This is unused by any callers currently, but will be useful for nir
algebraic pattern testing, and as a way to turn our comments in
nir_opcodes.py into actual C code.  For now, always returns false.

Reviewed-by: Alyssa Rosenzweig <alyssa.rosenzweig@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/39076>
This commit is contained in:
Emma Anholt 2026-01-06 11:25:35 -08:00 committed by Marge Bot
parent f6008645f6
commit b375da7f2a
7 changed files with 40 additions and 13 deletions

View file

@ -181,10 +181,10 @@ constant_fold_scalar(nir_scalar s, unsigned invocation_id, nir_shader *shader, n
srcs[i] = sources[i];
nir_const_value dests[NIR_MAX_VEC_COMPONENTS];
if (op_info->output_size == 0) {
nir_eval_const_opcode(alu->op, dests, 1, bit_size, srcs, exec_mode);
nir_eval_const_opcode(alu->op, dests, NULL, 1, bit_size, srcs, exec_mode);
*dest = dests[0];
} else {
nir_eval_const_opcode(alu->op, dests, s.def->num_components, bit_size, srcs, exec_mode);
nir_eval_const_opcode(alu->op, dests, NULL, s.def->num_components, bit_size, srcs, exec_mode);
*dest = dests[s.comp];
}
return true;

View file

@ -35,7 +35,14 @@
extern "C" {
#endif
void nir_eval_const_opcode(nir_op op, nir_const_value *dest,
/**
* Evaluates the NIR opcode for the given source constant values.
*
* If @poison is non-NULL, it will containe the nir_component_mask of output
* channels that invoked undefined behavior (define not used here, to avoid
* pulling in all of nir.h).
*/
void nir_eval_const_opcode(nir_op op, nir_const_value *dest, uint16_t *poison,
unsigned num_components, unsigned bit_size,
nir_const_value **src,
unsigned float_controls_execution_mode);

View file

@ -475,6 +475,7 @@ struct ${type}${width}_vec {
## components and apply the constant expression one component
## at a time.
for (unsigned _i = 0; _i < num_components; _i++) {
bool poison = false;
## For each per-component input, create a variable srcN that
## contains the value of the current (_i'th) component.
% for j in range(op.num_inputs):
@ -538,6 +539,9 @@ struct ${type}${width}_vec {
}
%endif
% endif
if (poison)
poison_mask |= (1 << _i);
}
% else:
## In the non-per-component case, create a struct dst with
@ -590,13 +594,15 @@ struct ${type}${width}_vec {
</%def>
% for name, op in sorted(opcodes.items()):
static void
static nir_component_mask_t
evaluate_${name}(nir_const_value *_dst_val,
UNUSED unsigned num_components,
${"UNUSED" if op_bit_sizes(op) is None else ""} unsigned bit_size,
UNUSED nir_const_value **_src,
UNUSED unsigned execution_mode)
{
nir_component_mask_t poison_mask = 0;
% if op_bit_sizes(op) is not None:
switch (bit_size) {
% for bit_size in op_bit_sizes(op):
@ -612,24 +618,31 @@ evaluate_${name}(nir_const_value *_dst_val,
% else:
${evaluate_op(op, 0, execution_mode)}
% endif
return poison_mask;
}
% endfor
void
nir_eval_const_opcode(nir_op op, nir_const_value *dest,
nir_eval_const_opcode(nir_op op, nir_const_value *dest, nir_component_mask_t *out_poison,
unsigned num_components, unsigned bit_width,
nir_const_value **src,
unsigned float_controls_execution_mode)
{
nir_component_mask_t poison;
switch (op) {
% for name in sorted(opcodes.keys()):
case nir_op_${name}:
evaluate_${name}(dest, num_components, bit_width, src, float_controls_execution_mode);
return;
poison = evaluate_${name}(dest, num_components, bit_width, src, float_controls_execution_mode);
break;
% endfor
default:
UNREACHABLE("shouldn't get here");
}
if (out_poison)
*out_poison = poison;
}"""
from mako.template import Template

View file

@ -583,7 +583,7 @@ eval_const_unop(nir_op op, unsigned bit_size, nir_const_value src0,
assert(nir_op_infos[op].num_inputs == 1);
nir_const_value dest;
nir_const_value *src[1] = { &src0 };
nir_eval_const_opcode(op, &dest, 1, bit_size, src, execution_mode);
nir_eval_const_opcode(op, &dest, NULL, 1, bit_size, src, execution_mode);
return dest;
}
@ -595,7 +595,7 @@ eval_const_binop(nir_op op, unsigned bit_size,
assert(nir_op_infos[op].num_inputs == 2);
nir_const_value dest;
nir_const_value *src[2] = { &src0, &src1 };
nir_eval_const_opcode(op, &dest, 1, bit_size, src, execution_mode);
nir_eval_const_opcode(op, &dest, NULL, 1, bit_size, src, execution_mode);
return dest;
}
@ -687,7 +687,7 @@ try_eval_const_alu(nir_const_value *dest, nir_scalar alu_s, const nir_scalar *or
}
}
nir_eval_const_opcode(alu->op, dest, 1, bit_size, src_ptrs, execution_mode);
nir_eval_const_opcode(alu->op, dest, NULL, 1, bit_size, src_ptrs, execution_mode);
return true;
}
@ -897,7 +897,7 @@ test_iterations(int32_t iter_int, nir_const_value step,
/* Evaluate the loop exit condition */
nir_const_value result;
nir_eval_const_opcode(cond_op, &result, 1, bit_size, src, execution_mode);
nir_eval_const_opcode(cond_op, &result, NULL, 1, bit_size, src, execution_mode);
return invert_cond ? !result.b : result.b;
}

View file

@ -67,6 +67,13 @@ class Opcode(object):
and the result will be equivalent to "dst = <expression>" for
per-component instructions and "dst.x = dst.y = ... = <expression>"
for non-per-component instructions.
The expression may set a poison = true flag to indicate that the
calculation invoked deferred undefined behavior (see
https://llvm.org/docs/UndefinedBehavior.html, which is similar to the
SPIRV 2.2.6 "Validity and Defined Behavior" definition.). For
non-per-component opcodes, poison_mask must be set to the undefined
components, instead.
"""
assert isinstance(name, str)
assert isinstance(output_size, int)

View file

@ -77,7 +77,7 @@ nir_try_constant_fold_alu(nir_builder *b, nir_alu_instr *alu)
memset(dest, 0, sizeof(dest));
for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; ++i)
srcs[i] = src[i];
nir_eval_const_opcode(alu->op, dest, alu->def.num_components,
nir_eval_const_opcode(alu->op, dest, NULL, alu->def.num_components,
bit_size, srcs,
b->shader->info.float_controls_execution_mode);

View file

@ -2953,7 +2953,7 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
nir_const_value *srcs[3] = {
src[0], src[1], src[2],
};
nir_eval_const_opcode(op, val->constant->values,
nir_eval_const_opcode(op, val->constant->values, NULL,
num_components, bit_size, srcs,
b->shader->info.float_controls_execution_mode);