From 2dfd4dcbc575a30edebf45cdae3e644976a4b3ad Mon Sep 17 00:00:00 2001 From: Caio Oliveira Date: Wed, 16 Jul 2025 16:06:02 -0700 Subject: [PATCH] brw: Fix cmat conversion between bfloat16 and non-float32 The HW only supports converting BRW_TYPE_BF values to/from BRW_TYPE_F, so intermediate conversion is needed. Move the intermediate conversion to the implementation of `@convert_cmat_intel` and simplify the brw_nir_lower_cooperative_matrix pass. This has two positive effects - Fixes conversion between BF and integer type cooperative matrices, that was still using the old emit_alu1 approach instead of the new code for `@convert_cmat_intel`. - Guarantee the intermediate conversion will result in a valid layout for conversions involved USE_B matrices. If we instead used the intrinsic twice in brw_nir_lower_cooperative_matrix.c, a matrix with invalid layout would be visible at NIR level and we wouldn't be able to keep the current assertion for USE_B case. Due to the configurations we have exposed, we still don't need to write a more complex USE_B conversion -- they are all between same size types (and, consequently, packing factors), so no shuffling of data is needed to respect the USE_B layout. Reviewed-by: Matt Turner Part-of: --- src/intel/compiler/brw_from_nir.cpp | 16 +++++- .../brw_nir_lower_cooperative_matrix.c | 54 +++++-------------- 2 files changed, 28 insertions(+), 42 deletions(-) diff --git a/src/intel/compiler/brw_from_nir.cpp b/src/intel/compiler/brw_from_nir.cpp index 21e2c4065c6..58be1b25f34 100644 --- a/src/intel/compiler/brw_from_nir.cpp +++ b/src/intel/compiler/brw_from_nir.cpp @@ -4885,11 +4885,15 @@ brw_from_nir_emit_cs_intrinsic(nir_to_brw_state &ntb, const unsigned elems = src_components * src_packing_factor; brw_builder bldn = bld.exec_all(); - const brw_reg src = retype(get_nir_src(ntb, instr->src[0], 0), src_type); + brw_reg src = retype(get_nir_src(ntb, instr->src[0], 0), src_type); const brw_reg dst = retype(dest, dst_type); assert(dst_cmat_desc.use == src_cmat_desc.use); + const bool needs_intermediate = + (src.type == BRW_TYPE_BF && dst.type != BRW_TYPE_F) || + (dst.type == BRW_TYPE_BF && src.type != BRW_TYPE_F); + switch (src_cmat_desc.use) { case GLSL_CMAT_USE_B: assert(dst_element_bits == src_element_bits); @@ -4898,6 +4902,16 @@ brw_from_nir_emit_cs_intrinsic(nir_to_brw_state &ntb, case GLSL_CMAT_USE_A: case GLSL_CMAT_USE_ACCUMULATOR: { const unsigned width = bldn.dispatch_width(); + + if (needs_intermediate) { + brw_reg tmp = bldn.vgrf(BRW_TYPE_F, elems); + for (unsigned c = 0; c < elems; c++) { + bldn.MOV(suboffset(tmp, c * width), + suboffset(src, c * width)); + } + src = tmp; + } + for (unsigned c = 0; c < elems; c++) { bldn.MOV(suboffset(dst, c * width), suboffset(src, c * width)); diff --git a/src/intel/compiler/brw_nir_lower_cooperative_matrix.c b/src/intel/compiler/brw_nir_lower_cooperative_matrix.c index 2ac5baeece9..f37c87cbb28 100644 --- a/src/intel/compiler/brw_nir_lower_cooperative_matrix.c +++ b/src/intel/compiler/brw_nir_lower_cooperative_matrix.c @@ -466,51 +466,23 @@ lower_cmat_convert(nir_builder *b, nir_intrinsic_instr *intrin, const slice_info *dst_info = get_slice_info(state, dst_slice); const slice_info *src_info = get_slice_info(state, src_slice); - const nir_cmat_signed cmat_signed_mask = nir_intrinsic_cmat_signed_mask(intrin); + /* Cooperative matrices must have the same "shape" to be converted. */ + assert(src_info->desc.rows == dst_info->desc.rows); + assert(src_info->desc.cols == dst_info->desc.cols); + assert(src_info->desc.use == dst_info->desc.use); + assert(src_info->desc.scope == dst_info->desc.scope); - enum glsl_base_type src_element_type = glsl_apply_signedness_to_base_type( - src_info->desc.element_type, cmat_signed_mask & NIR_CMAT_A_SIGNED); - enum glsl_base_type dst_element_type = glsl_apply_signedness_to_base_type( - dst_info->desc.element_type, cmat_signed_mask & NIR_CMAT_RESULT_SIGNED); - - bool needs_intermediate = - (src_element_type == GLSL_TYPE_BFLOAT16 && dst_element_type != GLSL_TYPE_FLOAT) || - (src_element_type != GLSL_TYPE_FLOAT && dst_element_type == GLSL_TYPE_BFLOAT16); - - nir_def *result; nir_def *src = nir_load_deref(b, src_slice); - if (needs_intermediate) { - /* Cooperative matrices must have the same "shape" to be converted. */ - assert(src_info->desc.rows == dst_info->desc.rows); - assert(src_info->desc.cols == dst_info->desc.cols); - assert(src_info->desc.use == dst_info->desc.use); - assert(src_info->desc.scope == dst_info->desc.scope); + const unsigned dst_components = glsl_get_vector_elements(dst_info->type); + const unsigned dst_bits = glsl_base_type_bit_size(dst_info->desc.element_type); - struct glsl_cmat_description float_desc = src_info->desc; - float_desc.element_type = GLSL_TYPE_FLOAT; - - slice_info float_info = {}; - init_slice_info(state, float_desc, &float_info); - - nir_op op1 = get_cmat_conversion_op(src_element_type, GLSL_TYPE_FLOAT); - nir_op op2 = get_cmat_conversion_op(GLSL_TYPE_FLOAT, dst_element_type); - - nir_def *tmp = emit_packed_alu1(b, state, src_info, &float_info, op1, src); - result = emit_packed_alu1(b, state, &float_info, dst_info, op2, tmp); - - } else { - const unsigned dst_components = glsl_get_vector_elements(dst_info->type); - const unsigned dst_bits = glsl_base_type_bit_size(dst_info->desc.element_type); - - result = - nir_convert_cmat_intel(b, - dst_components, - dst_info->packing_factor * dst_bits, - src, - .dst_cmat_desc = dst_info->desc, - .src_cmat_desc = src_info->desc); - } + nir_def *result = nir_convert_cmat_intel(b, + dst_components, + dst_info->packing_factor * dst_bits, + src, + .dst_cmat_desc = dst_info->desc, + .src_cmat_desc = src_info->desc); nir_store_deref(b, dst_slice, result, nir_component_mask(result->num_components)); }