diff --git a/src/amd/compiler/aco_optimizer.cpp b/src/amd/compiler/aco_optimizer.cpp index f801bce50c9..b1b8f054658 100644 --- a/src/amd/compiler/aco_optimizer.cpp +++ b/src/amd/compiler/aco_optimizer.cpp @@ -2765,13 +2765,15 @@ label_instruction(opt_ctx& ctx, aco_ptr& instr) if (!ctx.program->needs_wqm) ctx.info[instr->definitions[0].tempId()].set_constant(0u); break; + case aco_opcode::s_mul_f16: + case aco_opcode::s_mul_f32: case aco_opcode::v_mul_f16: case aco_opcode::v_mul_f32: case aco_opcode::v_mul_legacy_f32: case aco_opcode::v_mul_f64: case aco_opcode::v_mul_f64_e64: { bool uses_mods = instr->usesModifiers(); - bool fp16 = instr->opcode == aco_opcode::v_mul_f16; + bool fp16 = instr->opcode == aco_opcode::v_mul_f16 || instr->opcode == aco_opcode::s_mul_f16; bool fp64 = instr->opcode == aco_opcode::v_mul_f64 || instr->opcode == aco_opcode::v_mul_f64_e64; unsigned bit_size = fp16 ? 16 : (fp64 ? 64 : 32); @@ -2783,22 +2785,27 @@ label_instruction(opt_ctx& ctx, aco_ptr& instr) double constant = extract_float(instr->operands[!i].constantValue64(), bit_size); - if (!instr->isDPP() && !instr->isSDWA() && !instr->valu().opsel && fabs(constant) == 1.0) { - bool neg1 = constant == -1.0; + if (!instr->isDPP() && !instr->isSDWA() && (!instr->isVALU() || !instr->valu().opsel) && + fabs(constant) == 1.0) { + bool neg = constant == -1.0; + bool abs = false; - VALU_instruction* valu = &instr->valu(); - if (valu->abs[!i] || valu->neg[!i] || valu->omod || valu->clamp) - continue; + if (instr->isVALU()) { + VALU_instruction* valu = &instr->valu(); + if (valu->abs[!i] || valu->neg[!i] || valu->omod || valu->clamp) + continue; + + abs = valu->abs[i]; + neg ^= valu->neg[i]; + } - bool abs = valu->abs[i]; - bool neg = neg1 ^ valu->neg[i]; Temp other = instr->operands[i].getTemp(); - if (abs && neg && other.type() == RegType::vgpr) + if (abs && neg && other.type() == instr->definitions[0].getTemp().type()) ctx.info[instr->definitions[0].tempId()].set_neg_abs(other, bit_size); - else if (abs && !neg && other.type() == RegType::vgpr) + else if (abs && !neg && other.type() == instr->definitions[0].getTemp().type()) ctx.info[instr->definitions[0].tempId()].set_abs(other, bit_size); - else if (!abs && neg && other.type() == RegType::vgpr) + else if (!abs && neg && other.type() == instr->definitions[0].getTemp().type()) ctx.info[instr->definitions[0].tempId()].set_neg(other, bit_size); else if (!abs && !neg) { if (denorm_mode == fp_denorm_keep || ctx.info[other.id()].is_canonicalized(bit_size))