diff --git a/src/intel/compiler/brw/brw_nir_lower_cooperative_matrix.c b/src/intel/compiler/brw/brw_nir_lower_cooperative_matrix.c index 1f23f8a0746..25997c1cc63 100644 --- a/src/intel/compiler/brw/brw_nir_lower_cooperative_matrix.c +++ b/src/intel/compiler/brw/brw_nir_lower_cooperative_matrix.c @@ -463,6 +463,7 @@ static void lower_cmat_convert(nir_builder *b, nir_intrinsic_instr *intrin, struct lower_cmat_state *state) { + b->fp_math_ctrl = nir_intrinsic_fp_math_ctrl(intrin); nir_deref_instr *dst_slice = nir_src_as_deref(intrin->src[0]); nir_deref_instr *src_slice = nir_src_as_deref(intrin->src[1]); @@ -488,12 +489,14 @@ lower_cmat_convert(nir_builder *b, nir_intrinsic_instr *intrin, .src_cmat_desc = src_info->desc); nir_store_deref(b, dst_slice, result, nir_component_mask(result->num_components)); + b->fp_math_ctrl = nir_fp_fast_math; } static void lower_cmat_unary_op(nir_builder *b, nir_intrinsic_instr *intrin, struct lower_cmat_state *state) { + b->fp_math_ctrl = nir_intrinsic_fp_math_ctrl(intrin); nir_deref_instr *dst_slice = nir_src_as_deref(intrin->src[0]); nir_deref_instr *src_slice = nir_src_as_deref(intrin->src[1]); @@ -506,12 +509,14 @@ lower_cmat_unary_op(nir_builder *b, nir_intrinsic_instr *intrin, nir_load_deref(b, src_slice)); nir_store_deref(b, dst_slice, result, nir_component_mask(result->num_components)); + b->fp_math_ctrl = nir_fp_fast_math; } static void lower_cmat_binary_op(nir_builder *b, nir_intrinsic_instr *intrin, struct lower_cmat_state *state) { + b->fp_math_ctrl = nir_intrinsic_fp_math_ctrl(intrin); nir_deref_instr *dst_slice = nir_src_as_deref(intrin->src[0]); nir_deref_instr *src_a_slice = nir_src_as_deref(intrin->src[1]); nir_deref_instr *src_b_slice = nir_src_as_deref(intrin->src[2]); @@ -543,12 +548,14 @@ lower_cmat_binary_op(nir_builder *b, nir_intrinsic_instr *intrin, nir_store_deref(b, dst_slice, nir_vec(b, results, num_components), nir_component_mask(num_components)); + b->fp_math_ctrl = nir_fp_fast_math; } static void lower_cmat_scalar_op(nir_builder *b, nir_intrinsic_instr *intrin, struct lower_cmat_state *state) { + b->fp_math_ctrl = nir_intrinsic_fp_math_ctrl(intrin); nir_deref_instr *dst_slice = nir_src_as_deref(intrin->src[0]); nir_deref_instr *src_slice = nir_src_as_deref(intrin->src[1]); nir_def *scalar = intrin->src[2].ssa; @@ -575,6 +582,7 @@ lower_cmat_scalar_op(nir_builder *b, nir_intrinsic_instr *intrin, nir_store_deref(b, dst_slice, nir_vec(b, results, num_components), nir_component_mask(num_components)); + b->fp_math_ctrl = nir_fp_fast_math; } static void