nir_lower_fp16_casts: Allow opting out of lowering certain rounding modes

Reviewed-by: Giancarlo Devich <gdevich@microsoft.com>
Reviewed-by: Faith Ekstrand <faith.ekstrand@collabora.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/21029>
This commit is contained in:
Jesse Natalie 2023-01-30 11:10:03 -08:00 committed by Marge Bot
parent c0c2b60f1d
commit 25ee07373c
5 changed files with 64 additions and 36 deletions

View file

@ -5567,8 +5567,14 @@ bool nir_shader_uses_view_index(nir_shader *shader);
bool nir_can_lower_multiview(nir_shader *shader);
bool nir_lower_multiview(nir_shader *shader, uint32_t view_mask);
bool nir_lower_fp16_casts(nir_shader *shader);
typedef enum {
nir_lower_fp16_rtz = (1 << 0),
nir_lower_fp16_rtne = (1 << 1),
nir_lower_fp16_ru = (1 << 2),
nir_lower_fp16_rd = (1 << 3),
nir_lower_fp16_all = 0xf,
} nir_lower_fp16_cast_options;
bool nir_lower_fp16_casts(nir_shader *shader, nir_lower_fp16_cast_options options);
bool nir_normalize_cubemap_coords(nir_shader *shader);
bool nir_shader_supports_implicit_lod(nir_shader *shader);

View file

@ -44,26 +44,6 @@
*
* Version 2.1.0
*/
static bool
lower_fp16_casts_filter(const nir_instr *instr, const void *data)
{
if (instr->type == nir_instr_type_alu) {
nir_alu_instr *alu = nir_instr_as_alu(instr);
switch (alu->op) {
case nir_op_f2f16:
case nir_op_f2f16_rtne:
case nir_op_f2f16_rtz:
return true;
default:
return false;
}
} else if (instr->type == nir_instr_type_intrinsic) {
nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
return intrin->intrinsic == nir_intrinsic_convert_alu_types &&
nir_intrinsic_dest_type(intrin) == nir_type_float16;
}
return false;
}
static nir_ssa_def *
half_rounded(nir_builder *b, nir_ssa_def *value, nir_ssa_def *guard, nir_ssa_def *sticky,
@ -188,12 +168,12 @@ float_to_half_impl(nir_builder *b, nir_ssa_def *src, nir_rounding_mode mode)
return nir_u2u16(b, nir_ior(b, fp16, nir_ushr(b, sign, nir_imm_int(b, 16))));
}
static nir_ssa_def *
static bool
lower_fp16_cast_impl(nir_builder *b, nir_instr *instr, void *data)
{
nir_ssa_def *src, *dst;
uint8_t *swizzle = NULL;
nir_rounding_mode mode = nir_rounding_mode_rtne;
nir_rounding_mode mode = nir_rounding_mode_undef;
if (instr->type == nir_instr_type_alu) {
nir_alu_instr *alu = nir_instr_as_alu(instr);
@ -202,21 +182,62 @@ lower_fp16_cast_impl(nir_builder *b, nir_instr *instr, void *data)
dst = &alu->dest.dest.ssa;
switch (alu->op) {
case nir_op_f2f16:
if (b->shader->info.float_controls_execution_mode & FLOAT_CONTROLS_ROUNDING_MODE_RTZ_FP16)
mode = nir_rounding_mode_rtz;
else if (b->shader->info.float_controls_execution_mode & FLOAT_CONTROLS_ROUNDING_MODE_RTE_FP16)
mode = nir_rounding_mode_rtne;
break;
case nir_op_f2f16_rtne:
mode = nir_rounding_mode_rtne;
break;
case nir_op_f2f16_rtz:
mode = nir_rounding_mode_rtz;
break;
default: unreachable("Should've been filtered");
default:
return false;
}
} else {
} else if (instr->type == nir_instr_type_intrinsic) {
nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
assert(nir_intrinsic_src_type(intrin) == nir_type_float32);
if (intrin->intrinsic != nir_intrinsic_convert_alu_types ||
nir_intrinsic_dest_type(intrin) != nir_type_float16)
return false;
src = intrin->src[0].ssa;
dst = &intrin->dest.ssa;
mode = nir_intrinsic_rounding_mode(intrin);
} else {
return false;
}
nir_lower_fp16_cast_options options = *(nir_lower_fp16_cast_options *)data;
nir_lower_fp16_cast_options req_option = 0;
switch (mode) {
case nir_rounding_mode_rtz:
req_option = nir_lower_fp16_rtz;
break;
case nir_rounding_mode_rtne:
req_option = nir_lower_fp16_rtne;
break;
case nir_rounding_mode_ru:
req_option = nir_lower_fp16_ru;
break;
case nir_rounding_mode_rd:
req_option = nir_lower_fp16_rd;
break;
case nir_rounding_mode_undef:
if (options == nir_lower_fp16_all) {
/* Pick one arbitrarily for lowering */
mode = nir_rounding_mode_rtne;
req_option = nir_lower_fp16_rtne;
}
/* Otherwise assume the backend can handle f2f16 with undef rounding */
break;
default:
unreachable("Invalid rounding mode");
}
if (!(options & req_option))
return false;
b->cursor = nir_before_instr(instr);
nir_ssa_def *rets[NIR_MAX_VEC_COMPONENTS] = { NULL };
for (unsigned i = 0; i < dst->num_components; i++) {
@ -224,14 +245,16 @@ lower_fp16_cast_impl(nir_builder *b, nir_instr *instr, void *data)
rets[i] = float_to_half_impl(b, comp, mode);
}
return nir_vec(b, rets, dst->num_components);
nir_ssa_def *new_val = nir_vec(b, rets, dst->num_components);
nir_ssa_def_rewrite_uses(dst, new_val);
return true;
}
bool
nir_lower_fp16_casts(nir_shader *shader)
nir_lower_fp16_casts(nir_shader *shader, nir_lower_fp16_cast_options options)
{
return nir_shader_lower_instructions(shader,
lower_fp16_casts_filter,
lower_fp16_cast_impl,
NULL);
return nir_shader_instructions_pass(shader,
lower_fp16_cast_impl,
nir_metadata_none,
&options);
}

View file

@ -2775,7 +2775,7 @@ lp_build_opt_nir(struct nir_shader *nir)
NIR_PASS_V(nir, nir_lower_frexp);
NIR_PASS_V(nir, nir_lower_flrp, 16|32|64, true);
NIR_PASS_V(nir, nir_lower_fp16_casts);
NIR_PASS_V(nir, nir_lower_fp16_casts, nir_lower_fp16_all);
do {
progress = false;
NIR_PASS(progress, nir, nir_opt_constant_folding);

View file

@ -1085,7 +1085,7 @@ clc_spirv_to_dxil(struct clc_libclc *lib,
NIR_PASS_V(nir, dxil_nir_lower_loads_stores_to_dxil);
NIR_PASS_V(nir, dxil_nir_opt_alu_deref_srcs);
NIR_PASS_V(nir, dxil_nir_lower_atomics_to_dxil);
NIR_PASS_V(nir, nir_lower_fp16_casts);
NIR_PASS_V(nir, nir_lower_fp16_casts, nir_lower_fp16_all);
NIR_PASS_V(nir, nir_lower_convert_alu_types, NULL);
// Convert pack to pack_split

View file

@ -42,7 +42,6 @@ bool dxil_nir_lower_atomics_to_dxil(nir_shader *shader);
bool dxil_nir_lower_deref_ssbo(nir_shader *shader);
bool dxil_nir_opt_alu_deref_srcs(nir_shader *shader);
bool dxil_nir_lower_upcast_phis(nir_shader *shader, unsigned min_bit_size);
bool dxil_nir_lower_fp16_casts(nir_shader *shader);
bool dxil_nir_split_clip_cull_distance(nir_shader *shader);
bool dxil_nir_lower_double_math(nir_shader *shader);
bool dxil_nir_lower_system_values_to_zero(nir_shader *shader,