mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2025-12-24 08:50:13 +01:00
brw: Fix cmat conversion between bfloat16 and non-float32
The HW only supports converting BRW_TYPE_BF values to/from BRW_TYPE_F, so intermediate conversion is needed. Move the intermediate conversion to the implementation of `@convert_cmat_intel` and simplify the brw_nir_lower_cooperative_matrix pass. This has two positive effects - Fixes conversion between BF and integer type cooperative matrices, that was still using the old emit_alu1 approach instead of the new code for `@convert_cmat_intel`. - Guarantee the intermediate conversion will result in a valid layout for conversions involved USE_B matrices. If we instead used the intrinsic twice in brw_nir_lower_cooperative_matrix.c, a matrix with invalid layout would be visible at NIR level and we wouldn't be able to keep the current assertion for USE_B case. Due to the configurations we have exposed, we still don't need to write a more complex USE_B conversion -- they are all between same size types (and, consequently, packing factors), so no shuffling of data is needed to respect the USE_B layout. Reviewed-by: Matt Turner <mattst88@gmail.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/36185>
This commit is contained in:
parent
557ac588e4
commit
2dfd4dcbc5
2 changed files with 28 additions and 42 deletions
|
|
@ -4885,11 +4885,15 @@ brw_from_nir_emit_cs_intrinsic(nir_to_brw_state &ntb,
|
|||
const unsigned elems = src_components * src_packing_factor;
|
||||
|
||||
brw_builder bldn = bld.exec_all();
|
||||
const brw_reg src = retype(get_nir_src(ntb, instr->src[0], 0), src_type);
|
||||
brw_reg src = retype(get_nir_src(ntb, instr->src[0], 0), src_type);
|
||||
const brw_reg dst = retype(dest, dst_type);
|
||||
|
||||
assert(dst_cmat_desc.use == src_cmat_desc.use);
|
||||
|
||||
const bool needs_intermediate =
|
||||
(src.type == BRW_TYPE_BF && dst.type != BRW_TYPE_F) ||
|
||||
(dst.type == BRW_TYPE_BF && src.type != BRW_TYPE_F);
|
||||
|
||||
switch (src_cmat_desc.use) {
|
||||
case GLSL_CMAT_USE_B:
|
||||
assert(dst_element_bits == src_element_bits);
|
||||
|
|
@ -4898,6 +4902,16 @@ brw_from_nir_emit_cs_intrinsic(nir_to_brw_state &ntb,
|
|||
case GLSL_CMAT_USE_A:
|
||||
case GLSL_CMAT_USE_ACCUMULATOR: {
|
||||
const unsigned width = bldn.dispatch_width();
|
||||
|
||||
if (needs_intermediate) {
|
||||
brw_reg tmp = bldn.vgrf(BRW_TYPE_F, elems);
|
||||
for (unsigned c = 0; c < elems; c++) {
|
||||
bldn.MOV(suboffset(tmp, c * width),
|
||||
suboffset(src, c * width));
|
||||
}
|
||||
src = tmp;
|
||||
}
|
||||
|
||||
for (unsigned c = 0; c < elems; c++) {
|
||||
bldn.MOV(suboffset(dst, c * width),
|
||||
suboffset(src, c * width));
|
||||
|
|
|
|||
|
|
@ -466,51 +466,23 @@ lower_cmat_convert(nir_builder *b, nir_intrinsic_instr *intrin,
|
|||
const slice_info *dst_info = get_slice_info(state, dst_slice);
|
||||
const slice_info *src_info = get_slice_info(state, src_slice);
|
||||
|
||||
const nir_cmat_signed cmat_signed_mask = nir_intrinsic_cmat_signed_mask(intrin);
|
||||
/* Cooperative matrices must have the same "shape" to be converted. */
|
||||
assert(src_info->desc.rows == dst_info->desc.rows);
|
||||
assert(src_info->desc.cols == dst_info->desc.cols);
|
||||
assert(src_info->desc.use == dst_info->desc.use);
|
||||
assert(src_info->desc.scope == dst_info->desc.scope);
|
||||
|
||||
enum glsl_base_type src_element_type = glsl_apply_signedness_to_base_type(
|
||||
src_info->desc.element_type, cmat_signed_mask & NIR_CMAT_A_SIGNED);
|
||||
enum glsl_base_type dst_element_type = glsl_apply_signedness_to_base_type(
|
||||
dst_info->desc.element_type, cmat_signed_mask & NIR_CMAT_RESULT_SIGNED);
|
||||
|
||||
bool needs_intermediate =
|
||||
(src_element_type == GLSL_TYPE_BFLOAT16 && dst_element_type != GLSL_TYPE_FLOAT) ||
|
||||
(src_element_type != GLSL_TYPE_FLOAT && dst_element_type == GLSL_TYPE_BFLOAT16);
|
||||
|
||||
nir_def *result;
|
||||
nir_def *src = nir_load_deref(b, src_slice);
|
||||
|
||||
if (needs_intermediate) {
|
||||
/* Cooperative matrices must have the same "shape" to be converted. */
|
||||
assert(src_info->desc.rows == dst_info->desc.rows);
|
||||
assert(src_info->desc.cols == dst_info->desc.cols);
|
||||
assert(src_info->desc.use == dst_info->desc.use);
|
||||
assert(src_info->desc.scope == dst_info->desc.scope);
|
||||
const unsigned dst_components = glsl_get_vector_elements(dst_info->type);
|
||||
const unsigned dst_bits = glsl_base_type_bit_size(dst_info->desc.element_type);
|
||||
|
||||
struct glsl_cmat_description float_desc = src_info->desc;
|
||||
float_desc.element_type = GLSL_TYPE_FLOAT;
|
||||
|
||||
slice_info float_info = {};
|
||||
init_slice_info(state, float_desc, &float_info);
|
||||
|
||||
nir_op op1 = get_cmat_conversion_op(src_element_type, GLSL_TYPE_FLOAT);
|
||||
nir_op op2 = get_cmat_conversion_op(GLSL_TYPE_FLOAT, dst_element_type);
|
||||
|
||||
nir_def *tmp = emit_packed_alu1(b, state, src_info, &float_info, op1, src);
|
||||
result = emit_packed_alu1(b, state, &float_info, dst_info, op2, tmp);
|
||||
|
||||
} else {
|
||||
const unsigned dst_components = glsl_get_vector_elements(dst_info->type);
|
||||
const unsigned dst_bits = glsl_base_type_bit_size(dst_info->desc.element_type);
|
||||
|
||||
result =
|
||||
nir_convert_cmat_intel(b,
|
||||
dst_components,
|
||||
dst_info->packing_factor * dst_bits,
|
||||
src,
|
||||
.dst_cmat_desc = dst_info->desc,
|
||||
.src_cmat_desc = src_info->desc);
|
||||
}
|
||||
nir_def *result = nir_convert_cmat_intel(b,
|
||||
dst_components,
|
||||
dst_info->packing_factor * dst_bits,
|
||||
src,
|
||||
.dst_cmat_desc = dst_info->desc,
|
||||
.src_cmat_desc = src_info->desc);
|
||||
|
||||
nir_store_deref(b, dst_slice, result, nir_component_mask(result->num_components));
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue