diff --git a/src/intel/compiler/brw_nir_lower_cooperative_matrix.c b/src/intel/compiler/brw_nir_lower_cooperative_matrix.c index cd75ceb0f1d..3577dc37baf 100644 --- a/src/intel/compiler/brw_nir_lower_cooperative_matrix.c +++ b/src/intel/compiler/brw_nir_lower_cooperative_matrix.c @@ -436,6 +436,76 @@ emit_packed_alu1(nir_builder *b, return nir_vec(b, results, dst_components); } +static nir_op +get_cmat_conversion_op(enum glsl_base_type src, + enum glsl_base_type dst) +{ + if (src == GLSL_TYPE_BFLOAT16) { + assert(dst == GLSL_TYPE_FLOAT); + return nir_op_bf2f; + + } else if (dst == GLSL_TYPE_BFLOAT16) { + assert(src == GLSL_TYPE_FLOAT); + return nir_op_f2bf; + + } else { + return nir_type_conversion_op(nir_get_nir_type_for_glsl_base_type(src), + nir_get_nir_type_for_glsl_base_type(dst), + nir_rounding_mode_undef); + } +} + +static void +lower_cmat_convert(nir_builder *b, nir_intrinsic_instr *intrin, + struct lower_cmat_state *state) +{ + nir_deref_instr *dst_slice = nir_src_as_deref(intrin->src[0]); + nir_deref_instr *src_slice = nir_src_as_deref(intrin->src[1]); + + const slice_info *dst_info = get_slice_info(state, dst_slice); + const slice_info *src_info = get_slice_info(state, src_slice); + + const nir_cmat_signed cmat_signed_mask = nir_intrinsic_cmat_signed_mask(intrin); + + enum glsl_base_type src_element_type = glsl_apply_signedness_to_base_type( + src_info->desc.element_type, cmat_signed_mask & NIR_CMAT_A_SIGNED); + enum glsl_base_type dst_element_type = glsl_apply_signedness_to_base_type( + dst_info->desc.element_type, cmat_signed_mask & NIR_CMAT_RESULT_SIGNED); + + bool needs_intermediate = + (src_element_type == GLSL_TYPE_BFLOAT16 && dst_element_type != GLSL_TYPE_FLOAT) || + (src_element_type != GLSL_TYPE_FLOAT && dst_element_type == GLSL_TYPE_BFLOAT16); + + nir_def *result; + nir_def *src = nir_load_deref(b, src_slice); + + if (needs_intermediate) { + /* Cooperative matrices must have the same "shape" to be converted. */ + assert(src_info->desc.rows == dst_info->desc.rows); + assert(src_info->desc.cols == dst_info->desc.cols); + assert(src_info->desc.use == dst_info->desc.use); + assert(src_info->desc.scope == dst_info->desc.scope); + + struct glsl_cmat_description float_desc = src_info->desc; + float_desc.element_type = GLSL_TYPE_FLOAT; + + slice_info float_info = {}; + init_slice_info(state, float_desc, &float_info); + + nir_op op1 = get_cmat_conversion_op(src_element_type, GLSL_TYPE_FLOAT); + nir_op op2 = get_cmat_conversion_op(GLSL_TYPE_FLOAT, dst_element_type); + + nir_def *tmp = emit_packed_alu1(b, state, src_info, &float_info, op1, src); + result = emit_packed_alu1(b, state, &float_info, dst_info, op2, tmp); + + } else { + nir_op op = get_cmat_conversion_op(src_element_type, dst_element_type); + result = emit_packed_alu1(b, state, src_info, dst_info, op, src); + } + + nir_store_deref(b, dst_slice, result, nir_component_mask(result->num_components)); +} + static void lower_cmat_unary_op(nir_builder *b, nir_intrinsic_instr *intrin, struct lower_cmat_state *state) @@ -445,29 +515,10 @@ lower_cmat_unary_op(nir_builder *b, nir_intrinsic_instr *intrin, const slice_info *dst_info = get_slice_info(state, dst_slice); const slice_info *src_info = get_slice_info(state, src_slice); + assert(cmat_descriptions_are_equal(src_info->desc, dst_info->desc)); - /* The type of the returned slice may be different from the type of the - * input slice if this is a convert operation. - */ - - nir_op op; - - if (intrin->intrinsic == nir_intrinsic_cmat_unary_op) { - op = nir_intrinsic_alu_op(intrin); - } else { - const nir_cmat_signed cmat_signed_mask = nir_intrinsic_cmat_signed_mask(intrin); - - enum glsl_base_type src_base_type = glsl_apply_signedness_to_base_type( - src_info->desc.element_type, cmat_signed_mask & NIR_CMAT_A_SIGNED); - enum glsl_base_type dst_base_type = glsl_apply_signedness_to_base_type( - dst_info->desc.element_type, cmat_signed_mask & NIR_CMAT_RESULT_SIGNED); - - op = nir_type_conversion_op(nir_get_nir_type_for_glsl_base_type(src_base_type), - nir_get_nir_type_for_glsl_base_type(dst_base_type), - nir_rounding_mode_undef); - } - - nir_def *result = emit_packed_alu1(b, state, src_info, dst_info, op, + nir_def *result = emit_packed_alu1(b, state, src_info, dst_info, + nir_intrinsic_alu_op(intrin), nir_load_deref(b, src_slice)); nir_store_deref(b, dst_slice, result, nir_component_mask(result->num_components)); @@ -599,6 +650,9 @@ lower_cmat_instr(nir_builder *b, nir_instr *instr, void *_state) } case nir_intrinsic_cmat_convert: + lower_cmat_convert(b, intrin, state); + return NIR_LOWER_INSTR_PROGRESS_REPLACE; + case nir_intrinsic_cmat_unary_op: lower_cmat_unary_op(b, intrin, state); return NIR_LOWER_INSTR_PROGRESS_REPLACE;