diff --git a/src/amd/compiler/aco_optimizer.cpp b/src/amd/compiler/aco_optimizer.cpp index 6cfa4becefc..9b07b1bdaff 100644 --- a/src/amd/compiler/aco_optimizer.cpp +++ b/src/amd/compiler/aco_optimizer.cpp @@ -4343,112 +4343,6 @@ combine_vop3p(opt_ctx& ctx, aco_ptr& instr) return; } } - - if (instr->opcode == aco_opcode::v_pk_add_f16 || instr->opcode == aco_opcode::v_pk_add_u16) { - bool fadd = instr->opcode == aco_opcode::v_pk_add_f16; - if (fadd && instr->definitions[0].isPrecise()) - return; - if (!fadd && instr->valu().clamp) - return; - - Instruction* mul_instr = nullptr; - unsigned add_op_idx = 0; - bitarray8 mul_neg_lo = 0, mul_neg_hi = 0, mul_opsel_lo = 0, mul_opsel_hi = 0; - uint32_t uses = UINT32_MAX; - - /* find the 'best' mul instruction to combine with the add */ - for (unsigned i = 0; i < 2; i++) { - Instruction* op_instr = follow_operand(ctx, instr->operands[i], true); - if (!op_instr) - continue; - - if (op_instr->isVOP3P()) { - if (fadd) { - if (op_instr->opcode != aco_opcode::v_pk_mul_f16 || - op_instr->definitions[0].isPrecise()) - continue; - } else { - if (op_instr->opcode != aco_opcode::v_pk_mul_lo_u16) - continue; - } - - /* no clamp allowed between mul and add */ - if (op_instr->valu().clamp) - continue; - - Operand op[3] = {op_instr->operands[0], op_instr->operands[1], instr->operands[1 - i]}; - if (ctx.uses[instr->operands[i].tempId()] >= uses || !check_vop3_operands(ctx, 3, op)) - continue; - - mul_instr = op_instr; - add_op_idx = 1 - i; - uses = ctx.uses[instr->operands[i].tempId()]; - mul_neg_lo = mul_instr->valu().neg_lo; - mul_neg_hi = mul_instr->valu().neg_hi; - mul_opsel_lo = mul_instr->valu().opsel_lo; - mul_opsel_hi = mul_instr->valu().opsel_hi; - } else if (instr->operands[i].bytes() == 2) { - if ((fadd && (op_instr->opcode != aco_opcode::v_mul_f16 || - op_instr->definitions[0].isPrecise())) || - (!fadd && op_instr->opcode != aco_opcode::v_mul_lo_u16 && - op_instr->opcode != aco_opcode::v_mul_lo_u16_e64)) - continue; - - if (op_instr->valu().clamp || op_instr->valu().omod || op_instr->valu().abs) - continue; - - if (op_instr->isDPP() || (op_instr->isSDWA() && (op_instr->sdwa().sel[0].size() < 2 || - op_instr->sdwa().sel[1].size() < 2))) - continue; - - Operand op[3] = {op_instr->operands[0], op_instr->operands[1], instr->operands[1 - i]}; - if (ctx.uses[instr->operands[i].tempId()] >= uses || !check_vop3_operands(ctx, 3, op)) - continue; - - mul_instr = op_instr; - add_op_idx = 1 - i; - uses = ctx.uses[instr->operands[i].tempId()]; - mul_neg_lo = mul_instr->valu().neg; - mul_neg_hi = mul_instr->valu().neg; - if (mul_instr->isSDWA()) { - for (unsigned j = 0; j < 2; j++) - mul_opsel_lo[j] = mul_instr->sdwa().sel[j].offset(); - } else { - mul_opsel_lo = mul_instr->valu().opsel; - } - mul_opsel_hi = mul_opsel_lo; - } - } - - if (!mul_instr) - return; - - /* turn mul + packed add into v_pk_fma_f16 */ - aco_opcode mad = fadd ? aco_opcode::v_pk_fma_f16 : aco_opcode::v_pk_mad_u16; - aco_ptr fma{create_instruction(mad, Format::VOP3P, 3, 1)}; - fma->operands[0] = copy_operand(ctx, mul_instr->operands[0]); - fma->operands[1] = copy_operand(ctx, mul_instr->operands[1]); - fma->operands[2] = instr->operands[add_op_idx]; - fma->valu().clamp = vop3p->clamp; - fma->valu().neg_lo = mul_neg_lo; - fma->valu().neg_hi = mul_neg_hi; - fma->valu().opsel_lo = mul_opsel_lo; - fma->valu().opsel_hi = mul_opsel_hi; - propagate_swizzles(&fma->valu(), vop3p->opsel_lo[1 - add_op_idx], - vop3p->opsel_hi[1 - add_op_idx]); - fma->valu().opsel_lo[2] = vop3p->opsel_lo[add_op_idx]; - fma->valu().opsel_hi[2] = vop3p->opsel_hi[add_op_idx]; - fma->valu().neg_lo[2] = vop3p->neg_lo[add_op_idx]; - fma->valu().neg_hi[2] = vop3p->neg_hi[add_op_idx]; - fma->valu().neg_lo[1] = fma->valu().neg_lo[1] ^ vop3p->neg_lo[1 - add_op_idx]; - fma->valu().neg_hi[1] = fma->valu().neg_hi[1] ^ vop3p->neg_hi[1 - add_op_idx]; - fma->definitions[0] = instr->definitions[0]; - fma->pass_flags = instr->pass_flags; - instr = std::move(fma); - ctx.info[instr->definitions[0].tempId()].parent_instr = instr.get(); - decrease_and_dce(ctx, mul_instr->definitions[0].getTemp()); - return; - } } bool @@ -4699,8 +4593,9 @@ combine_instruction(opt_ctx& ctx, aco_ptr& instr) } if (instr->isVOP3P() && instr->opcode != aco_opcode::v_fma_mix_f32 && - instr->opcode != aco_opcode::v_fma_mixlo_f16) - return combine_vop3p(ctx, instr); + instr->opcode != aco_opcode::v_fma_mixlo_f16) { + combine_vop3p(ctx, instr); + } if (instr->isDPP()) return; @@ -4874,16 +4769,19 @@ combine_instruction(opt_ctx& ctx, aco_ptr& instr) if (ctx.program->gfx_level >= GFX10_3) add_opt(v_mul_legacy_f32, v_fma_legacy_f32, 0x3, "120", create_fma_cb); } else if (info.opcode == aco_opcode::v_add_f16) { - if (ctx.program->gfx_level < GFX9 && ctx.fp_mode.denorm16_64 == 0) + if (ctx.program->gfx_level < GFX9 && ctx.fp_mode.denorm16_64 == 0) { add_opt(v_mul_f16, v_mad_legacy_f16, 0x3, "120"); - else if (ctx.program->gfx_level < GFX10 && ctx.fp_mode.denorm16_64 == 0) + } else if (ctx.program->gfx_level < GFX10 && ctx.fp_mode.denorm16_64 == 0) { add_opt(v_mul_f16, v_mad_f16, 0x3, "120"); + add_opt(v_pk_mul_f16, v_mad_f16, 0x3, "120"); + } if (ctx.program->gfx_level < GFX9) { add_opt(v_mul_f16, v_fma_legacy_f16, 0x3, "120", create_fma_cb); } else { add_opt(v_mul_f16, v_fma_f16, 0x3, "120", create_fma_cb); add_opt(s_mul_f16, v_fma_f16, 0x3, "120", create_fma_cb); + add_opt(v_pk_mul_f16, v_fma_f16, 0x3, "120", create_fma_cb); } } else if (info.opcode == aco_opcode::v_add_f64) { add_opt(v_mul_f64, v_fma_f64, 0x3, "120", create_fma_cb); @@ -4893,6 +4791,10 @@ combine_instruction(opt_ctx& ctx, aco_ptr& instr) add_opt(s_mul_f32, s_fmac_f32, 0x3, "120", create_fma_cb); } else if (info.opcode == aco_opcode::s_add_f16) { add_opt(s_mul_f16, s_fmac_f16, 0x3, "120", create_fma_cb); + } else if (info.opcode == aco_opcode::v_pk_add_f16) { + add_opt(v_pk_mul_f16, v_pk_fma_f16, 0x3, "120", create_fma_cb); + add_opt(v_mul_f16, v_pk_fma_f16, 0x3, "120", create_fma_cb); + add_opt(s_mul_f16, v_pk_fma_f16, 0x3, "120", create_fma_cb); } else if (info.opcode == aco_opcode::v_max_f32) { add_opt(v_max_f32, v_max3_f32, 0x3, "120", nullptr, true); add_opt(s_max_f32, v_max3_f32, 0x3, "120", nullptr, true); @@ -5001,12 +4903,21 @@ combine_instruction(opt_ctx& ctx, aco_ptr& instr) add_opt(v_cndmask_b32, v_cndmask_b32, 0x3, "1032", and_cb, remove_const_cb<0x3f800000>>, true); } else if (info.opcode == aco_opcode::v_add_u16 && !info.clamp) { - if (ctx.program->gfx_level < GFX9) + if (ctx.program->gfx_level < GFX9) { add_opt(v_mul_lo_u16, v_mad_legacy_u16, 0x3, "120"); - else + } else { add_opt(v_mul_lo_u16, v_mad_u16, 0x3, "120"); + add_opt(v_pk_mul_lo_u16, v_mad_u16, 0x3, "120"); + } } else if (info.opcode == aco_opcode::v_add_u16_e64 && !info.clamp) { add_opt(v_mul_lo_u16_e64, v_mad_u16, 0x3, "120"); + add_opt(v_pk_mul_lo_u16, v_mad_u16, 0x3, "120"); + } else if (info.opcode == aco_opcode::v_pk_add_u16 && !info.clamp) { + add_opt(v_pk_mul_lo_u16, v_pk_mad_u16, 0x3, "120"); + if (ctx.program->gfx_level < GFX10) + add_opt(v_mul_lo_u16, v_pk_mad_u16, 0x3, "120"); + else + add_opt(v_mul_lo_u16_e64, v_pk_mad_u16, 0x3, "120"); } if (match_and_apply_patterns(ctx, info, patterns)) {