diff --git a/src/amd/compiler/aco_optimizer.cpp b/src/amd/compiler/aco_optimizer.cpp index 3b3233b1b22..f8f73a0d6d9 100644 --- a/src/amd/compiler/aco_optimizer.cpp +++ b/src/amd/compiler/aco_optimizer.cpp @@ -75,11 +75,10 @@ perfwarn(Program* program, bool cond, const char* msg, Instruction* instr) struct mad_info { aco_ptr add_instr; uint32_t mul_temp_id; - uint16_t literal_idx; - bool check_literal; + uint16_t literal_mask; mad_info(aco_ptr instr, uint32_t id) - : add_instr(std::move(instr)), mul_temp_id(id), literal_idx(0), check_literal(false) + : add_instr(std::move(instr)), mul_temp_id(id), literal_mask(0) {} }; @@ -4398,49 +4397,42 @@ select_instruction(opt_ctx& ctx, aco_ptr& instr) if (instr->opcode == aco_opcode::v_fma_legacy_f16) return; - uint32_t literal_idx = 0; + uint32_t literal_mask = 0; + uint32_t sgpr_mask = 0; + uint32_t vgpr_mask = 0; uint32_t literal_uses = UINT32_MAX; + uint32_t literal_value = 0; - /* Try using v_madak/v_fmaak */ - if (instr->operands[2].isTemp() && - ctx.info[instr->operands[2].tempId()].is_literal(get_operand_size(instr, 2))) { - bool has_sgpr = false; - bool has_vgpr = false; - for (unsigned i = 0; i < 2; i++) { - if (!instr->operands[i].isTemp()) + /* Iterate in reverse to prefer v_madak/v_fmaak. */ + for (int i = 2; i >= 0; i--) { + Operand& op = instr->operands[i]; + if (!op.isTemp()) + continue; + if (ctx.info[op.tempId()].is_literal(get_operand_size(instr, i))) { + uint32_t new_literal = ctx.info[op.tempId()].val; + if (!literal_mask || literal_value == new_literal) { + literal_value = new_literal; + literal_uses = MIN2(literal_uses, ctx.uses[op.tempId()]); + literal_mask |= 1 << i; continue; - has_sgpr |= instr->operands[i].getTemp().type() == RegType::sgpr; - has_vgpr |= instr->operands[i].getTemp().type() == RegType::vgpr; - } - /* Encoding limitations requires a VGPR operand. The constant bus limitations before - * GFX10 disallows SGPRs. - */ - if ((!has_sgpr || ctx.program->gfx_level >= GFX10) && has_vgpr) { - literal_idx = 2; - literal_uses = ctx.uses[instr->operands[2].tempId()]; - } - } - - /* Try using v_madmk/v_fmamk */ - /* Encoding limitations requires a VGPR operand. */ - if (instr->operands[2].isTemp() && instr->operands[2].getTemp().type() == RegType::vgpr) { - for (unsigned i = 0; i < 2; i++) { - if (!instr->operands[i].isTemp()) - continue; - - /* The constant bus limitations before GFX10 disallows SGPRs. */ - if (ctx.program->gfx_level < GFX10 && instr->operands[!i].isTemp() && - instr->operands[!i].getTemp().type() == RegType::sgpr) - continue; - - if (ctx.info[instr->operands[i].tempId()].is_literal(get_operand_size(instr, i)) && - ctx.uses[instr->operands[i].tempId()] < literal_uses) { - literal_idx = i; - literal_uses = ctx.uses[instr->operands[i].tempId()]; } } + sgpr_mask |= op.isOfType(RegType::sgpr) << i; + vgpr_mask |= op.isOfType(RegType::vgpr) << i; } + /* The constant bus limitations before GFX10 disallows SGPRs. */ + if (sgpr_mask && ctx.program->gfx_level < GFX10) + literal_mask = 0; + + /* Encoding needs a vgpr. */ + if (!vgpr_mask) + literal_mask = 0; + + /* v_madmk/v_fmamk needs a vgpr in the third source. */ + if (!(literal_mask & 0b100) && !(vgpr_mask & 0b100)) + literal_mask = 0; + /* Limit the number of literals to apply to not increase the code * size too much, but always apply literals for v_mad->v_madak * because both instructions are 64-bit and this doesn't increase @@ -4448,10 +4440,10 @@ select_instruction(opt_ctx& ctx, aco_ptr& instr) * TODO: try to apply the literals earlier to lower the number of * uses below threshold */ - if (literal_uses < threshold || literal_idx == 2) { - ctx.uses[instr->operands[literal_idx].tempId()]--; - mad_info->check_literal = true; - mad_info->literal_idx = literal_idx; + if (literal_mask && (literal_uses < threshold || (literal_mask & 0b100))) { + u_foreach_bit (i, literal_mask) + ctx.uses[instr->operands[i].tempId()]--; + mad_info->literal_mask = literal_mask; return; } } @@ -4733,33 +4725,39 @@ apply_literals(opt_ctx& ctx, aco_ptr& instr) /* apply literals on MAD */ if (!instr->definitions.empty() && ctx.info[instr->definitions[0].tempId()].is_mad()) { mad_info* info = &ctx.mad_infos[ctx.info[instr->definitions[0].tempId()].instr->pass_flags]; - if (info->check_literal && - (ctx.uses[instr->operands[info->literal_idx].tempId()] == 0 || info->literal_idx == 2)) { + const bool madak = (info->literal_mask & 0b100); + bool has_dead_literal = false; + u_foreach_bit (i, info->literal_mask) + has_dead_literal |= ctx.uses[instr->operands[i].tempId()] == 0; + if (has_dead_literal || madak) { aco_ptr new_mad; - aco_opcode new_op = - info->literal_idx == 2 ? aco_opcode::v_madak_f32 : aco_opcode::v_madmk_f32; + aco_opcode new_op = madak ? aco_opcode::v_madak_f32 : aco_opcode::v_madmk_f32; if (instr->opcode == aco_opcode::v_fma_f32) - new_op = info->literal_idx == 2 ? aco_opcode::v_fmaak_f32 : aco_opcode::v_fmamk_f32; + new_op = madak ? aco_opcode::v_fmaak_f32 : aco_opcode::v_fmamk_f32; else if (instr->opcode == aco_opcode::v_mad_f16 || instr->opcode == aco_opcode::v_mad_legacy_f16) - new_op = info->literal_idx == 2 ? aco_opcode::v_madak_f16 : aco_opcode::v_madmk_f16; + new_op = madak ? aco_opcode::v_madak_f16 : aco_opcode::v_madmk_f16; else if (instr->opcode == aco_opcode::v_fma_f16) - new_op = info->literal_idx == 2 ? aco_opcode::v_fmaak_f16 : aco_opcode::v_fmamk_f16; + new_op = madak ? aco_opcode::v_fmaak_f16 : aco_opcode::v_fmamk_f16; + uint32_t literal = ctx.info[instr->operands[ffs(info->literal_mask) - 1].tempId()].val; new_mad.reset(create_instruction(new_op, Format::VOP2, 3, 1)); - if (info->literal_idx == 2) { /* add literal -> madak */ - new_mad->operands[0] = instr->operands[0]; - new_mad->operands[1] = instr->operands[1]; + for (unsigned i = 0; i < 3; i++) { + if (info->literal_mask & (1 << i)) + new_mad->operands[i] = Operand::literal32(literal); + else + new_mad->operands[i] = instr->operands[i]; + } + if (madak) { /* add literal -> madak */ if (!new_mad->operands[1].isTemp() || new_mad->operands[1].getTemp().type() == RegType::sgpr) std::swap(new_mad->operands[0], new_mad->operands[1]); } else { /* mul literal -> madmk */ - new_mad->operands[0] = instr->operands[1 - info->literal_idx]; - new_mad->operands[1] = instr->operands[2]; + if (!(info->literal_mask & 0b10)) + std::swap(new_mad->operands[0], new_mad->operands[1]); + std::swap(new_mad->operands[1], new_mad->operands[2]); } - new_mad->operands[2] = - Operand::c32(ctx.info[instr->operands[info->literal_idx].tempId()].val); new_mad->definitions[0] = instr->definitions[0]; ctx.instructions.emplace_back(std::move(new_mad)); return;