mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2026-01-06 02:20:11 +01:00
nir_lower_fp16_casts: Allow opting out of lowering certain rounding modes
Reviewed-by: Giancarlo Devich <gdevich@microsoft.com> Reviewed-by: Faith Ekstrand <faith.ekstrand@collabora.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/21029>
This commit is contained in:
parent
c0c2b60f1d
commit
25ee07373c
5 changed files with 64 additions and 36 deletions
|
|
@ -5567,8 +5567,14 @@ bool nir_shader_uses_view_index(nir_shader *shader);
|
|||
bool nir_can_lower_multiview(nir_shader *shader);
|
||||
bool nir_lower_multiview(nir_shader *shader, uint32_t view_mask);
|
||||
|
||||
|
||||
bool nir_lower_fp16_casts(nir_shader *shader);
|
||||
typedef enum {
|
||||
nir_lower_fp16_rtz = (1 << 0),
|
||||
nir_lower_fp16_rtne = (1 << 1),
|
||||
nir_lower_fp16_ru = (1 << 2),
|
||||
nir_lower_fp16_rd = (1 << 3),
|
||||
nir_lower_fp16_all = 0xf,
|
||||
} nir_lower_fp16_cast_options;
|
||||
bool nir_lower_fp16_casts(nir_shader *shader, nir_lower_fp16_cast_options options);
|
||||
bool nir_normalize_cubemap_coords(nir_shader *shader);
|
||||
|
||||
bool nir_shader_supports_implicit_lod(nir_shader *shader);
|
||||
|
|
|
|||
|
|
@ -44,26 +44,6 @@
|
|||
*
|
||||
* Version 2.1.0
|
||||
*/
|
||||
static bool
|
||||
lower_fp16_casts_filter(const nir_instr *instr, const void *data)
|
||||
{
|
||||
if (instr->type == nir_instr_type_alu) {
|
||||
nir_alu_instr *alu = nir_instr_as_alu(instr);
|
||||
switch (alu->op) {
|
||||
case nir_op_f2f16:
|
||||
case nir_op_f2f16_rtne:
|
||||
case nir_op_f2f16_rtz:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
} else if (instr->type == nir_instr_type_intrinsic) {
|
||||
nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
|
||||
return intrin->intrinsic == nir_intrinsic_convert_alu_types &&
|
||||
nir_intrinsic_dest_type(intrin) == nir_type_float16;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
static nir_ssa_def *
|
||||
half_rounded(nir_builder *b, nir_ssa_def *value, nir_ssa_def *guard, nir_ssa_def *sticky,
|
||||
|
|
@ -188,12 +168,12 @@ float_to_half_impl(nir_builder *b, nir_ssa_def *src, nir_rounding_mode mode)
|
|||
return nir_u2u16(b, nir_ior(b, fp16, nir_ushr(b, sign, nir_imm_int(b, 16))));
|
||||
}
|
||||
|
||||
static nir_ssa_def *
|
||||
static bool
|
||||
lower_fp16_cast_impl(nir_builder *b, nir_instr *instr, void *data)
|
||||
{
|
||||
nir_ssa_def *src, *dst;
|
||||
uint8_t *swizzle = NULL;
|
||||
nir_rounding_mode mode = nir_rounding_mode_rtne;
|
||||
nir_rounding_mode mode = nir_rounding_mode_undef;
|
||||
|
||||
if (instr->type == nir_instr_type_alu) {
|
||||
nir_alu_instr *alu = nir_instr_as_alu(instr);
|
||||
|
|
@ -202,21 +182,62 @@ lower_fp16_cast_impl(nir_builder *b, nir_instr *instr, void *data)
|
|||
dst = &alu->dest.dest.ssa;
|
||||
switch (alu->op) {
|
||||
case nir_op_f2f16:
|
||||
if (b->shader->info.float_controls_execution_mode & FLOAT_CONTROLS_ROUNDING_MODE_RTZ_FP16)
|
||||
mode = nir_rounding_mode_rtz;
|
||||
else if (b->shader->info.float_controls_execution_mode & FLOAT_CONTROLS_ROUNDING_MODE_RTE_FP16)
|
||||
mode = nir_rounding_mode_rtne;
|
||||
break;
|
||||
case nir_op_f2f16_rtne:
|
||||
mode = nir_rounding_mode_rtne;
|
||||
break;
|
||||
case nir_op_f2f16_rtz:
|
||||
mode = nir_rounding_mode_rtz;
|
||||
break;
|
||||
default: unreachable("Should've been filtered");
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
} else if (instr->type == nir_instr_type_intrinsic) {
|
||||
nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
|
||||
assert(nir_intrinsic_src_type(intrin) == nir_type_float32);
|
||||
if (intrin->intrinsic != nir_intrinsic_convert_alu_types ||
|
||||
nir_intrinsic_dest_type(intrin) != nir_type_float16)
|
||||
return false;
|
||||
src = intrin->src[0].ssa;
|
||||
dst = &intrin->dest.ssa;
|
||||
mode = nir_intrinsic_rounding_mode(intrin);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
|
||||
nir_lower_fp16_cast_options options = *(nir_lower_fp16_cast_options *)data;
|
||||
nir_lower_fp16_cast_options req_option = 0;
|
||||
switch (mode) {
|
||||
case nir_rounding_mode_rtz:
|
||||
req_option = nir_lower_fp16_rtz;
|
||||
break;
|
||||
case nir_rounding_mode_rtne:
|
||||
req_option = nir_lower_fp16_rtne;
|
||||
break;
|
||||
case nir_rounding_mode_ru:
|
||||
req_option = nir_lower_fp16_ru;
|
||||
break;
|
||||
case nir_rounding_mode_rd:
|
||||
req_option = nir_lower_fp16_rd;
|
||||
break;
|
||||
case nir_rounding_mode_undef:
|
||||
if (options == nir_lower_fp16_all) {
|
||||
/* Pick one arbitrarily for lowering */
|
||||
mode = nir_rounding_mode_rtne;
|
||||
req_option = nir_lower_fp16_rtne;
|
||||
}
|
||||
/* Otherwise assume the backend can handle f2f16 with undef rounding */
|
||||
break;
|
||||
default:
|
||||
unreachable("Invalid rounding mode");
|
||||
}
|
||||
if (!(options & req_option))
|
||||
return false;
|
||||
|
||||
b->cursor = nir_before_instr(instr);
|
||||
nir_ssa_def *rets[NIR_MAX_VEC_COMPONENTS] = { NULL };
|
||||
|
||||
for (unsigned i = 0; i < dst->num_components; i++) {
|
||||
|
|
@ -224,14 +245,16 @@ lower_fp16_cast_impl(nir_builder *b, nir_instr *instr, void *data)
|
|||
rets[i] = float_to_half_impl(b, comp, mode);
|
||||
}
|
||||
|
||||
return nir_vec(b, rets, dst->num_components);
|
||||
nir_ssa_def *new_val = nir_vec(b, rets, dst->num_components);
|
||||
nir_ssa_def_rewrite_uses(dst, new_val);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool
|
||||
nir_lower_fp16_casts(nir_shader *shader)
|
||||
nir_lower_fp16_casts(nir_shader *shader, nir_lower_fp16_cast_options options)
|
||||
{
|
||||
return nir_shader_lower_instructions(shader,
|
||||
lower_fp16_casts_filter,
|
||||
lower_fp16_cast_impl,
|
||||
NULL);
|
||||
return nir_shader_instructions_pass(shader,
|
||||
lower_fp16_cast_impl,
|
||||
nir_metadata_none,
|
||||
&options);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2775,7 +2775,7 @@ lp_build_opt_nir(struct nir_shader *nir)
|
|||
NIR_PASS_V(nir, nir_lower_frexp);
|
||||
|
||||
NIR_PASS_V(nir, nir_lower_flrp, 16|32|64, true);
|
||||
NIR_PASS_V(nir, nir_lower_fp16_casts);
|
||||
NIR_PASS_V(nir, nir_lower_fp16_casts, nir_lower_fp16_all);
|
||||
do {
|
||||
progress = false;
|
||||
NIR_PASS(progress, nir, nir_opt_constant_folding);
|
||||
|
|
|
|||
|
|
@ -1085,7 +1085,7 @@ clc_spirv_to_dxil(struct clc_libclc *lib,
|
|||
NIR_PASS_V(nir, dxil_nir_lower_loads_stores_to_dxil);
|
||||
NIR_PASS_V(nir, dxil_nir_opt_alu_deref_srcs);
|
||||
NIR_PASS_V(nir, dxil_nir_lower_atomics_to_dxil);
|
||||
NIR_PASS_V(nir, nir_lower_fp16_casts);
|
||||
NIR_PASS_V(nir, nir_lower_fp16_casts, nir_lower_fp16_all);
|
||||
NIR_PASS_V(nir, nir_lower_convert_alu_types, NULL);
|
||||
|
||||
// Convert pack to pack_split
|
||||
|
|
|
|||
|
|
@ -42,7 +42,6 @@ bool dxil_nir_lower_atomics_to_dxil(nir_shader *shader);
|
|||
bool dxil_nir_lower_deref_ssbo(nir_shader *shader);
|
||||
bool dxil_nir_opt_alu_deref_srcs(nir_shader *shader);
|
||||
bool dxil_nir_lower_upcast_phis(nir_shader *shader, unsigned min_bit_size);
|
||||
bool dxil_nir_lower_fp16_casts(nir_shader *shader);
|
||||
bool dxil_nir_split_clip_cull_distance(nir_shader *shader);
|
||||
bool dxil_nir_lower_double_math(nir_shader *shader);
|
||||
bool dxil_nir_lower_system_values_to_zero(nir_shader *shader,
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue