nak/nir: Use correct rounding for fp64 -> fp16 conversions

For up, down, and round towards zero, the rounding accumulates properly
as long as you use the same rounding mode for both.  For RTNE, however,
we need to insert a two-instruction fixup in order to guarantee correct
rounding.

Reviewed-by: Mel Henning <mhenning@darkrefraction.com>
Reviewed-by: Benjamin Lee <benjamin.lee@collabora.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/34126>
This commit is contained in:
Faith Ekstrand 2025-03-12 18:01:58 -05:00 committed by Marge Bot
parent d826f82ffe
commit 2d75e7dced

View file

@ -55,6 +55,7 @@ split_64bit_conversion(nir_builder *b, nir_instr *instr, UNUSED void *_data)
nir_alu_type dst_full_type = nir_op_infos[alu->op].output_type;
assert(nir_alu_type_get_type_size(dst_full_type) == dst_bit_size);
nir_alu_type dst_type = nir_alu_type_get_base_type(dst_full_type);
const nir_rounding_mode rounding_mode = op_rounding_mode(alu->op);
/* We can't cross the 64-bit boundary in one conversion */
if ((src_bit_size <= 32 && dst_bit_size <= 32) ||
@ -87,10 +88,95 @@ split_64bit_conversion(nir_builder *b, nir_instr *instr, UNUSED void *_data)
b->cursor = nir_before_instr(&alu->instr);
nir_def *src = nir_ssa_for_alu_src(b, alu, 0);
nir_def *tmp = nir_type_convert(b, src, src_type, tmp_type,
nir_rounding_mode_undef);
nir_def *tmp;
if (src_full_type == nir_type_float64 && dst_full_type == nir_type_float16) {
/* For fp64->fp16 conversions, we need to be careful with the first
* conversion or else rounding might not accumulate properly.
*/
assert(tmp_type == nir_type_float32);
if (rounding_mode == nir_rounding_mode_rtne ||
rounding_mode == nir_rounding_mode_undef) {
nir_def *src_lo = nir_unpack_64_2x32_split_x(b, src);
nir_def *src_hi = nir_unpack_64_2x32_split_y(b, src);
/* RTNE is tricky to get right through a double conversion. To work
* around this, we do a little fixup of the fp64 value first.
*
* For a 64-bit float, the mantissa bits are as follows:
*
* HHHHHHHHHHHLTFFFFFFFFF FFFDDDDDDDDDDDDDDDDDDDDDDDDDDDDD
* | |
* +------- bottom 32 bits -------+
*
* Where:
* - D are only used for fp64
* - T and F are used for fp64 and fp32
* - H and L are used for fp64, fp32, and fp16
* - L denotes the low bit of the fp16 mantissa
* - T is the tie bit
*
* The RTNE tie-breaking rules for fp64 -> fp16 can then be described
* as follows:
*
* - If any F or D bit is non-zero:
* - If T == 1, round up
* - If T == 0, round down
* - If all F and D bits are zero:
* - If T == 0, it's already fp16, do nothing
* - If T != 0 and L == 0, round down
* - If T != 0 and L != 0, round up
*
* What's important here is that the only way the F or D bits fit
* into the algorithm is if any are zero or none are zero. So we
* will get the same result if we take all of the bits in the low
* dword, or them together, and then or that into the low F bits of
* the high dword. The result of "all F and D bits are zero" will be
* the same. We can also zero the low dword without affecting the
* final result. Doing this accomplishes two useful things:
*
* 1. The resulting fp64 value is exactly representable as fp32 so
* we don't have to care about the rounding of the fp64 -> fp32
* conversion.
*
* 2. The fp32 -> fp16 conversion will round exactly the same as a
* full fp64 -> fp16 conversion on the original data since it now
* takes all of the D bits into account as well as the F bits.
*
* It's also correct for NaN/INF since those are delineated by the
* entire mantissa being either zero or non-zero. For denorms,
* anything that might be a denorm in fp32 or fp64 will have a
* sufficiently negative exponent that it will flush to zero when
* converted to fp16, regardless of what we do here.
*
* There are many operations we could choose for combining the low
* dword bits for ORing into the high dword. We choose umin because
* it nicely translates to a single fixed-latency instruction on
* everything except Volta.
*/
src_hi = nir_ior(b, src_hi, nir_umin_imm(b, src_lo, 1));
src_lo = nir_imm_int(b, 0);
tmp = nir_f2f32(b, nir_pack_64_2x32_split(b, src_lo, src_hi));
} else {
/* For round-up, round-down, and round-towards-zero, the rounding
* accumulates properly as long as we use the same rounding mode for
* both operations.
*/
tmp = nir_convert_alu_types(b, 32, src,
.src_type = nir_type_float64,
.dest_type = tmp_type,
.rounding_mode = rounding_mode,
.saturate = false);
}
} else {
/* This is an up-convert or a convert to integer, in which case we
* always round towards zero.
*/
tmp = nir_type_convert(b, src, src_type, tmp_type,
nir_rounding_mode_undef);
}
nir_def *res = nir_type_convert(b, tmp, tmp_type, dst_full_type,
op_rounding_mode(alu->op));
rounding_mode);
nir_def_replace(&alu->def, res);
return true;