diff --git a/src/compiler/spirv/vtn_alu.c b/src/compiler/spirv/vtn_alu.c index e57df780535..6898b2bf151 100644 --- a/src/compiler/spirv/vtn_alu.c +++ b/src/compiler/spirv/vtn_alu.c @@ -438,10 +438,14 @@ fp_math_ctrl_for_type(struct vtn_builder *b, struct vtn_type *type) enum glsl_base_type base_type; /* Some ALU like modf and frexp return a struct of two values. */ - if (glsl_type_is_struct(type->type)) + if (glsl_type_is_struct(type->type)) { base_type = glsl_get_base_type(type->type->fields.structure[0].type); - else + } else if (glsl_type_is_cmat(type->type)) { + struct glsl_cmat_description desc = *glsl_get_cmat_description(type->type); + base_type = desc.element_type; + } else { base_type = glsl_get_base_type(type->type); + } unsigned *fp_math_ctrl = vtn_fp_math_ctrl_for_base_type(b, base_type); @@ -742,13 +746,14 @@ vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, struct vtn_value *dest_val = vtn_untyped_value(b, w[2]); const struct glsl_type *dest_type = vtn_get_type(b, w[1])->type; + vtn_handle_fp_fast_math(b, dest_val, vtn_untyped_value(b, w[3])); + if (glsl_type_is_cmat(dest_type)) { vtn_handle_cooperative_alu(b, dest_val, dest_type, opcode, w, count); + b->nb.fp_math_ctrl = nir_fp_fast_math; return; } - vtn_handle_fp_fast_math(b, dest_val, vtn_untyped_value(b, w[3])); - bool mediump_16bit = vtn_alu_op_mediump_16bit(b, opcode, dest_val); /* Collect the various SSA sources */ diff --git a/src/compiler/spirv/vtn_cmat.c b/src/compiler/spirv/vtn_cmat.c index 6b1a87ffc9e..ee1179dc1c2 100644 --- a/src/compiler/spirv/vtn_cmat.c +++ b/src/compiler/spirv/vtn_cmat.c @@ -322,6 +322,7 @@ vtn_handle_cooperative_alu(struct vtn_builder *b, struct vtn_value *dest_val, bool swap = false; unsigned extra_fp_math_ctrl = nir_fp_fast_math; nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, &extra_fp_math_ctrl); + b->nb.fp_math_ctrl |= extra_fp_math_ctrl; nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_unary"); nir_cmat_unary_op(&b->nb, &dst->def, &src->def, @@ -347,6 +348,7 @@ vtn_handle_cooperative_alu(struct vtn_builder *b, struct vtn_value *dest_val, nir_deref_instr *mat_b = vtn_get_cmat_deref(b, w[4]); nir_op op = vtn_nir_alu_op_for_spirv_opcode(b, opcode, &swap, &extra_fp_math_ctrl); + b->nb.fp_math_ctrl |= extra_fp_math_ctrl; nir_deref_instr *dst = vtn_create_cmat_temporary(b, dst_type->type, "cmat_binary"); nir_cmat_binary_op(&b->nb, &dst->def, &mat_a->def, &mat_b->def,