From 25ee07373c3123e5f10dff74e7d90311b3c7b60f Mon Sep 17 00:00:00 2001 From: Jesse Natalie Date: Mon, 30 Jan 2023 11:10:03 -0800 Subject: [PATCH] nir_lower_fp16_casts: Allow opting out of lowering certain rounding modes Reviewed-by: Giancarlo Devich Reviewed-by: Faith Ekstrand Part-of: --- src/compiler/nir/nir.h | 10 ++- src/compiler/nir/nir_lower_fp16_conv.c | 85 ++++++++++++++-------- src/gallium/auxiliary/gallivm/lp_bld_nir.c | 2 +- src/microsoft/clc/clc_compiler.c | 2 +- src/microsoft/compiler/dxil_nir.h | 1 - 5 files changed, 64 insertions(+), 36 deletions(-) diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index f35b72d88f3..d410f118b12 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -5567,8 +5567,14 @@ bool nir_shader_uses_view_index(nir_shader *shader); bool nir_can_lower_multiview(nir_shader *shader); bool nir_lower_multiview(nir_shader *shader, uint32_t view_mask); - -bool nir_lower_fp16_casts(nir_shader *shader); +typedef enum { + nir_lower_fp16_rtz = (1 << 0), + nir_lower_fp16_rtne = (1 << 1), + nir_lower_fp16_ru = (1 << 2), + nir_lower_fp16_rd = (1 << 3), + nir_lower_fp16_all = 0xf, +} nir_lower_fp16_cast_options; +bool nir_lower_fp16_casts(nir_shader *shader, nir_lower_fp16_cast_options options); bool nir_normalize_cubemap_coords(nir_shader *shader); bool nir_shader_supports_implicit_lod(nir_shader *shader); diff --git a/src/compiler/nir/nir_lower_fp16_conv.c b/src/compiler/nir/nir_lower_fp16_conv.c index 2f6862731a3..194dc9b820a 100644 --- a/src/compiler/nir/nir_lower_fp16_conv.c +++ b/src/compiler/nir/nir_lower_fp16_conv.c @@ -44,26 +44,6 @@ * * Version 2.1.0 */ -static bool -lower_fp16_casts_filter(const nir_instr *instr, const void *data) -{ - if (instr->type == nir_instr_type_alu) { - nir_alu_instr *alu = nir_instr_as_alu(instr); - switch (alu->op) { - case nir_op_f2f16: - case nir_op_f2f16_rtne: - case nir_op_f2f16_rtz: - return true; - default: - return false; - } - } else if (instr->type == nir_instr_type_intrinsic) { - nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); - return intrin->intrinsic == nir_intrinsic_convert_alu_types && - nir_intrinsic_dest_type(intrin) == nir_type_float16; - } - return false; -} static nir_ssa_def * half_rounded(nir_builder *b, nir_ssa_def *value, nir_ssa_def *guard, nir_ssa_def *sticky, @@ -188,12 +168,12 @@ float_to_half_impl(nir_builder *b, nir_ssa_def *src, nir_rounding_mode mode) return nir_u2u16(b, nir_ior(b, fp16, nir_ushr(b, sign, nir_imm_int(b, 16)))); } -static nir_ssa_def * +static bool lower_fp16_cast_impl(nir_builder *b, nir_instr *instr, void *data) { nir_ssa_def *src, *dst; uint8_t *swizzle = NULL; - nir_rounding_mode mode = nir_rounding_mode_rtne; + nir_rounding_mode mode = nir_rounding_mode_undef; if (instr->type == nir_instr_type_alu) { nir_alu_instr *alu = nir_instr_as_alu(instr); @@ -202,21 +182,62 @@ lower_fp16_cast_impl(nir_builder *b, nir_instr *instr, void *data) dst = &alu->dest.dest.ssa; switch (alu->op) { case nir_op_f2f16: + if (b->shader->info.float_controls_execution_mode & FLOAT_CONTROLS_ROUNDING_MODE_RTZ_FP16) + mode = nir_rounding_mode_rtz; + else if (b->shader->info.float_controls_execution_mode & FLOAT_CONTROLS_ROUNDING_MODE_RTE_FP16) + mode = nir_rounding_mode_rtne; + break; case nir_op_f2f16_rtne: + mode = nir_rounding_mode_rtne; break; case nir_op_f2f16_rtz: mode = nir_rounding_mode_rtz; break; - default: unreachable("Should've been filtered"); + default: + return false; } - } else { + } else if (instr->type == nir_instr_type_intrinsic) { nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr); - assert(nir_intrinsic_src_type(intrin) == nir_type_float32); + if (intrin->intrinsic != nir_intrinsic_convert_alu_types || + nir_intrinsic_dest_type(intrin) != nir_type_float16) + return false; src = intrin->src[0].ssa; dst = &intrin->dest.ssa; mode = nir_intrinsic_rounding_mode(intrin); + } else { + return false; } + nir_lower_fp16_cast_options options = *(nir_lower_fp16_cast_options *)data; + nir_lower_fp16_cast_options req_option = 0; + switch (mode) { + case nir_rounding_mode_rtz: + req_option = nir_lower_fp16_rtz; + break; + case nir_rounding_mode_rtne: + req_option = nir_lower_fp16_rtne; + break; + case nir_rounding_mode_ru: + req_option = nir_lower_fp16_ru; + break; + case nir_rounding_mode_rd: + req_option = nir_lower_fp16_rd; + break; + case nir_rounding_mode_undef: + if (options == nir_lower_fp16_all) { + /* Pick one arbitrarily for lowering */ + mode = nir_rounding_mode_rtne; + req_option = nir_lower_fp16_rtne; + } + /* Otherwise assume the backend can handle f2f16 with undef rounding */ + break; + default: + unreachable("Invalid rounding mode"); + } + if (!(options & req_option)) + return false; + + b->cursor = nir_before_instr(instr); nir_ssa_def *rets[NIR_MAX_VEC_COMPONENTS] = { NULL }; for (unsigned i = 0; i < dst->num_components; i++) { @@ -224,14 +245,16 @@ lower_fp16_cast_impl(nir_builder *b, nir_instr *instr, void *data) rets[i] = float_to_half_impl(b, comp, mode); } - return nir_vec(b, rets, dst->num_components); + nir_ssa_def *new_val = nir_vec(b, rets, dst->num_components); + nir_ssa_def_rewrite_uses(dst, new_val); + return true; } bool -nir_lower_fp16_casts(nir_shader *shader) +nir_lower_fp16_casts(nir_shader *shader, nir_lower_fp16_cast_options options) { - return nir_shader_lower_instructions(shader, - lower_fp16_casts_filter, - lower_fp16_cast_impl, - NULL); + return nir_shader_instructions_pass(shader, + lower_fp16_cast_impl, + nir_metadata_none, + &options); } diff --git a/src/gallium/auxiliary/gallivm/lp_bld_nir.c b/src/gallium/auxiliary/gallivm/lp_bld_nir.c index 3a127c25abd..8602491728d 100644 --- a/src/gallium/auxiliary/gallivm/lp_bld_nir.c +++ b/src/gallium/auxiliary/gallivm/lp_bld_nir.c @@ -2775,7 +2775,7 @@ lp_build_opt_nir(struct nir_shader *nir) NIR_PASS_V(nir, nir_lower_frexp); NIR_PASS_V(nir, nir_lower_flrp, 16|32|64, true); - NIR_PASS_V(nir, nir_lower_fp16_casts); + NIR_PASS_V(nir, nir_lower_fp16_casts, nir_lower_fp16_all); do { progress = false; NIR_PASS(progress, nir, nir_opt_constant_folding); diff --git a/src/microsoft/clc/clc_compiler.c b/src/microsoft/clc/clc_compiler.c index 9400f393426..01ac14df2c6 100644 --- a/src/microsoft/clc/clc_compiler.c +++ b/src/microsoft/clc/clc_compiler.c @@ -1085,7 +1085,7 @@ clc_spirv_to_dxil(struct clc_libclc *lib, NIR_PASS_V(nir, dxil_nir_lower_loads_stores_to_dxil); NIR_PASS_V(nir, dxil_nir_opt_alu_deref_srcs); NIR_PASS_V(nir, dxil_nir_lower_atomics_to_dxil); - NIR_PASS_V(nir, nir_lower_fp16_casts); + NIR_PASS_V(nir, nir_lower_fp16_casts, nir_lower_fp16_all); NIR_PASS_V(nir, nir_lower_convert_alu_types, NULL); // Convert pack to pack_split diff --git a/src/microsoft/compiler/dxil_nir.h b/src/microsoft/compiler/dxil_nir.h index adeea6d3fad..0c0a85e0854 100644 --- a/src/microsoft/compiler/dxil_nir.h +++ b/src/microsoft/compiler/dxil_nir.h @@ -42,7 +42,6 @@ bool dxil_nir_lower_atomics_to_dxil(nir_shader *shader); bool dxil_nir_lower_deref_ssbo(nir_shader *shader); bool dxil_nir_opt_alu_deref_srcs(nir_shader *shader); bool dxil_nir_lower_upcast_phis(nir_shader *shader, unsigned min_bit_size); -bool dxil_nir_lower_fp16_casts(nir_shader *shader); bool dxil_nir_split_clip_cull_distance(nir_shader *shader); bool dxil_nir_lower_double_math(nir_shader *shader); bool dxil_nir_lower_system_values_to_zero(nir_shader *shader,