diff --git a/src/compiler/nir/nir_lower_alu_to_scalar.c b/src/compiler/nir/nir_lower_alu_to_scalar.c index 2833774c621..1203e53a33e 100644 --- a/src/compiler/nir/nir_lower_alu_to_scalar.c +++ b/src/compiler/nir/nir_lower_alu_to_scalar.c @@ -93,6 +93,52 @@ lower_reduction(nir_alu_instr *alu, nir_op chan_op, nir_op merge_op, return last; } +static inline bool +will_lower_ffma(nir_shader *shader, unsigned bit_size) +{ + switch (bit_size) { + case 16: + return shader->options->lower_ffma16; + case 32: + return shader->options->lower_ffma32; + case 64: + return shader->options->lower_ffma64; + } + unreachable("bad bit size"); +} + +static nir_ssa_def * +lower_fdot(nir_alu_instr *alu, nir_builder *builder) +{ + /* If we don't want to lower ffma, create several ffma instead of fmul+fadd + * and fusing later because fusing is not possible for exact fdot instructions. + */ + if (will_lower_ffma(builder->shader, alu->dest.dest.ssa.bit_size)) + return lower_reduction(alu, nir_op_fmul, nir_op_fadd, builder); + + unsigned num_components = nir_op_infos[alu->op].input_sizes[0]; + + nir_ssa_def *prev = NULL; + for (int i = num_components - 1; i >= 0; i--) { + nir_alu_instr *instr = nir_alu_instr_create( + builder->shader, prev ? nir_op_ffma : nir_op_fmul); + nir_alu_ssa_dest_init(instr, 1, alu->dest.dest.ssa.bit_size); + for (unsigned j = 0; j < 2; j++) { + nir_alu_src_copy(&instr->src[j], &alu->src[j], instr); + instr->src[j].swizzle[0] = alu->src[j].swizzle[i]; + } + if (i != num_components - 1) + instr->src[2].src = nir_src_for_ssa(prev); + instr->exact = builder->exact; + + nir_builder_instr_insert(builder, &instr->instr); + + prev = &instr->dest.dest.ssa; + } + + return prev; +} + static nir_ssa_def * lower_alu_instr_scalar(nir_builder *b, nir_instr *instr, void *_data) { @@ -244,7 +290,13 @@ lower_alu_instr_scalar(nir_builder *b, nir_instr *instr, void *_data) case nir_op_unpack_double_2x32_dxil: return NULL; - LOWER_REDUCTION(nir_op_fdot, nir_op_fmul, nir_op_fadd); + case nir_op_fdot2: + case nir_op_fdot3: + case nir_op_fdot4: + case nir_op_fdot8: + case nir_op_fdot16: + return lower_fdot(alu, b); + LOWER_REDUCTION(nir_op_ball_fequal, nir_op_feq, nir_op_iand); LOWER_REDUCTION(nir_op_ball_iequal, nir_op_ieq, nir_op_iand); LOWER_REDUCTION(nir_op_bany_fnequal, nir_op_fneu, nir_op_ior);