brw/cmat: Implement conversion from/to BFloat16

When converting BFloat16 from/to non-Float32 type, use
the Float32 conversion as an intermediate step.  Take the
opportunity to separate the unary_op/convert code-paths.

Reviewed-by: Ian Romanick <ian.d.romanick@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/34105>
This commit is contained in:
Caio Oliveira 2025-04-14 18:38:33 -07:00 committed by Marge Bot
parent de88184ab6
commit d4381c0908

View file

@ -436,6 +436,76 @@ emit_packed_alu1(nir_builder *b,
return nir_vec(b, results, dst_components);
}
static nir_op
get_cmat_conversion_op(enum glsl_base_type src,
enum glsl_base_type dst)
{
if (src == GLSL_TYPE_BFLOAT16) {
assert(dst == GLSL_TYPE_FLOAT);
return nir_op_bf2f;
} else if (dst == GLSL_TYPE_BFLOAT16) {
assert(src == GLSL_TYPE_FLOAT);
return nir_op_f2bf;
} else {
return nir_type_conversion_op(nir_get_nir_type_for_glsl_base_type(src),
nir_get_nir_type_for_glsl_base_type(dst),
nir_rounding_mode_undef);
}
}
static void
lower_cmat_convert(nir_builder *b, nir_intrinsic_instr *intrin,
struct lower_cmat_state *state)
{
nir_deref_instr *dst_slice = nir_src_as_deref(intrin->src[0]);
nir_deref_instr *src_slice = nir_src_as_deref(intrin->src[1]);
const slice_info *dst_info = get_slice_info(state, dst_slice);
const slice_info *src_info = get_slice_info(state, src_slice);
const nir_cmat_signed cmat_signed_mask = nir_intrinsic_cmat_signed_mask(intrin);
enum glsl_base_type src_element_type = glsl_apply_signedness_to_base_type(
src_info->desc.element_type, cmat_signed_mask & NIR_CMAT_A_SIGNED);
enum glsl_base_type dst_element_type = glsl_apply_signedness_to_base_type(
dst_info->desc.element_type, cmat_signed_mask & NIR_CMAT_RESULT_SIGNED);
bool needs_intermediate =
(src_element_type == GLSL_TYPE_BFLOAT16 && dst_element_type != GLSL_TYPE_FLOAT) ||
(src_element_type != GLSL_TYPE_FLOAT && dst_element_type == GLSL_TYPE_BFLOAT16);
nir_def *result;
nir_def *src = nir_load_deref(b, src_slice);
if (needs_intermediate) {
/* Cooperative matrices must have the same "shape" to be converted. */
assert(src_info->desc.rows == dst_info->desc.rows);
assert(src_info->desc.cols == dst_info->desc.cols);
assert(src_info->desc.use == dst_info->desc.use);
assert(src_info->desc.scope == dst_info->desc.scope);
struct glsl_cmat_description float_desc = src_info->desc;
float_desc.element_type = GLSL_TYPE_FLOAT;
slice_info float_info = {};
init_slice_info(state, float_desc, &float_info);
nir_op op1 = get_cmat_conversion_op(src_element_type, GLSL_TYPE_FLOAT);
nir_op op2 = get_cmat_conversion_op(GLSL_TYPE_FLOAT, dst_element_type);
nir_def *tmp = emit_packed_alu1(b, state, src_info, &float_info, op1, src);
result = emit_packed_alu1(b, state, &float_info, dst_info, op2, tmp);
} else {
nir_op op = get_cmat_conversion_op(src_element_type, dst_element_type);
result = emit_packed_alu1(b, state, src_info, dst_info, op, src);
}
nir_store_deref(b, dst_slice, result, nir_component_mask(result->num_components));
}
static void
lower_cmat_unary_op(nir_builder *b, nir_intrinsic_instr *intrin,
struct lower_cmat_state *state)
@ -445,29 +515,10 @@ lower_cmat_unary_op(nir_builder *b, nir_intrinsic_instr *intrin,
const slice_info *dst_info = get_slice_info(state, dst_slice);
const slice_info *src_info = get_slice_info(state, src_slice);
assert(cmat_descriptions_are_equal(src_info->desc, dst_info->desc));
/* The type of the returned slice may be different from the type of the
* input slice if this is a convert operation.
*/
nir_op op;
if (intrin->intrinsic == nir_intrinsic_cmat_unary_op) {
op = nir_intrinsic_alu_op(intrin);
} else {
const nir_cmat_signed cmat_signed_mask = nir_intrinsic_cmat_signed_mask(intrin);
enum glsl_base_type src_base_type = glsl_apply_signedness_to_base_type(
src_info->desc.element_type, cmat_signed_mask & NIR_CMAT_A_SIGNED);
enum glsl_base_type dst_base_type = glsl_apply_signedness_to_base_type(
dst_info->desc.element_type, cmat_signed_mask & NIR_CMAT_RESULT_SIGNED);
op = nir_type_conversion_op(nir_get_nir_type_for_glsl_base_type(src_base_type),
nir_get_nir_type_for_glsl_base_type(dst_base_type),
nir_rounding_mode_undef);
}
nir_def *result = emit_packed_alu1(b, state, src_info, dst_info, op,
nir_def *result = emit_packed_alu1(b, state, src_info, dst_info,
nir_intrinsic_alu_op(intrin),
nir_load_deref(b, src_slice));
nir_store_deref(b, dst_slice, result, nir_component_mask(result->num_components));
@ -599,6 +650,9 @@ lower_cmat_instr(nir_builder *b, nir_instr *instr, void *_state)
}
case nir_intrinsic_cmat_convert:
lower_cmat_convert(b, intrin, state);
return NIR_LOWER_INSTR_PROGRESS_REPLACE;
case nir_intrinsic_cmat_unary_op:
lower_cmat_unary_op(b, intrin, state);
return NIR_LOWER_INSTR_PROGRESS_REPLACE;