nir/lower_fp16_casts: add option to split fp64 casts

Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Ivan Briano <ivan.briano@intel.com>
Reviewed-by: Georg Lehmann <dadschoorse@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/25566>
This commit is contained in:
Rhys Perry 2023-10-04 14:23:59 +01:00 committed by Marge Bot
parent fce434818a
commit 288e9db053
2 changed files with 38 additions and 9 deletions

View file

@ -6144,6 +6144,7 @@ typedef enum {
nir_lower_fp16_ru = (1 << 2),
nir_lower_fp16_rd = (1 << 3),
nir_lower_fp16_all = 0xf,
nir_lower_fp16_split_fp64 = (1 << 4),
} nir_lower_fp16_cast_options;
bool nir_lower_fp16_casts(nir_shader *shader, nir_lower_fp16_cast_options options);
bool nir_normalize_cubemap_coords(nir_shader *shader);

View file

@ -227,13 +227,15 @@ split_f2f16_conversion(nir_builder *b, nir_def *src, nir_rounding_mode rnd)
static bool
lower_fp16_cast_impl(nir_builder *b, nir_instr *instr, void *data)
{
nir_def *src, *dst;
nir_lower_fp16_cast_options options = *(nir_lower_fp16_cast_options *)data;
nir_src *src;
nir_def *dst;
uint8_t *swizzle = NULL;
nir_rounding_mode mode = nir_rounding_mode_undef;
if (instr->type == nir_instr_type_alu) {
nir_alu_instr *alu = nir_instr_as_alu(instr);
src = alu->src[0].src.ssa;
src = &alu->src[0].src;
swizzle = alu->src[0].swizzle;
dst = &alu->def;
switch (alu->op) {
@ -249,22 +251,48 @@ lower_fp16_cast_impl(nir_builder *b, nir_instr *instr, void *data)
case nir_op_f2f16_rtz:
mode = nir_rounding_mode_rtz;
break;
case nir_op_f2f64:
if (src->ssa->bit_size == 16 && (options & nir_lower_fp16_split_fp64)) {
b->cursor = nir_before_instr(instr);
nir_src_rewrite(src, nir_f2f32(b, src->ssa));
return true;
}
return false;
default:
return false;
}
} else if (instr->type == nir_instr_type_intrinsic) {
nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
if (intrin->intrinsic != nir_intrinsic_convert_alu_types ||
nir_intrinsic_dest_type(intrin) != nir_type_float16)
if (intrin->intrinsic != nir_intrinsic_convert_alu_types)
return false;
src = intrin->src[0].ssa;
src = &intrin->src[0];
dst = &intrin->def;
mode = nir_intrinsic_rounding_mode(intrin);
if (nir_intrinsic_src_type(intrin) == nir_type_float16 &&
nir_intrinsic_dest_type(intrin) == nir_type_float64 &&
(options & nir_lower_fp16_split_fp64)) {
b->cursor = nir_before_instr(instr);
nir_src_rewrite(src, nir_f2f32(b, src->ssa));
return true;
}
if (nir_intrinsic_dest_type(intrin) != nir_type_float16)
return false;
} else {
return false;
}
nir_lower_fp16_cast_options options = *(nir_lower_fp16_cast_options *)data;
bool progress = false;
if (src->ssa->bit_size == 64 && (options & nir_lower_fp16_split_fp64)) {
b->cursor = nir_before_instr(instr);
nir_src_rewrite(src, split_f2f16_conversion(b, src->ssa, mode));
if (instr->type == nir_instr_type_intrinsic)
nir_intrinsic_set_src_type(nir_instr_as_intrinsic(instr), nir_type_float32);
progress = true;
}
nir_lower_fp16_cast_options req_option = 0;
switch (mode) {
case nir_rounding_mode_rtz:
@ -280,7 +308,7 @@ lower_fp16_cast_impl(nir_builder *b, nir_instr *instr, void *data)
req_option = nir_lower_fp16_rd;
break;
case nir_rounding_mode_undef:
if (options == nir_lower_fp16_all) {
if ((options & nir_lower_fp16_all) == nir_lower_fp16_all) {
/* Pick one arbitrarily for lowering */
mode = nir_rounding_mode_rtne;
req_option = nir_lower_fp16_rtne;
@ -291,13 +319,13 @@ lower_fp16_cast_impl(nir_builder *b, nir_instr *instr, void *data)
unreachable("Invalid rounding mode");
}
if (!(options & req_option))
return false;
return progress;
b->cursor = nir_before_instr(instr);
nir_def *rets[NIR_MAX_VEC_COMPONENTS] = { NULL };
for (unsigned i = 0; i < dst->num_components; i++) {
nir_def *comp = nir_channel(b, src, swizzle ? swizzle[i] : i);
nir_def *comp = nir_channel(b, src->ssa, swizzle ? swizzle[i] : i);
if (comp->bit_size == 64)
comp = split_f2f16_conversion(b, comp, mode);
rets[i] = float_to_half_impl(b, comp, mode);