aco: fix fddx/y with uniform inf/nan input

inf or nan subtracted by itself is not zero.

I don't think Vulkan requires this, but this better matches NIR's constant
folding and the divergent implementation.

fossil-db (navi31):
Totals from 3 (0.00% of 79395) affected shaders:
Instrs: 537 -> 588 (+9.50%)
CodeSize: 3132 -> 3380 (+7.92%)
Latency: 2806 -> 2819 (+0.46%)
InvThroughput: 286 -> 316 (+10.49%)
Copies: 24 -> 39 (+62.50%)
VALU: 262 -> 289 (+10.31%)
SALU: 33 -> 51 (+54.55%)

Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/29418>
This commit is contained in:
Rhys Perry 2024-05-27 12:18:49 +01:00 committed by Marge Bot
parent 09fb55ea92
commit 1829d74ad3

View file

@ -3889,14 +3889,6 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr)
case nir_op_fddy_fine:
case nir_op_fddx_coarse:
case nir_op_fddy_coarse: {
if (!nir_src_is_divergent(instr->src[0].src)) {
/* Source is the same in all lanes, so the derivative is zero.
* This also avoids emitting invalid IR.
*/
bld.copy(Definition(dst), Operand::zero(dst.bytes()));
break;
}
uint16_t dpp_ctrl1, dpp_ctrl2;
if (instr->op == nir_op_fddx_fine) {
dpp_ctrl1 = dpp_quad_perm(0, 0, 2, 2);
@ -3923,8 +3915,12 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr)
opsel_lo |= opsel_lo << 1;
opsel_hi |= opsel_hi << 1;
Temp tl = bld.vop1_dpp(aco_opcode::v_mov_b32, bld.def(v1), src, dpp_ctrl1);
Temp tr = bld.vop1_dpp(aco_opcode::v_mov_b32, bld.def(v1), src, dpp_ctrl2);
Temp tl = src;
Temp tr = src;
if (nir_src_is_divergent(instr->src[0].src)) {
tl = bld.vop1_dpp(aco_opcode::v_mov_b32, bld.def(v1), src, dpp_ctrl1);
tr = bld.vop1_dpp(aco_opcode::v_mov_b32, bld.def(v1), src, dpp_ctrl2);
}
VALU_instruction& sub =
bld.vop3p(aco_opcode::v_pk_add_f16, Definition(dst), tr, tl, opsel_lo, opsel_hi)
@ -3935,9 +3931,11 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr)
} else {
Temp src = as_vgpr(ctx, get_alu_src(ctx, instr->src[0]));
if (ctx->program->gfx_level >= GFX8) {
aco_opcode sub =
instr->def.bit_size == 16 ? aco_opcode::v_sub_f16 : aco_opcode::v_sub_f32;
aco_opcode sub =
instr->def.bit_size == 16 ? aco_opcode::v_sub_f16 : aco_opcode::v_sub_f32;
if (!nir_src_is_divergent(instr->src[0].src)) {
bld.vop2(sub, Definition(dst), src, src);
} else if (ctx->program->gfx_level >= GFX8) {
Temp tl = bld.vop1_dpp(aco_opcode::v_mov_b32, bld.def(v1), src, dpp_ctrl1);
bld.vop2_dpp(sub, Definition(dst), src, tl, dpp_ctrl2);
} else {