nir: track some float controls bits per instruction

With float_controls2, shaders can decide on the behavior of
NaN/Inf/SignedZero preservation by decorating specific instructions, on
top of having a default for the whole program.
Add where to track these to nir_alu_instr and propagate them to new
instructions everywhere that exact is being done already.

v2: use less bits for fp_fast_math in nir_alu_instr (Alyssa)

Reviewed-by: Alyssa Rosenzweig <alyssa@rosenzweig.io>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/27281>
This commit is contained in:
Iván Briano 2024-02-13 16:24:56 -08:00 committed by Marge Bot
parent 829ea35714
commit 666647acae
11 changed files with 46 additions and 0 deletions

View file

@ -1605,6 +1605,14 @@ typedef struct nir_alu_instr {
*/ */
bool no_unsigned_wrap : 1; bool no_unsigned_wrap : 1;
/**
* The float controls bit float_controls2 cares about. That is,
* NAN/INF/SIGNED_ZERO_PRESERVE only. Allow{Contract,Reassoc,Transform} are
* still handled through the exact bit, and the other float controls bits
* (rounding mode and denorm handling) remain in the execution mode only.
*/
uint32_t fp_fast_math : 9;
/** Destination */ /** Destination */
nir_def def; nir_def def;

View file

@ -66,6 +66,7 @@ nir_builder_alu_instr_finish_and_insert(nir_builder *build, nir_alu_instr *instr
const nir_op_info *op_info = &nir_op_infos[instr->op]; const nir_op_info *op_info = &nir_op_infos[instr->op];
instr->exact = build->exact; instr->exact = build->exact;
instr->fp_fast_math = build->fp_fast_math;
/* Guess the number of components the destination temporary should have /* Guess the number of components the destination temporary should have
* based on our input sizes, if it's not fixed for the op. * based on our input sizes, if it's not fixed for the op.
@ -324,6 +325,7 @@ nir_vec_scalars(nir_builder *build, nir_scalar *comp, unsigned num_components)
instr->src[i].swizzle[0] = comp[i].comp; instr->src[i].swizzle[0] = comp[i].comp;
} }
instr->exact = build->exact; instr->exact = build->exact;
instr->fp_fast_math = build->fp_fast_math;
/* Note: not reusing nir_builder_alu_instr_finish_and_insert() because it /* Note: not reusing nir_builder_alu_instr_finish_and_insert() because it
* can't re-guess the num_components when num_components == 1 (nir_op_mov). * can't re-guess the num_components when num_components == 1 (nir_op_mov).

View file

@ -44,6 +44,9 @@ typedef struct nir_builder {
* and header phis are not updated). */ * and header phis are not updated). */
bool update_divergence; bool update_divergence;
/* Float_controls2 bits. See nir_alu_instr for details. */
uint32_t fp_fast_math;
nir_shader *shader; nir_shader *shader;
nir_function_impl *impl; nir_function_impl *impl;
} nir_builder; } nir_builder;
@ -611,6 +614,7 @@ nir_mov_alu(nir_builder *build, nir_alu_src src, unsigned num_components)
nir_def_init(&mov->instr, &mov->def, num_components, nir_def_init(&mov->instr, &mov->def, num_components,
nir_src_bit_size(src.src)); nir_src_bit_size(src.src));
mov->exact = build->exact; mov->exact = build->exact;
mov->fp_fast_math = build->fp_fast_math;
mov->src[0] = src; mov->src[0] = src;
nir_builder_instr_insert(build, &mov->instr); nir_builder_instr_insert(build, &mov->instr);

View file

@ -216,6 +216,7 @@ clone_alu(clone_state *state, const nir_alu_instr *alu)
{ {
nir_alu_instr *nalu = nir_alu_instr_create(state->ns, alu->op); nir_alu_instr *nalu = nir_alu_instr_create(state->ns, alu->op);
nalu->exact = alu->exact; nalu->exact = alu->exact;
nalu->fp_fast_math = alu->fp_fast_math;
nalu->no_signed_wrap = alu->no_signed_wrap; nalu->no_signed_wrap = alu->no_signed_wrap;
nalu->no_unsigned_wrap = alu->no_unsigned_wrap; nalu->no_unsigned_wrap = alu->no_unsigned_wrap;

View file

@ -773,6 +773,8 @@ nir_instr_set_add_or_rewrite(struct set *instr_set, nir_instr *instr,
*/ */
if (instr->type == nir_instr_type_alu && nir_instr_as_alu(instr)->exact) if (instr->type == nir_instr_type_alu && nir_instr_as_alu(instr)->exact)
nir_instr_as_alu(match)->exact = true; nir_instr_as_alu(match)->exact = true;
if (instr->type == nir_instr_type_alu)
nir_instr_as_alu(match)->fp_fast_math = nir_instr_as_alu(instr)->fp_fast_math;
nir_def_rewrite_uses(def, new_def); nir_def_rewrite_uses(def, new_def);

View file

@ -51,6 +51,7 @@ lower_alu_instr(nir_builder *b, nir_instr *instr_, UNUSED void *cb_data)
b->cursor = nir_before_instr(&instr->instr); b->cursor = nir_before_instr(&instr->instr);
b->exact = instr->exact; b->exact = instr->exact;
b->fp_fast_math = instr->fp_fast_math;
switch (instr->op) { switch (instr->op) {
case nir_op_bitfield_reverse: case nir_op_bitfield_reverse:

View file

@ -111,6 +111,7 @@ lower_reduction(nir_alu_instr *alu, nir_op chan_op, nir_op merge_op,
chan->src[1].swizzle[0] = chan->src[1].swizzle[channel]; chan->src[1].swizzle[0] = chan->src[1].swizzle[channel];
} }
chan->exact = alu->exact; chan->exact = alu->exact;
chan->fp_fast_math = alu->fp_fast_math;
nir_builder_instr_insert(builder, &chan->instr); nir_builder_instr_insert(builder, &chan->instr);
@ -169,6 +170,7 @@ lower_fdot(nir_alu_instr *alu, nir_builder *builder)
if (i != 0) if (i != 0)
instr->src[2].src = nir_src_for_ssa(prev); instr->src[2].src = nir_src_for_ssa(prev);
instr->exact = builder->exact; instr->exact = builder->exact;
instr->fp_fast_math = builder->fp_fast_math;
nir_builder_instr_insert(builder, &instr->instr); nir_builder_instr_insert(builder, &instr->instr);
@ -187,6 +189,7 @@ lower_alu_instr_width(nir_builder *b, nir_instr *instr, void *_data)
unsigned i, chan; unsigned i, chan;
b->exact = alu->exact; b->exact = alu->exact;
b->fp_fast_math = alu->fp_fast_math;
unsigned num_components = alu->def.num_components; unsigned num_components = alu->def.num_components;
unsigned target_width = 1; unsigned target_width = 1;
@ -406,6 +409,7 @@ lower_alu_instr_width(nir_builder *b, nir_instr *instr, void *_data)
nir_alu_ssa_dest_init(lower, components, alu->def.bit_size); nir_alu_ssa_dest_init(lower, components, alu->def.bit_size);
lower->exact = alu->exact; lower->exact = alu->exact;
lower->fp_fast_math = alu->fp_fast_math;
for (i = 0; i < components; i++) { for (i = 0; i < components; i++) {
vec->src[chan + i].src = nir_src_for_ssa(&lower->def); vec->src[chan + i].src = nir_src_for_ssa(&lower->def);

View file

@ -53,12 +53,15 @@ replace_with_strict_ffma(struct nir_builder *bld, struct u_vector *dead_flrp,
nir_def *const neg_a = nir_fneg(bld, a); nir_def *const neg_a = nir_fneg(bld, a);
nir_instr_as_alu(neg_a->parent_instr)->exact = alu->exact; nir_instr_as_alu(neg_a->parent_instr)->exact = alu->exact;
nir_instr_as_alu(neg_a->parent_instr)->fp_fast_math = alu->fp_fast_math;
nir_def *const inner_ffma = nir_ffma(bld, neg_a, c, a); nir_def *const inner_ffma = nir_ffma(bld, neg_a, c, a);
nir_instr_as_alu(inner_ffma->parent_instr)->exact = alu->exact; nir_instr_as_alu(inner_ffma->parent_instr)->exact = alu->exact;
nir_instr_as_alu(inner_ffma->parent_instr)->fp_fast_math = alu->fp_fast_math;
nir_def *const outer_ffma = nir_ffma(bld, b, c, inner_ffma); nir_def *const outer_ffma = nir_ffma(bld, b, c, inner_ffma);
nir_instr_as_alu(outer_ffma->parent_instr)->exact = alu->exact; nir_instr_as_alu(outer_ffma->parent_instr)->exact = alu->exact;
nir_instr_as_alu(outer_ffma->parent_instr)->fp_fast_math = alu->fp_fast_math;
nir_def_rewrite_uses(&alu->def, outer_ffma); nir_def_rewrite_uses(&alu->def, outer_ffma);
@ -82,16 +85,20 @@ replace_with_single_ffma(struct nir_builder *bld, struct u_vector *dead_flrp,
nir_def *const neg_c = nir_fneg(bld, c); nir_def *const neg_c = nir_fneg(bld, c);
nir_instr_as_alu(neg_c->parent_instr)->exact = alu->exact; nir_instr_as_alu(neg_c->parent_instr)->exact = alu->exact;
nir_instr_as_alu(neg_c->parent_instr)->fp_fast_math = alu->fp_fast_math;
nir_def *const one_minus_c = nir_def *const one_minus_c =
nir_fadd(bld, nir_imm_floatN_t(bld, 1.0f, c->bit_size), neg_c); nir_fadd(bld, nir_imm_floatN_t(bld, 1.0f, c->bit_size), neg_c);
nir_instr_as_alu(one_minus_c->parent_instr)->exact = alu->exact; nir_instr_as_alu(one_minus_c->parent_instr)->exact = alu->exact;
nir_instr_as_alu(one_minus_c->parent_instr)->fp_fast_math = alu->fp_fast_math;
nir_def *const b_times_c = nir_fmul(bld, b, c); nir_def *const b_times_c = nir_fmul(bld, b, c);
nir_instr_as_alu(b_times_c->parent_instr)->exact = alu->exact; nir_instr_as_alu(b_times_c->parent_instr)->exact = alu->exact;
nir_instr_as_alu(b_times_c->parent_instr)->fp_fast_math = alu->fp_fast_math;
nir_def *const final_ffma = nir_ffma(bld, a, one_minus_c, b_times_c); nir_def *const final_ffma = nir_ffma(bld, a, one_minus_c, b_times_c);
nir_instr_as_alu(final_ffma->parent_instr)->exact = alu->exact; nir_instr_as_alu(final_ffma->parent_instr)->exact = alu->exact;
nir_instr_as_alu(final_ffma->parent_instr)->fp_fast_math = alu->fp_fast_math;
nir_def_rewrite_uses(&alu->def, final_ffma); nir_def_rewrite_uses(&alu->def, final_ffma);
@ -115,19 +122,24 @@ replace_with_strict(struct nir_builder *bld, struct u_vector *dead_flrp,
nir_def *const neg_c = nir_fneg(bld, c); nir_def *const neg_c = nir_fneg(bld, c);
nir_instr_as_alu(neg_c->parent_instr)->exact = alu->exact; nir_instr_as_alu(neg_c->parent_instr)->exact = alu->exact;
nir_instr_as_alu(neg_c->parent_instr)->fp_fast_math = alu->fp_fast_math;
nir_def *const one_minus_c = nir_def *const one_minus_c =
nir_fadd(bld, nir_imm_floatN_t(bld, 1.0f, c->bit_size), neg_c); nir_fadd(bld, nir_imm_floatN_t(bld, 1.0f, c->bit_size), neg_c);
nir_instr_as_alu(one_minus_c->parent_instr)->exact = alu->exact; nir_instr_as_alu(one_minus_c->parent_instr)->exact = alu->exact;
nir_instr_as_alu(one_minus_c->parent_instr)->fp_fast_math = alu->fp_fast_math;
nir_def *const first_product = nir_fmul(bld, a, one_minus_c); nir_def *const first_product = nir_fmul(bld, a, one_minus_c);
nir_instr_as_alu(first_product->parent_instr)->exact = alu->exact; nir_instr_as_alu(first_product->parent_instr)->exact = alu->exact;
nir_instr_as_alu(first_product->parent_instr)->fp_fast_math = alu->fp_fast_math;
nir_def *const second_product = nir_fmul(bld, b, c); nir_def *const second_product = nir_fmul(bld, b, c);
nir_instr_as_alu(second_product->parent_instr)->exact = alu->exact; nir_instr_as_alu(second_product->parent_instr)->exact = alu->exact;
nir_instr_as_alu(second_product->parent_instr)->fp_fast_math = alu->fp_fast_math;
nir_def *const sum = nir_fadd(bld, first_product, second_product); nir_def *const sum = nir_fadd(bld, first_product, second_product);
nir_instr_as_alu(sum->parent_instr)->exact = alu->exact; nir_instr_as_alu(sum->parent_instr)->exact = alu->exact;
nir_instr_as_alu(sum->parent_instr)->fp_fast_math = alu->fp_fast_math;
nir_def_rewrite_uses(&alu->def, sum); nir_def_rewrite_uses(&alu->def, sum);
@ -151,15 +163,19 @@ replace_with_fast(struct nir_builder *bld, struct u_vector *dead_flrp,
nir_def *const neg_a = nir_fneg(bld, a); nir_def *const neg_a = nir_fneg(bld, a);
nir_instr_as_alu(neg_a->parent_instr)->exact = alu->exact; nir_instr_as_alu(neg_a->parent_instr)->exact = alu->exact;
nir_instr_as_alu(neg_a->parent_instr)->fp_fast_math = alu->fp_fast_math;
nir_def *const b_minus_a = nir_fadd(bld, b, neg_a); nir_def *const b_minus_a = nir_fadd(bld, b, neg_a);
nir_instr_as_alu(b_minus_a->parent_instr)->exact = alu->exact; nir_instr_as_alu(b_minus_a->parent_instr)->exact = alu->exact;
nir_instr_as_alu(b_minus_a->parent_instr)->fp_fast_math = alu->fp_fast_math;
nir_def *const product = nir_fmul(bld, c, b_minus_a); nir_def *const product = nir_fmul(bld, c, b_minus_a);
nir_instr_as_alu(product->parent_instr)->exact = alu->exact; nir_instr_as_alu(product->parent_instr)->exact = alu->exact;
nir_instr_as_alu(product->parent_instr)->fp_fast_math = alu->fp_fast_math;
nir_def *const sum = nir_fadd(bld, a, product); nir_def *const sum = nir_fadd(bld, a, product);
nir_instr_as_alu(sum->parent_instr)->exact = alu->exact; nir_instr_as_alu(sum->parent_instr)->exact = alu->exact;
nir_instr_as_alu(sum->parent_instr)->fp_fast_math = alu->fp_fast_math;
nir_def_rewrite_uses(&alu->def, sum); nir_def_rewrite_uses(&alu->def, sum);
@ -186,12 +202,14 @@ replace_with_expanded_ffma_and_add(struct nir_builder *bld,
nir_def *const b_times_c = nir_fmul(bld, b, c); nir_def *const b_times_c = nir_fmul(bld, b, c);
nir_instr_as_alu(b_times_c->parent_instr)->exact = alu->exact; nir_instr_as_alu(b_times_c->parent_instr)->exact = alu->exact;
nir_instr_as_alu(b_times_c->parent_instr)->fp_fast_math = alu->fp_fast_math;
nir_def *inner_sum; nir_def *inner_sum;
if (subtract_c) { if (subtract_c) {
nir_def *const neg_c = nir_fneg(bld, c); nir_def *const neg_c = nir_fneg(bld, c);
nir_instr_as_alu(neg_c->parent_instr)->exact = alu->exact; nir_instr_as_alu(neg_c->parent_instr)->exact = alu->exact;
nir_instr_as_alu(neg_c->parent_instr)->fp_fast_math = alu->fp_fast_math;
inner_sum = nir_fadd(bld, a, neg_c); inner_sum = nir_fadd(bld, a, neg_c);
} else { } else {
@ -199,9 +217,11 @@ replace_with_expanded_ffma_and_add(struct nir_builder *bld,
} }
nir_instr_as_alu(inner_sum->parent_instr)->exact = alu->exact; nir_instr_as_alu(inner_sum->parent_instr)->exact = alu->exact;
nir_instr_as_alu(inner_sum->parent_instr)->fp_fast_math = alu->fp_fast_math;
nir_def *const outer_sum = nir_fadd(bld, inner_sum, b_times_c); nir_def *const outer_sum = nir_fadd(bld, inner_sum, b_times_c);
nir_instr_as_alu(outer_sum->parent_instr)->exact = alu->exact; nir_instr_as_alu(outer_sum->parent_instr)->exact = alu->exact;
nir_instr_as_alu(outer_sum->parent_instr)->fp_fast_math = alu->fp_fast_math;
nir_def_rewrite_uses(&alu->def, outer_sum); nir_def_rewrite_uses(&alu->def, outer_sum);

View file

@ -849,6 +849,7 @@ clone_alu_and_replace_src_defs(nir_builder *b, const nir_alu_instr *alu,
{ {
nir_alu_instr *nalu = nir_alu_instr_create(b->shader, alu->op); nir_alu_instr *nalu = nir_alu_instr_create(b->shader, alu->op);
nalu->exact = alu->exact; nalu->exact = alu->exact;
nalu->fp_fast_math = alu->fp_fast_math;
nir_def_init(&nalu->instr, &nalu->def, nir_def_init(&nalu->instr, &nalu->def,
alu->def.num_components, alu->def.num_components,

View file

@ -460,6 +460,7 @@ construct_value(nir_builder *build,
* replacement should be exact. * replacement should be exact.
*/ */
alu->exact = state->has_exact_alu || expr->exact; alu->exact = state->has_exact_alu || expr->exact;
alu->fp_fast_math = nir_instr_as_alu(instr)->fp_fast_math;
for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) { for (unsigned i = 0; i < nir_op_infos[op].num_inputs; i++) {
/* If the source is an explicitly sized source, then we need to reset /* If the source is an explicitly sized source, then we need to reset

View file

@ -722,6 +722,7 @@ write_alu(write_ctx *ctx, const nir_alu_instr *alu)
} }
write_def(ctx, &alu->def, header, alu->instr.type); write_def(ctx, &alu->def, header, alu->instr.type);
blob_write_uint32(ctx->blob, alu->fp_fast_math);
if (header.alu.packed_src_ssa_16bit) { if (header.alu.packed_src_ssa_16bit) {
for (unsigned i = 0; i < num_srcs; i++) { for (unsigned i = 0; i < num_srcs; i++) {
@ -773,6 +774,7 @@ read_alu(read_ctx *ctx, union packed_instr header)
alu->no_unsigned_wrap = header.alu.no_unsigned_wrap; alu->no_unsigned_wrap = header.alu.no_unsigned_wrap;
read_def(ctx, &alu->def, &alu->instr, header); read_def(ctx, &alu->def, &alu->instr, header);
alu->fp_fast_math = blob_read_uint32(ctx->blob);
if (header.alu.packed_src_ssa_16bit) { if (header.alu.packed_src_ssa_16bit) {
for (unsigned i = 0; i < num_srcs; i++) { for (unsigned i = 0; i < num_srcs; i++) {