spirv: set fp_math_ctrl for cmat alu

Reviewed-by: Alyssa Rosenzweig <alyssa.rosenzweig@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/40630>
This commit is contained in:
Georg Lehmann 2026-03-25 15:38:40 +01:00 committed by Marge Bot
parent 35ca85176c
commit b8b1ce9667
2 changed files with 11 additions and 4 deletions

View file

@ -438,10 +438,14 @@ fp_math_ctrl_for_type(struct vtn_builder *b, struct vtn_type *type)
enum glsl_base_type base_type;
/* Some ALU like modf and frexp return a struct of two values. */
if (glsl_type_is_struct(type->type))
if (glsl_type_is_struct(type->type)) {
base_type = glsl_get_base_type(type->type->fields.structure[0].type);
else
} else if (glsl_type_is_cmat(type->type)) {
struct glsl_cmat_description desc = *glsl_get_cmat_description(type->type);
base_type = desc.element_type;
} else {
base_type = glsl_get_base_type(type->type);
}
unsigned *fp_math_ctrl = vtn_fp_math_ctrl_for_base_type(b, base_type);
@ -742,13 +746,14 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode,
struct vtn_value *dest_val = vtn_untyped_value(b, w[2]);
const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type;
vtn_handle_fp_fast_math(b, dest_val, vtn_untyped_value(b, w[3]));
if (glsl_type_is_cmat(dest_type)) {
vtn_handle_cooperative_alu(b, dest_val, dest_type, opcode, w, count);
b->nb.fp_math_ctrl = nir_fp_fast_math;
return;
}
vtn_handle_fp_fast_math(b, dest_val, vtn_untyped_value(b, w[3]));
bool mediump_16bit = vtn_alu_op_mediump_16bit(b, opcode, dest_val);
/* Collect the various SSA sources */

View file

@ -322,6 +322,7 @@ vtn_handle_cooperative_alu(struct vtn_builder *b, struct vtn_value *dest_val,
bool swap = false;
unsigned extra_fp_math_ctrl = nir_fp_fast_math;
nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, &extra_fp_math_ctrl);
b->nb.fp_math_ctrl |= extra_fp_math_ctrl;
nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_unary");
nir_cmat_unary_op(&b->nb, &dst->def, &src->def,
@ -347,6 +348,7 @@ vtn_handle_cooperative_alu(struct vtn_builder *b, struct vtn_value *dest_val,
nir_deref_instr *mat_b = vtn_get_cmat_deref(b, w[4]);
nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, &extra_fp_math_ctrl);
b->nb.fp_math_ctrl |= extra_fp_math_ctrl;
nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_binary");
nir_cmat_binary_op(&b->nb, &dst->def, &mat_a->def, &mat_b->def,