mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2026-01-02 20:20:09 +01:00
brw/cmat: Implement conversion from/to BFloat16
When converting BFloat16 from/to non-Float32 type, use the Float32 conversion as an intermediate step. Take the opportunity to separate the unary_op/convert code-paths. Reviewed-by: Ian Romanick <ian.d.romanick@intel.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/34105>
This commit is contained in:
parent
de88184ab6
commit
d4381c0908
1 changed files with 76 additions and 22 deletions
|
|
@ -436,6 +436,76 @@ emit_packed_alu1(nir_builder *b,
|
|||
return nir_vec(b, results, dst_components);
|
||||
}
|
||||
|
||||
static nir_op
|
||||
get_cmat_conversion_op(enum glsl_base_type src,
|
||||
enum glsl_base_type dst)
|
||||
{
|
||||
if (src == GLSL_TYPE_BFLOAT16) {
|
||||
assert(dst == GLSL_TYPE_FLOAT);
|
||||
return nir_op_bf2f;
|
||||
|
||||
} else if (dst == GLSL_TYPE_BFLOAT16) {
|
||||
assert(src == GLSL_TYPE_FLOAT);
|
||||
return nir_op_f2bf;
|
||||
|
||||
} else {
|
||||
return nir_type_conversion_op(nir_get_nir_type_for_glsl_base_type(src),
|
||||
nir_get_nir_type_for_glsl_base_type(dst),
|
||||
nir_rounding_mode_undef);
|
||||
}
|
||||
}
|
||||
|
||||
static void
|
||||
lower_cmat_convert(nir_builder *b, nir_intrinsic_instr *intrin,
|
||||
struct lower_cmat_state *state)
|
||||
{
|
||||
nir_deref_instr *dst_slice = nir_src_as_deref(intrin->src[0]);
|
||||
nir_deref_instr *src_slice = nir_src_as_deref(intrin->src[1]);
|
||||
|
||||
const slice_info *dst_info = get_slice_info(state, dst_slice);
|
||||
const slice_info *src_info = get_slice_info(state, src_slice);
|
||||
|
||||
const nir_cmat_signed cmat_signed_mask = nir_intrinsic_cmat_signed_mask(intrin);
|
||||
|
||||
enum glsl_base_type src_element_type = glsl_apply_signedness_to_base_type(
|
||||
src_info->desc.element_type, cmat_signed_mask & NIR_CMAT_A_SIGNED);
|
||||
enum glsl_base_type dst_element_type = glsl_apply_signedness_to_base_type(
|
||||
dst_info->desc.element_type, cmat_signed_mask & NIR_CMAT_RESULT_SIGNED);
|
||||
|
||||
bool needs_intermediate =
|
||||
(src_element_type == GLSL_TYPE_BFLOAT16 && dst_element_type != GLSL_TYPE_FLOAT) ||
|
||||
(src_element_type != GLSL_TYPE_FLOAT && dst_element_type == GLSL_TYPE_BFLOAT16);
|
||||
|
||||
nir_def *result;
|
||||
nir_def *src = nir_load_deref(b, src_slice);
|
||||
|
||||
if (needs_intermediate) {
|
||||
/* Cooperative matrices must have the same "shape" to be converted. */
|
||||
assert(src_info->desc.rows == dst_info->desc.rows);
|
||||
assert(src_info->desc.cols == dst_info->desc.cols);
|
||||
assert(src_info->desc.use == dst_info->desc.use);
|
||||
assert(src_info->desc.scope == dst_info->desc.scope);
|
||||
|
||||
struct glsl_cmat_description float_desc = src_info->desc;
|
||||
float_desc.element_type = GLSL_TYPE_FLOAT;
|
||||
|
||||
slice_info float_info = {};
|
||||
init_slice_info(state, float_desc, &float_info);
|
||||
|
||||
nir_op op1 = get_cmat_conversion_op(src_element_type, GLSL_TYPE_FLOAT);
|
||||
nir_op op2 = get_cmat_conversion_op(GLSL_TYPE_FLOAT, dst_element_type);
|
||||
|
||||
nir_def *tmp = emit_packed_alu1(b, state, src_info, &float_info, op1, src);
|
||||
result = emit_packed_alu1(b, state, &float_info, dst_info, op2, tmp);
|
||||
|
||||
} else {
|
||||
nir_op op = get_cmat_conversion_op(src_element_type, dst_element_type);
|
||||
result = emit_packed_alu1(b, state, src_info, dst_info, op, src);
|
||||
}
|
||||
|
||||
nir_store_deref(b, dst_slice, result, nir_component_mask(result->num_components));
|
||||
}
|
||||
|
||||
static void
|
||||
lower_cmat_unary_op(nir_builder *b, nir_intrinsic_instr *intrin,
|
||||
struct lower_cmat_state *state)
|
||||
|
|
@ -445,29 +515,10 @@ lower_cmat_unary_op(nir_builder *b, nir_intrinsic_instr *intrin,
|
|||
|
||||
const slice_info *dst_info = get_slice_info(state, dst_slice);
|
||||
const slice_info *src_info = get_slice_info(state, src_slice);
|
||||
assert(cmat_descriptions_are_equal(src_info->desc, dst_info->desc));
|
||||
|
||||
/* The type of the returned slice may be different from the type of the
|
||||
* input slice if this is a convert operation.
|
||||
*/
|
||||
|
||||
nir_op op;
|
||||
|
||||
if (intrin->intrinsic == nir_intrinsic_cmat_unary_op) {
|
||||
op = nir_intrinsic_alu_op(intrin);
|
||||
} else {
|
||||
const nir_cmat_signed cmat_signed_mask = nir_intrinsic_cmat_signed_mask(intrin);
|
||||
|
||||
enum glsl_base_type src_base_type = glsl_apply_signedness_to_base_type(
|
||||
src_info->desc.element_type, cmat_signed_mask & NIR_CMAT_A_SIGNED);
|
||||
enum glsl_base_type dst_base_type = glsl_apply_signedness_to_base_type(
|
||||
dst_info->desc.element_type, cmat_signed_mask & NIR_CMAT_RESULT_SIGNED);
|
||||
|
||||
op = nir_type_conversion_op(nir_get_nir_type_for_glsl_base_type(src_base_type),
|
||||
nir_get_nir_type_for_glsl_base_type(dst_base_type),
|
||||
nir_rounding_mode_undef);
|
||||
}
|
||||
|
||||
nir_def *result = emit_packed_alu1(b, state, src_info, dst_info, op,
|
||||
nir_def *result = emit_packed_alu1(b, state, src_info, dst_info,
|
||||
nir_intrinsic_alu_op(intrin),
|
||||
nir_load_deref(b, src_slice));
|
||||
|
||||
nir_store_deref(b, dst_slice, result, nir_component_mask(result->num_components));
|
||||
|
|
@ -599,6 +650,9 @@ lower_cmat_instr(nir_builder *b, nir_instr *instr, void *_state)
|
|||
}
|
||||
|
||||
case nir_intrinsic_cmat_convert:
|
||||
lower_cmat_convert(b, intrin, state);
|
||||
return NIR_LOWER_INSTR_PROGRESS_REPLACE;
|
||||
|
||||
case nir_intrinsic_cmat_unary_op:
|
||||
lower_cmat_unary_op(b, intrin, state);
|
||||
return NIR_LOWER_INSTR_PROGRESS_REPLACE;
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue