diff --git a/src/amd/compiler/aco_optimizer.cpp b/src/amd/compiler/aco_optimizer.cpp index 79a2d65df3b..d9a5958c4e2 100644 --- a/src/amd/compiler/aco_optimizer.cpp +++ b/src/amd/compiler/aco_optimizer.cpp @@ -2064,6 +2064,32 @@ parse_operand(opt_ctx& ctx, Temp tmp, alu_opt_op& op_info, aco_type& type) return true; } + if (info.parent_instr->opcode == aco_opcode::v_cvt_f32_f16 || + info.parent_instr->opcode == aco_opcode::s_cvt_f32_f16 || + info.parent_instr->opcode == aco_opcode::s_cvt_hi_f32_f16) { + Instruction* instr = info.parent_instr; + if (instr->isVALU() && (instr->valu().clamp || instr->valu().omod)) + return false; + if (instr->isDPP() || (instr->isSDWA() && instr->sdwa().dst_sel.size() != 4)) + return false; + + if (instr->isVALU() && instr->valu().abs[0]) + op_info.abs[0] = true; + if (instr->isVALU() && instr->valu().neg[0]) + op_info.neg[0] = true; + + if (instr->isSDWA()) + op_info.extract[0] = instr->sdwa().sel[0]; + else if (instr->isVALU() && instr->valu().opsel[0]) + op_info.extract[0] = SubdwordSel::uword1; + else if (info.parent_instr->opcode == aco_opcode::s_cvt_hi_f32_f16) + op_info.extract[0] = SubdwordSel::uword1; + + op_info.f16_to_f32 = true; + op_info.op = instr->operands[0]; + return true; + } + if (info.is_temp() || info.is_fcanonicalize() || info.is_abs() || info.is_neg()) { op_info.op = Operand(info.temp); if (info.is_abs()) @@ -2094,6 +2120,12 @@ combine_operand(opt_ctx& ctx, alu_opt_op& inner, const aco_type& inner_type, if (has_imod && outer_type.bit_size != inner_type.bit_size) return false; + if (outer.f16_to_f32) { + if (inner_type.num_components != 1 || inner.extract[0].size() != 4 || inner.f16_to_f32) + return false; + inner.f16_to_f32 = true; + } + for (unsigned i = 0; i < inner_type.num_components; i++) { unsigned offset = inner.extract[i].offset() * 8; unsigned size = MIN2(inner.extract[i].size() * 8, inner_type.bit_size); @@ -2208,7 +2240,8 @@ alu_propagate_temp_const(opt_ctx& ctx, aco_ptr& instr, bool uses_va alu_opt_op outer; aco_type outer_type; - if (!parse_operand(ctx, info.operands[i].op.getTemp(), outer, outer_type)) { + if (!parse_operand(ctx, info.operands[i].op.getTemp(), outer, outer_type) || + (!uses_valid && outer.f16_to_f32)) { operand_mask &= ~BITFIELD_BIT(i); continue; } @@ -4360,68 +4393,6 @@ combine_output_conversion(opt_ctx& ctx, aco_ptr& instr) return true; } -void -combine_mad_mix(opt_ctx& ctx, aco_ptr& instr) -{ - if (!can_use_mad_mix(ctx, instr)) - return; - - for (unsigned i = 0; i < instr->operands.size(); i++) { - if (!instr->operands[i].isTemp()) - continue; - Temp tmp = instr->operands[i].getTemp(); - - Instruction* conv = ctx.info[tmp.id()].parent_instr; - if (conv->opcode != aco_opcode::v_cvt_f32_f16 || !conv->operands[0].isTemp() || - conv->valu().clamp || conv->valu().omod) { - continue; - } else if (conv->isSDWA() && - (conv->sdwa().dst_sel.size() != 4 || conv->sdwa().sel[0].size() != 2)) { - continue; - } else if (conv->isDPP()) { - continue; - } - - if (get_operand_type(instr, i).bit_size != 32) - continue; - - /* Conversion to VOP3P will add inline constant operands, but that shouldn't affect - * check_vop3_operands(). */ - Operand op[3]; - for (unsigned j = 0; j < instr->operands.size(); j++) - op[j] = instr->operands[j]; - op[i] = conv->operands[0]; - if (!check_vop3_operands(ctx, instr->operands.size(), op)) - continue; - if (!conv->operands[0].isOfType(RegType::vgpr) && instr->isDPP()) - continue; - - if (!instr->isVOP3P()) { - bool is_add = - instr->opcode != aco_opcode::v_mul_f32 && instr->opcode != aco_opcode::v_fma_f32; - to_mad_mix(ctx, instr); - i += is_add; - } - - if (--ctx.uses[tmp.id()]) - ctx.uses[conv->operands[0].tempId()]++; - instr->operands[i].setTemp(conv->operands[0].getTemp()); - if (conv->definitions[0].isPrecise()) - instr->definitions[0].setPrecise(true); - instr->valu().opsel_hi[i] = true; - if (conv->isSDWA() && conv->sdwa().sel[0].offset() == 2) - instr->valu().opsel_lo[i] = true; - else - instr->valu().opsel_lo[i] = conv->valu().opsel[0]; - bool neg = conv->valu().neg[0]; - bool abs = conv->valu().abs[0]; - if (!instr->valu().abs[i]) { - instr->valu().neg[i] ^= neg; - instr->valu().abs[i] = abs; - } - } -} - // TODO: we could possibly move the whole label_instruction pass to combine_instruction: // this would mean that we'd have to fix the instruction uses while value propagation @@ -4486,7 +4457,6 @@ combine_instruction(opt_ctx& ctx, aco_ptr& instr) } if (instr->isVALU()) { - combine_mad_mix(ctx, instr); while (apply_omod_clamp(ctx, instr) || combine_output_conversion(ctx, instr)) ; apply_insert(ctx, instr);