diff --git a/src/amd/compiler/aco_optimizer.cpp b/src/amd/compiler/aco_optimizer.cpp index 418f61bd61f..25914429f48 100644 --- a/src/amd/compiler/aco_optimizer.cpp +++ b/src/amd/compiler/aco_optimizer.cpp @@ -5089,8 +5089,10 @@ combine_instruction(opt_ctx& ctx, aco_ptr& instr) add_opt(v_mul_f32, v_mad_f32, 0x3, "120"); add_opt(v_mul_legacy_f32, v_mad_legacy_f32, 0x3, "120"); } - if (ctx.program->dev.has_fast_fma32) + if (ctx.program->dev.has_fast_fma32) { add_opt(v_mul_f32, v_fma_f32, 0x3, "120", create_fma_cb); + add_opt(s_mul_f32, v_fma_f32, 0x3, "120", create_fma_cb); + } if (ctx.program->gfx_level >= GFX10_3) add_opt(v_mul_legacy_f32, v_fma_legacy_f32, 0x3, "120", create_fma_cb); } else if (info.opcode == aco_opcode::v_add_f16) { @@ -5099,14 +5101,20 @@ combine_instruction(opt_ctx& ctx, aco_ptr& instr) else if (ctx.program->gfx_level < GFX10 && ctx.fp_mode.denorm16_64 == 0) add_opt(v_mul_f16, v_mad_f16, 0x3, "120"); - if (ctx.program->gfx_level < GFX9) + if (ctx.program->gfx_level < GFX9) { add_opt(v_mul_f16, v_fma_legacy_f16, 0x3, "120", create_fma_cb); - else + } else { add_opt(v_mul_f16, v_fma_f16, 0x3, "120", create_fma_cb); + add_opt(s_mul_f16, v_fma_f16, 0x3, "120", create_fma_cb); + } } else if (info.opcode == aco_opcode::v_add_f64) { add_opt(v_mul_f64, v_fma_f64, 0x3, "120", create_fma_cb); } else if (info.opcode == aco_opcode::v_add_f64_e64) { add_opt(v_mul_f64_e64, v_fma_f64, 0x3, "120", create_fma_cb); + } else if (info.opcode == aco_opcode::s_add_f32) { + add_opt(s_mul_f32, s_fmac_f32, 0x3, "120", create_fma_cb); + } else if (info.opcode == aco_opcode::s_add_f16) { + add_opt(s_mul_f16, s_fmac_f16, 0x3, "120", create_fma_cb); } if (match_and_apply_patterns(ctx, info, patterns)) {