brw: Fix cmat conversion between bfloat16 and non-float32

The HW only supports converting BRW_TYPE_BF values to/from BRW_TYPE_F,
so intermediate conversion is needed.  Move the intermediate conversion
to the implementation of `@convert_cmat_intel` and simplify the
brw_nir_lower_cooperative_matrix pass.  This has two positive effects

- Fixes conversion between BF and integer type cooperative matrices,
  that was still using the old emit_alu1 approach instead of the new
  code for `@convert_cmat_intel`.

- Guarantee the intermediate conversion will result in a valid layout
  for conversions involved USE_B matrices.  If we instead used the
  intrinsic twice in brw_nir_lower_cooperative_matrix.c, a matrix with
  invalid layout would be visible at NIR level and we wouldn't be able
  to keep the current assertion for USE_B case.

Due to the configurations we have exposed, we still don't need to
write a more complex USE_B conversion -- they are all between same
size types (and, consequently, packing factors), so no shuffling of
data is needed to respect the USE_B layout.

Reviewed-by: Matt Turner <mattst88@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/36185>
This commit is contained in:
Caio Oliveira 2025-07-16 16:06:02 -07:00 committed by Marge Bot
parent 557ac588e4
commit 2dfd4dcbc5
2 changed files with 28 additions and 42 deletions

View file

@ -4885,11 +4885,15 @@ brw_from_nir_emit_cs_intrinsic(nir_to_brw_state &ntb,
const unsigned elems = src_components * src_packing_factor;
brw_builder bldn = bld.exec_all();
const brw_reg src = retype(get_nir_src(ntb, instr->src[0], 0), src_type);
brw_reg src = retype(get_nir_src(ntb, instr->src[0], 0), src_type);
const brw_reg dst = retype(dest, dst_type);
assert(dst_cmat_desc.use == src_cmat_desc.use);
const bool needs_intermediate =
(src.type == BRW_TYPE_BF && dst.type != BRW_TYPE_F) ||
(dst.type == BRW_TYPE_BF && src.type != BRW_TYPE_F);
switch (src_cmat_desc.use) {
case GLSL_CMAT_USE_B:
assert(dst_element_bits == src_element_bits);
@ -4898,6 +4902,16 @@ brw_from_nir_emit_cs_intrinsic(nir_to_brw_state &ntb,
case GLSL_CMAT_USE_A:
case GLSL_CMAT_USE_ACCUMULATOR: {
const unsigned width = bldn.dispatch_width();
if (needs_intermediate) {
brw_reg tmp = bldn.vgrf(BRW_TYPE_F, elems);
for (unsigned c = 0; c < elems; c++) {
bldn.MOV(suboffset(tmp, c * width),
suboffset(src, c * width));
}
src = tmp;
}
for (unsigned c = 0; c < elems; c++) {
bldn.MOV(suboffset(dst, c * width),
suboffset(src, c * width));

View file

@ -466,51 +466,23 @@ lower_cmat_convert(nir_builder *b, nir_intrinsic_instr *intrin,
const slice_info *dst_info = get_slice_info(state, dst_slice);
const slice_info *src_info = get_slice_info(state, src_slice);
const nir_cmat_signed cmat_signed_mask = nir_intrinsic_cmat_signed_mask(intrin);
/* Cooperative matrices must have the same "shape" to be converted. */
assert(src_info->desc.rows == dst_info->desc.rows);
assert(src_info->desc.cols == dst_info->desc.cols);
assert(src_info->desc.use == dst_info->desc.use);
assert(src_info->desc.scope == dst_info->desc.scope);
enum glsl_base_type src_element_type = glsl_apply_signedness_to_base_type(
src_info->desc.element_type, cmat_signed_mask & NIR_CMAT_A_SIGNED);
enum glsl_base_type dst_element_type = glsl_apply_signedness_to_base_type(
dst_info->desc.element_type, cmat_signed_mask & NIR_CMAT_RESULT_SIGNED);
bool needs_intermediate =
(src_element_type == GLSL_TYPE_BFLOAT16 && dst_element_type != GLSL_TYPE_FLOAT) ||
(src_element_type != GLSL_TYPE_FLOAT && dst_element_type == GLSL_TYPE_BFLOAT16);
nir_def *result;
nir_def *src = nir_load_deref(b, src_slice);
if (needs_intermediate) {
/* Cooperative matrices must have the same "shape" to be converted. */
assert(src_info->desc.rows == dst_info->desc.rows);
assert(src_info->desc.cols == dst_info->desc.cols);
assert(src_info->desc.use == dst_info->desc.use);
assert(src_info->desc.scope == dst_info->desc.scope);
const unsigned dst_components = glsl_get_vector_elements(dst_info->type);
const unsigned dst_bits = glsl_base_type_bit_size(dst_info->desc.element_type);
struct glsl_cmat_description float_desc = src_info->desc;
float_desc.element_type = GLSL_TYPE_FLOAT;
slice_info float_info = {};
init_slice_info(state, float_desc, &float_info);
nir_op op1 = get_cmat_conversion_op(src_element_type, GLSL_TYPE_FLOAT);
nir_op op2 = get_cmat_conversion_op(GLSL_TYPE_FLOAT, dst_element_type);
nir_def *tmp = emit_packed_alu1(b, state, src_info, &float_info, op1, src);
result = emit_packed_alu1(b, state, &float_info, dst_info, op2, tmp);
} else {
const unsigned dst_components = glsl_get_vector_elements(dst_info->type);
const unsigned dst_bits = glsl_base_type_bit_size(dst_info->desc.element_type);
result =
nir_convert_cmat_intel(b,
dst_components,
dst_info->packing_factor * dst_bits,
src,
.dst_cmat_desc = dst_info->desc,
.src_cmat_desc = src_info->desc);
}
nir_def *result = nir_convert_cmat_intel(b,
dst_components,
dst_info->packing_factor * dst_bits,
src,
.dst_cmat_desc = dst_info->desc,
.src_cmat_desc = src_info->desc);
nir_store_deref(b, dst_slice, result, nir_component_mask(result->num_components));
}