spirv: gather some float controls bits per instruction

v2: add static_assert to ensure values fit in bitfield (Alyssa)

Reviewed-by: Alyssa Rosenzweig <alyssa@rosenzweig.io>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/27281>
This commit is contained in:
Iván Briano 2024-02-13 16:35:53 -08:00 committed by Marge Bot
parent 666647acae
commit 750bd9757e
3 changed files with 23 additions and 0 deletions

View file

@ -402,6 +402,25 @@ vtn_nir_alu_op_for_spirv_opcode(struct vtn_builder *b,
}
}
void
vtn_handle_fp_fast_math(struct vtn_builder *b, struct vtn_value *val)
{
/* Take the NaN/Inf/SZ preserve bits from the execution mode and set them
* on the builder, so the generated instructions can take it from it.
* We only care about some of them, check nir_alu_instr for details.
* We also copy all bit widths, because we can't easily get the correct one
* here.
*/
#define FLOAT_CONTROLS2_BITS (FLOAT_CONTROLS_SIGNED_ZERO_INF_NAN_PRESERVE_FP16 | \
FLOAT_CONTROLS_SIGNED_ZERO_INF_NAN_PRESERVE_FP32 | \
FLOAT_CONTROLS_SIGNED_ZERO_INF_NAN_PRESERVE_FP64)
static_assert(FLOAT_CONTROLS2_BITS == BITSET_MASK(9),
"enum float_controls and fp_fast_math out of sync!");
b->nb.fp_fast_math = b->shader->info.float_controls_execution_mode &
FLOAT_CONTROLS2_BITS;
#undef FLOAT_CONTROLS2_BITS
}
static void
handle_no_contraction(struct vtn_builder *b, UNUSED struct vtn_value *val,
UNUSED int member, const struct vtn_decoration *dec,
@ -581,6 +600,7 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
}
vtn_handle_no_contraction(b, dest_val);
vtn_handle_fp_fast_math(b, dest_val);
bool mediump_16bit = vtn_alu_op_mediump_16bit(b, opcode, dest_val);
/* Collect the various SSA sources */

View file

@ -697,6 +697,7 @@ bool
vtn_handle_glsl450_instruction(struct vtn_builder *b, SpvOp ext_opcode,
const uint32_t *w, unsigned count)
{
vtn_handle_fp_fast_math(b, vtn_untyped_value(b, w[2]));
switch ((enum GLSLstd450)ext_opcode) {
case GLSLstd450Determinant: {
vtn_push_nir_ssa(b, w[2], build_mat_det(b, vtn_ssa_value(b, w[5])));

View file

@ -971,6 +971,8 @@ void vtn_handle_bitcast(struct vtn_builder *b, const uint32_t *w,
void vtn_handle_no_contraction(struct vtn_builder *b, struct vtn_value *val);
void vtn_handle_fp_fast_math(struct vtn_builder *b, struct vtn_value *val);
void vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
const uint32_t *w, unsigned count);