diff --git a/src/compiler/Makefile.sources b/src/compiler/Makefile.sources index 82a69e6c8af..73f0cb2a85a 100644 --- a/src/compiler/Makefile.sources +++ b/src/compiler/Makefile.sources @@ -209,6 +209,7 @@ NIR_FILES = \ nir/nir_builder.h \ nir/nir_builtin_builder.c \ nir/nir_builtin_builder.h \ + nir/nir_conversion_builder.h \ nir/nir_clone.c \ nir/nir_constant_expressions.h \ nir/nir_control_flow.c \ diff --git a/src/compiler/nir/meson.build b/src/compiler/nir/meson.build index 1600a02b159..657d4b44cd4 100644 --- a/src/compiler/nir/meson.build +++ b/src/compiler/nir/meson.build @@ -89,6 +89,7 @@ files_libnir = files( 'nir_builder.h', 'nir_builtin_builder.c', 'nir_builtin_builder.h', + 'nir_conversion_builder.h', 'nir_clone.c', 'nir_constant_expressions.h', 'nir_control_flow.c', diff --git a/src/compiler/nir/nir_conversion_builder.h b/src/compiler/nir/nir_conversion_builder.h new file mode 100644 index 00000000000..78e41bfb690 --- /dev/null +++ b/src/compiler/nir/nir_conversion_builder.h @@ -0,0 +1,482 @@ +/* + * Copyright © 2020 Collabora Ltd. + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice (including the next + * paragraph) shall be included in all copies or substantial portions of the + * Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ + +#ifndef NIR_CONVERSION_BUILDER_H +#define NIR_CONVERSION_BUILDER_H + +#include "util/u_math.h" +#include "nir_builder.h" +#include "nir_builtin_builder.h" + +#ifdef __cplusplus +extern "C" { +#endif + +static inline nir_ssa_def * +nir_round_float_to_int(nir_builder *b, nir_ssa_def *src, + nir_rounding_mode round) +{ + switch (round) { + case nir_rounding_mode_ru: + return nir_fceil(b, src); + + case nir_rounding_mode_rd: + return nir_ffloor(b, src); + + case nir_rounding_mode_rtne: + return nir_fround_even(b, src); + + case nir_rounding_mode_undef: + case nir_rounding_mode_rtz: + break; + } + unreachable("unexpected rounding mode"); +} + +static inline nir_ssa_def * +nir_round_float_to_float(nir_builder *b, nir_ssa_def *src, + unsigned dest_bit_size, + nir_rounding_mode round) +{ + unsigned src_bit_size = src->bit_size; + if (dest_bit_size > src_bit_size) + return src; /* No rounding is needed for an up-convert */ + + nir_op low_conv = nir_type_conversion_op(nir_type_float | src_bit_size, + nir_type_float | dest_bit_size, + nir_rounding_mode_undef); + nir_op high_conv = nir_type_conversion_op(nir_type_float | dest_bit_size, + nir_type_float | src_bit_size, + nir_rounding_mode_undef); + + switch (round) { + case nir_rounding_mode_ru: { + /* If lower-precision conversion results in a lower value, push it + * up one ULP. */ + nir_ssa_def *lower_prec = + nir_build_alu(b, low_conv, src, NULL, NULL, NULL); + nir_ssa_def *roundtrip = + nir_build_alu(b, high_conv, lower_prec, NULL, NULL, NULL); + nir_ssa_def *cmp = nir_flt(b, roundtrip, src); + nir_ssa_def *inf = nir_imm_floatN_t(b, INFINITY, dest_bit_size); + return nir_bcsel(b, cmp, nir_nextafter(b, lower_prec, inf), lower_prec); + } + case nir_rounding_mode_rd: { + /* If lower-precision conversion results in a higher value, push it + * down one ULP. */ + nir_ssa_def *lower_prec = + nir_build_alu(b, low_conv, src, NULL, NULL, NULL); + nir_ssa_def *roundtrip = + nir_build_alu(b, high_conv, lower_prec, NULL, NULL, NULL); + nir_ssa_def *cmp = nir_flt(b, src, roundtrip); + nir_ssa_def *neg_inf = nir_imm_floatN_t(b, -INFINITY, dest_bit_size); + return nir_bcsel(b, cmp, nir_nextafter(b, lower_prec, neg_inf), lower_prec); + } + case nir_rounding_mode_rtz: + return nir_bcsel(b, nir_flt(b, src, nir_imm_zero(b, 1, src->bit_size)), + nir_round_float_to_float(b, src, dest_bit_size, + nir_rounding_mode_ru), + nir_round_float_to_float(b, src, dest_bit_size, + nir_rounding_mode_rd)); + case nir_rounding_mode_rtne: + case nir_rounding_mode_undef: + break; + } + unreachable("unexpected rounding mode"); +} + +static inline nir_ssa_def * +nir_round_int_to_float(nir_builder *b, nir_ssa_def *src, + nir_alu_type src_type, + unsigned dest_bit_size, + nir_rounding_mode round) +{ + /* We only care whether or not its signed */ + src_type = nir_alu_type_get_base_type(src_type); + + unsigned mantissa_bits; + switch (dest_bit_size) { + case 16: + mantissa_bits = 10; + break; + case 32: + mantissa_bits = 23; + break; + case 64: + mantissa_bits = 52; + break; + default: unreachable("Unsupported bit size"); + } + + if (src->bit_size < mantissa_bits) + return src; + + if (src_type == nir_type_int) { + nir_ssa_def *sign = + nir_i2b1(b, nir_ishr(b, src, nir_imm_int(b, src->bit_size - 1))); + nir_ssa_def *abs = nir_iabs(b, src); + nir_ssa_def *positive_rounded = + nir_round_int_to_float(b, abs, nir_type_uint, dest_bit_size, round); + nir_ssa_def *max_positive = + nir_imm_intN_t(b, (1ull << (src->bit_size - 1)) - 1, src->bit_size); + switch (round) { + case nir_rounding_mode_rtz: + return nir_bcsel(b, sign, nir_ineg(b, positive_rounded), + positive_rounded); + break; + case nir_rounding_mode_ru: + return nir_bcsel(b, sign, + nir_ineg(b, nir_round_int_to_float(b, abs, nir_type_uint, dest_bit_size, nir_rounding_mode_rd)), + nir_umin(b, positive_rounded, max_positive)); + break; + case nir_rounding_mode_rd: + return nir_bcsel(b, sign, + nir_ineg(b, + nir_umin(b, max_positive, + nir_round_int_to_float(b, abs, nir_type_uint, dest_bit_size, nir_rounding_mode_ru))), + positive_rounded); + case nir_rounding_mode_rtne: + case nir_rounding_mode_undef: + break; + } + unreachable("unexpected rounding mode"); + } else { + nir_ssa_def *mantissa_bit_size = nir_imm_int(b, mantissa_bits); + nir_ssa_def *msb = nir_imax(b, nir_ufind_msb(b, src), mantissa_bit_size); + nir_ssa_def *bits_to_lose = nir_isub(b, msb, mantissa_bit_size); + nir_ssa_def *one = nir_imm_intN_t(b, 1, src->bit_size); + nir_ssa_def *adjust = nir_ishl(b, one, bits_to_lose); + nir_ssa_def *mask = nir_inot(b, nir_isub(b, adjust, one)); + nir_ssa_def *truncated = nir_iand(b, src, mask); + switch (round) { + case nir_rounding_mode_rtz: + case nir_rounding_mode_rd: + return truncated; + break; + case nir_rounding_mode_ru: + return nir_bcsel(b, nir_ieq(b, src, truncated), + src, nir_uadd_sat(b, truncated, adjust)); + case nir_rounding_mode_rtne: + case nir_rounding_mode_undef: + break; + } + unreachable("unexpected rounding mode"); + } +} + +/** Returns true if the representable range of a contains the representable + * range of b. + */ +static inline bool +nir_alu_type_range_contains_type_range(nir_alu_type a, nir_alu_type b) +{ + /* Split types from bit sizes */ + nir_alu_type a_base_type = nir_alu_type_get_base_type(a); + nir_alu_type b_base_type = nir_alu_type_get_base_type(b); + unsigned a_bit_size = nir_alu_type_get_type_size(a); + unsigned b_bit_size = nir_alu_type_get_type_size(b); + + /* This requires sized types */ + assert(a_bit_size > 0 && b_bit_size > 0); + + if (a_base_type == b_base_type && a_bit_size >= b_bit_size) + return true; + + if (a_base_type == nir_type_int && b_base_type == nir_type_uint && + a_bit_size > b_bit_size) + return true; + + /* 16-bit floats fit in 32-bit integers */ + if (a_base_type == nir_type_int && a_bit_size >= 32 && + b == nir_type_float16) + return true; + + /* All signed or unsigned ints can fit in float or above. A uint8 can fit + * in a float16. + */ + if (a_base_type == nir_type_float && b_base_type != nir_type_float && + (a_bit_size >= 32 || b_bit_size == 8)) + return true; + + return false; +} + +/** + * Clamp the source value into the widest representatble range of the + * destination type with cmp + bcsel. + */ +static inline nir_ssa_def * +nir_clamp_to_type_range(nir_builder *b, + nir_ssa_def *src, nir_alu_type src_type, + nir_alu_type dest_type) +{ + assert(nir_alu_type_get_type_size(src_type) == 0 || + nir_alu_type_get_type_size(src_type) == src->bit_size); + src_type |= src->bit_size; + if (nir_alu_type_range_contains_type_range(dest_type, src_type)) + return src; + + /* Split types from bit sizes */ + nir_alu_type src_base_type = nir_alu_type_get_base_type(src_type); + nir_alu_type dest_base_type = nir_alu_type_get_base_type(dest_type); + unsigned dest_bit_size = nir_alu_type_get_type_size(dest_type); + assert(dest_bit_size != 0); + + /* limits of the destination type, expressed in the source type */ + nir_ssa_def *low = NULL, *high = NULL; + switch (dest_base_type) { + case nir_type_int: { + int64_t ilow, ihigh; + if (dest_bit_size == 64) { + ilow = INT64_MIN; + ihigh = INT64_MAX; + } else { + ilow = -(1ll << (dest_bit_size - 1)); + ihigh = (1ll << (dest_bit_size - 1)) - 1; + } + + if (src_base_type == nir_type_int) { + low = nir_imm_intN_t(b, ilow, src->bit_size); + high = nir_imm_intN_t(b, ihigh, src->bit_size); + } else if (src_base_type == nir_type_uint) { + assert(src->bit_size >= dest_bit_size); + high = nir_imm_intN_t(b, ihigh, src->bit_size); + } else { + low = nir_imm_floatN_t(b, ilow, src->bit_size); + high = nir_imm_floatN_t(b, ihigh, src->bit_size); + } + break; + } + case nir_type_uint: { + uint64_t uhigh = dest_bit_size == 64 ? + ~0ull : (1ull << dest_bit_size) - 1; + if (src_base_type != nir_type_float) { + low = nir_imm_intN_t(b, 0, src->bit_size); + if (src_base_type == nir_type_uint || src->bit_size > dest_bit_size) + high = nir_imm_intN_t(b, uhigh, src->bit_size); + } else { + low = nir_imm_floatN_t(b, 0.0f, src->bit_size); + high = nir_imm_floatN_t(b, uhigh, src->bit_size); + } + break; + } + case nir_type_float: { + double flow, fhigh; + switch (dest_bit_size) { + case 16: + flow = -65504.0f; + fhigh = 65504.0f; + break; + case 32: + flow = -FLT_MAX; + fhigh = FLT_MAX; + break; + case 64: + flow = -DBL_MAX; + fhigh = DBL_MAX; + break; + default: + unreachable("Unhandled bit size"); + } + + switch (src_base_type) { + case nir_type_int: { + int64_t src_ilow, src_ihigh; + if (src->bit_size == 64) { + src_ilow = INT64_MIN; + src_ihigh = INT64_MAX; + } else { + src_ilow = -(1ll << (src->bit_size - 1)); + src_ihigh = (1ll << (src->bit_size - 1)) - 1; + } + if (src_ilow < flow) + low = nir_imm_intN_t(b, flow, src->bit_size); + if (src_ihigh > fhigh) + high = nir_imm_intN_t(b, fhigh, src->bit_size); + break; + } + case nir_type_uint: { + uint64_t src_uhigh = src->bit_size == 64 ? + ~0ull : (1ull << src->bit_size) - 1; + if (src_uhigh > fhigh) + high = nir_imm_intN_t(b, fhigh, src->bit_size); + break; + } + case nir_type_float: + low = nir_imm_floatN_t(b, flow, src->bit_size); + high = nir_imm_floatN_t(b, fhigh, src->bit_size); + break; + default: + unreachable("Clamping from unknown type"); + } + break; + } + default: + unreachable("clamping to unknown type"); + break; + } + + nir_ssa_def *low_cond = NULL, *high_cond = NULL; + switch (src_base_type) { + case nir_type_int: + low_cond = low ? nir_ilt(b, src, low) : NULL; + high_cond = high ? nir_ilt(b, high, src) : NULL; + break; + case nir_type_uint: + low_cond = low ? nir_ult(b, src, low) : NULL; + high_cond = high ? nir_ult(b, high, src) : NULL; + break; + case nir_type_float: + low_cond = low ? nir_flt(b, src, low) : NULL; + high_cond = high ? nir_flt(b, high, src) : NULL; + break; + default: + unreachable("clamping from unknown type"); + } + + nir_ssa_def *res = src; + if (low_cond) + res = nir_bcsel(b, low_cond, low, res); + if (high_cond) + res = nir_bcsel(b, high_cond, high, res); + + return res; +} + +static inline nir_rounding_mode +nir_simplify_conversion_rounding(nir_alu_type src_type, + nir_alu_type dest_type, + nir_rounding_mode rounding) +{ + nir_alu_type src_base_type = nir_alu_type_get_base_type(src_type); + nir_alu_type dest_base_type = nir_alu_type_get_base_type(dest_type); + unsigned src_bit_size = nir_alu_type_get_type_size(src_type); + unsigned dest_bit_size = nir_alu_type_get_type_size(dest_type); + assert(src_bit_size > 0 && dest_bit_size > 0); + + if (rounding == nir_rounding_mode_undef) + return rounding; + + /* Pure integer conversion doesn't have any rounding */ + if (src_base_type != nir_type_float && + dest_base_type != nir_type_float) + return nir_rounding_mode_undef; + + /* Float down-casts don't round */ + if (src_base_type == nir_type_float && + dest_base_type == nir_type_float && + dest_bit_size >= src_bit_size) + return nir_rounding_mode_undef; + + /* Regular float to int conversions are RTZ */ + if (src_base_type == nir_type_float && + dest_base_type != nir_type_float && + rounding == nir_rounding_mode_rtz) + return nir_rounding_mode_undef; + + /* The CL spec requires regular conversions to float to be RTNE */ + if (dest_base_type == nir_type_float && + rounding == nir_rounding_mode_rtne) + return nir_rounding_mode_undef; + + /* Couldn't simplify */ + return rounding; +} + +static inline nir_ssa_def * +nir_convert_with_rounding(nir_builder *b, + nir_ssa_def *src, nir_alu_type src_type, + nir_alu_type dest_type, + nir_rounding_mode round, + bool clamp) +{ + /* Some stuff wants sized types */ + assert(nir_alu_type_get_type_size(src_type) == 0 || + nir_alu_type_get_type_size(src_type) == src->bit_size); + src_type |= src->bit_size; + + /* Split types from bit sizes */ + nir_alu_type src_base_type = nir_alu_type_get_base_type(src_type); + nir_alu_type dest_base_type = nir_alu_type_get_base_type(dest_type); + unsigned dest_bit_size = nir_alu_type_get_type_size(dest_type); + + /* Try to simplify the conversion if we can */ + clamp = clamp && + !nir_alu_type_range_contains_type_range(dest_type, src_type); + round = nir_simplify_conversion_rounding(src_type, dest_type, round); + + /* + * If we don't care about rounding and clamping, we can just use NIR's + * built-in ops. There is also a special case for SPIR-V in shaders, where + * f32/f64 -> f16 conversions can have one of two rounding modes applied, + * which NIR has built-in opcodes for. + * + * For the rest, we have our own implementation of rounding and clamping. + */ + bool trivial_convert; + if (!clamp && round == nir_rounding_mode_undef) { + trivial_convert = true; + } else if (!clamp && src_type == nir_type_float32 && + dest_type == nir_type_float16 && + (round == nir_rounding_mode_rtne || + round == nir_rounding_mode_rtz)) { + trivial_convert = true; + } else { + trivial_convert = false; + } + if (trivial_convert) { + nir_op op = nir_type_conversion_op(src_type, dest_type, round); + return nir_build_alu(b, op, src, NULL, NULL, NULL); + } + + nir_ssa_def *dest = src; + + /* clamp the result into range */ + if (clamp) + dest = nir_clamp_to_type_range(b, dest, src_type, dest_type); + + /* round with selected rounding mode */ + if (!trivial_convert && round != nir_rounding_mode_undef) { + if (src_base_type == nir_type_float) { + if (dest_base_type == nir_type_float) { + dest = nir_round_float_to_float(b, dest, dest_bit_size, round); + } else { + dest = nir_round_float_to_int(b, dest, round); + } + } else { + dest = nir_round_int_to_float(b, dest, src_type, dest_bit_size, round); + } + + round = nir_rounding_mode_undef; + } + + /* now we can convert the value */ + nir_op op = nir_type_conversion_op(src_type, dest_type, round); + return nir_build_alu(b, op, dest, NULL, NULL, NULL); +} + +#ifdef __cplusplus +} +#endif + +#endif /* NIR_CONVERSION_BUILDER_H */