diff --git a/src/amd/compiler/aco_optimizer.cpp b/src/amd/compiler/aco_optimizer.cpp index f4bdb249e5b..0488eff8bef 100644 --- a/src/amd/compiler/aco_optimizer.cpp +++ b/src/amd/compiler/aco_optimizer.cpp @@ -3991,6 +3991,53 @@ create_med3_cb(opt_ctx& ctx, alu_opt_info& info) return false; } +bool +can_reassoc_omod(opt_ctx& ctx, const alu_opt_info& info, unsigned bit_size) +{ + unsigned denorm = bit_size == 32 ? ctx.fp_mode.denorm32 : ctx.fp_mode.denorm16_64; + bool no_signed_zero = + info.opcode == aco_opcode::v_mul_legacy_f32 || !info.defs[0].isSZPreserve(); + + return no_signed_zero && !info.omod && !info.defs[0].isPrecise() && denorm == fp_denorm_flush; +} + +template +bool +reassoc_omod_cb(opt_ctx& ctx, alu_opt_info& info) +{ + if (info.defs[0].isPrecise()) + return false; + + aco_type type = instr_info.alu_opcode_infos[(int)info.opcode].def_types[0]; + + for (unsigned op_idx = 0; op_idx < 2; op_idx++) { + uint64_t constant = 0; + if (!op_info_get_constant(ctx, info.operands[op_idx], type, &constant)) + continue; + + double val = extract_float(constant, type.bit_size); + if (val < 0.0) { + info.operands[!op_idx].neg[0] ^= true; + val = fabs(val); + } + + if (val == (is_rcp ? 0.5 : 2.0)) + info.omod = 1; + else if (val == (is_rcp ? 0.25 : 4.0)) + info.omod = 2; + else if (val == (is_rcp ? 2.0 : 0.5)) + info.omod = 3; + else + return false; + + info.operands.erase(std::next(info.operands.begin(), op_idx)); + + return true; + } + + return false; +} + template bool shift_to_mad_cb(opt_ctx& ctx, alu_opt_info& info) @@ -4293,16 +4340,38 @@ combine_instruction(opt_ctx& ctx, aco_ptr& instr) } else if (info.opcode == aco_opcode::v_min_i16_e64) { add_opt(v_min_i16_e64, v_min3_i16, 0x3, "120", nullptr, true); add_opt(v_max_i16_e64, v_med3_i16, 0x3, "012", create_med3_cb, true); - } else if (((info.opcode == aco_opcode::v_mul_f32 && !info.defs[0].isNaNPreserve() && - !info.defs[0].isInfPreserve()) || - (info.opcode == aco_opcode::v_mul_legacy_f32 && !info.defs[0].isSZPreserve())) && - !info.clamp && !info.omod && !ctx.fp_mode.must_flush_denorms32) { - /* v_mul_f32(a, v_cndmask_b32(0, 1.0, cond)) -> v_cndmask_b32(0, a, cond) */ - add_opt(v_cndmask_b32, v_cndmask_b32, 0x3, "1032", - and_cb, remove_const_cb<0x3f800000>>, true); - /* v_mul_f32(a, v_cndmask_b32(1.0, 0, cond)) -> v_cndmask_b32(a, 0, cond) */ - add_opt(v_cndmask_b32, v_cndmask_b32, 0x3, "0231", - and_cb, remove_const_cb<0x3f800000>>, true); + } else if (info.opcode == aco_opcode::v_mul_f32 || info.opcode == aco_opcode::v_mul_legacy_f32) { + bool legacy = info.opcode == aco_opcode::v_mul_legacy_f32; + + if ((legacy ? !info.defs[0].isSZPreserve() + : (!info.defs[0].isNaNPreserve() && !info.defs[0].isInfPreserve())) && + !info.clamp && !info.omod && !ctx.fp_mode.must_flush_denorms32) { + /* v_mul_f32(a, v_cndmask_b32(0, 1.0, cond)) -> v_cndmask_b32(0, a, cond) */ + add_opt(v_cndmask_b32, v_cndmask_b32, 0x3, "1032", + and_cb, remove_const_cb<0x3f800000>>, true); + /* v_mul_f32(a, v_cndmask_b32(1.0, 0, cond)) -> v_cndmask_b32(a, 0, cond) */ + add_opt(v_cndmask_b32, v_cndmask_b32, 0x3, "0231", + and_cb, remove_const_cb<0x3f800000>>, true); + } + + if (can_reassoc_omod(ctx, info, 32)) { + if (legacy) { + add_opt(v_mul_f32, v_mul_legacy_f32, 0x3, "120", reassoc_omod_cb, true); + add_opt(v_mul_legacy_f32, v_mul_legacy_f32, 0x3, "120", reassoc_omod_cb, true); + add_opt(s_mul_f32, v_mul_legacy_f32, 0x3, "120", reassoc_omod_cb, true); + } else { + add_opt(v_mul_f32, v_mul_f32, 0x3, "120", reassoc_omod_cb, true); + add_opt(v_mul_legacy_f32, v_mul_f32, 0x3, "120", reassoc_omod_cb, true); + add_opt(s_mul_f32, v_mul_f32, 0x3, "120", reassoc_omod_cb, true); + } + } + } else if (info.opcode == aco_opcode::v_mul_f16 && can_reassoc_omod(ctx, info, 16)) { + add_opt(v_mul_f16, v_mul_f16, 0x3, "120", reassoc_omod_cb, true); + add_opt(s_mul_f16, v_mul_f16, 0x3, "120", reassoc_omod_cb, true); + } else if (info.opcode == aco_opcode::v_mul_f64 && can_reassoc_omod(ctx, info, 64)) { + add_opt(v_mul_f64, v_mul_f64, 0x3, "120", reassoc_omod_cb, true); + } else if (info.opcode == aco_opcode::v_mul_f64_e64 && can_reassoc_omod(ctx, info, 64)) { + add_opt(v_mul_f64_e64, v_mul_f64_e64, 0x3, "120", reassoc_omod_cb, true); } else if (info.opcode == aco_opcode::v_add_u16 && !info.clamp) { if (ctx.program->gfx_level < GFX9) { add_opt(v_mul_lo_u16, v_mad_legacy_u16, 0x3, "120");