diff --git a/src/amd/compiler/aco_instruction_selection.cpp b/src/amd/compiler/aco_instruction_selection.cpp index 2a2dc0687df..07a31319094 100644 --- a/src/amd/compiler/aco_instruction_selection.cpp +++ b/src/amd/compiler/aco_instruction_selection.cpp @@ -1274,10 +1274,10 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_iabs: { + Temp src = get_alu_src(ctx, instr->src[0]); if (dst.regClass() == s1) { - bld.sop1(aco_opcode::s_abs_i32, Definition(dst), bld.def(s1, scc), get_alu_src(ctx, instr->src[0])); + bld.sop1(aco_opcode::s_abs_i32, Definition(dst), bld.def(s1, scc), src); } else if (dst.regClass() == v1) { - Temp src = get_alu_src(ctx, instr->src[0]); bld.vop2(aco_opcode::v_max_i32, Definition(dst), src, bld.vsub32(bld.def(v1), Operand(0u), src)); } else { isel_err(&instr->instr, "Unimplemented NIR instr bit size"); @@ -1685,7 +1685,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) if (dst.regClass() == v1) { bld.vop3(aco_opcode::v_mul_hi_u32, Definition(dst), get_alu_src(ctx, instr->src[0]), get_alu_src(ctx, instr->src[1])); } else if (dst.regClass() == s1 && ctx->options->chip_class >= GFX9) { - bld.sop2(aco_opcode::s_mul_hi_u32, Definition(dst), get_alu_src(ctx, instr->src[0]), get_alu_src(ctx, instr->src[1])); + emit_sop2_instruction(ctx, instr, aco_opcode::s_mul_hi_u32, dst, false); } else if (dst.regClass() == s1) { Temp tmp = bld.vop3(aco_opcode::v_mul_hi_u32, bld.def(v1), get_alu_src(ctx, instr->src[0]), as_vgpr(ctx, get_alu_src(ctx, instr->src[1]))); @@ -1699,7 +1699,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) if (dst.regClass() == v1) { bld.vop3(aco_opcode::v_mul_hi_i32, Definition(dst), get_alu_src(ctx, instr->src[0]), get_alu_src(ctx, instr->src[1])); } else if (dst.regClass() == s1 && ctx->options->chip_class >= GFX9) { - bld.sop2(aco_opcode::s_mul_hi_i32, Definition(dst), get_alu_src(ctx, instr->src[0]), get_alu_src(ctx, instr->src[1])); + emit_sop2_instruction(ctx, instr, aco_opcode::s_mul_hi_i32, dst, false); } else if (dst.regClass() == s1) { Temp tmp = bld.vop3(aco_opcode::v_mul_hi_i32, bld.def(v1), get_alu_src(ctx, instr->src[0]), as_vgpr(ctx, get_alu_src(ctx, instr->src[1]))); @@ -1710,13 +1710,13 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_fmul: { - Temp src0 = get_alu_src(ctx, instr->src[0]); - Temp src1 = as_vgpr(ctx, get_alu_src(ctx, instr->src[1])); if (dst.regClass() == v2b) { emit_vop2_instruction(ctx, instr, aco_opcode::v_mul_f16, dst, true); } else if (dst.regClass() == v1) { emit_vop2_instruction(ctx, instr, aco_opcode::v_mul_f32, dst, true); } else if (dst.regClass() == v2) { + Temp src0 = get_alu_src(ctx, instr->src[0]); + Temp src1 = as_vgpr(ctx, get_alu_src(ctx, instr->src[1])); bld.vop3(aco_opcode::v_mul_f64, Definition(dst), src0, src1); } else { isel_err(&instr->instr, "Unimplemented NIR instr bit size"); @@ -1724,13 +1724,13 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_fadd: { - Temp src0 = get_alu_src(ctx, instr->src[0]); - Temp src1 = as_vgpr(ctx, get_alu_src(ctx, instr->src[1])); if (dst.regClass() == v2b) { emit_vop2_instruction(ctx, instr, aco_opcode::v_add_f16, dst, true); } else if (dst.regClass() == v1) { emit_vop2_instruction(ctx, instr, aco_opcode::v_add_f32, dst, true); } else if (dst.regClass() == v2) { + Temp src0 = get_alu_src(ctx, instr->src[0]); + Temp src1 = as_vgpr(ctx, get_alu_src(ctx, instr->src[1])); bld.vop3(aco_opcode::v_add_f64, Definition(dst), src0, src1); } else { isel_err(&instr->instr, "Unimplemented NIR instr bit size"); @@ -1761,14 +1761,14 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_fmax: { - Temp src0 = get_alu_src(ctx, instr->src[0]); - Temp src1 = as_vgpr(ctx, get_alu_src(ctx, instr->src[1])); if (dst.regClass() == v2b) { // TODO: check fp_mode.must_flush_denorms16_64 emit_vop2_instruction(ctx, instr, aco_opcode::v_max_f16, dst, true); } else if (dst.regClass() == v1) { emit_vop2_instruction(ctx, instr, aco_opcode::v_max_f32, dst, true, false, ctx->block->fp_mode.must_flush_denorms32); } else if (dst.regClass() == v2) { + Temp src0 = get_alu_src(ctx, instr->src[0]); + Temp src1 = as_vgpr(ctx, get_alu_src(ctx, instr->src[1])); if (ctx->block->fp_mode.must_flush_denorms16_64 && ctx->program->chip_class < GFX9) { Temp tmp = bld.vop3(aco_opcode::v_max_f64, bld.def(v2), src0, src1); bld.vop3(aco_opcode::v_mul_f64, Definition(dst), Operand(0x3FF0000000000000lu), tmp); @@ -1781,14 +1781,14 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_fmin: { - Temp src0 = get_alu_src(ctx, instr->src[0]); - Temp src1 = as_vgpr(ctx, get_alu_src(ctx, instr->src[1])); if (dst.regClass() == v2b) { // TODO: check fp_mode.must_flush_denorms16_64 emit_vop2_instruction(ctx, instr, aco_opcode::v_min_f16, dst, true); } else if (dst.regClass() == v1) { emit_vop2_instruction(ctx, instr, aco_opcode::v_min_f32, dst, true, false, ctx->block->fp_mode.must_flush_denorms32); } else if (dst.regClass() == v2) { + Temp src0 = get_alu_src(ctx, instr->src[0]); + Temp src1 = as_vgpr(ctx, get_alu_src(ctx, instr->src[1])); if (ctx->block->fp_mode.must_flush_denorms16_64 && ctx->program->chip_class < GFX9) { Temp tmp = bld.vop3(aco_opcode::v_min_f64, bld.def(v2), src0, src1); bld.vop3(aco_opcode::v_mul_f64, Definition(dst), Operand(0x3FF0000000000000lu), tmp); @@ -1831,10 +1831,10 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_frsq: { - Temp src = get_alu_src(ctx, instr->src[0]); if (dst.regClass() == v2b) { emit_vop1_instruction(ctx, instr, aco_opcode::v_rsq_f16, dst); } else if (dst.regClass() == v1) { + Temp src = get_alu_src(ctx, instr->src[0]); emit_rsq(ctx, bld, Definition(dst), src); } else if (dst.regClass() == v2) { /* Lowered at NIR level for precision reasons. */ @@ -1906,10 +1906,10 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_flog2: { - Temp src = get_alu_src(ctx, instr->src[0]); if (dst.regClass() == v2b) { emit_vop1_instruction(ctx, instr, aco_opcode::v_log_f16, dst); } else if (dst.regClass() == v1) { + Temp src = get_alu_src(ctx, instr->src[0]); emit_log2(ctx, bld, Definition(dst), src); } else { isel_err(&instr->instr, "Unimplemented NIR instr bit size"); @@ -1917,10 +1917,10 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_frcp: { - Temp src = get_alu_src(ctx, instr->src[0]); if (dst.regClass() == v2b) { emit_vop1_instruction(ctx, instr, aco_opcode::v_rcp_f16, dst); } else if (dst.regClass() == v1) { + Temp src = get_alu_src(ctx, instr->src[0]); emit_rcp(ctx, bld, Definition(dst), src); } else if (dst.regClass() == v2) { /* Lowered at NIR level for precision reasons. */ @@ -1941,10 +1941,10 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_fsqrt: { - Temp src = get_alu_src(ctx, instr->src[0]); if (dst.regClass() == v2b) { emit_vop1_instruction(ctx, instr, aco_opcode::v_sqrt_f16, dst); } else if (dst.regClass() == v1) { + Temp src = get_alu_src(ctx, instr->src[0]); emit_sqrt(ctx, bld, Definition(dst), src); } else if (dst.regClass() == v2) { /* Lowered at NIR level for precision reasons. */ @@ -1967,12 +1967,12 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_ffloor: { - Temp src = get_alu_src(ctx, instr->src[0]); if (dst.regClass() == v2b) { emit_vop1_instruction(ctx, instr, aco_opcode::v_floor_f16, dst); } else if (dst.regClass() == v1) { emit_vop1_instruction(ctx, instr, aco_opcode::v_floor_f32, dst); } else if (dst.regClass() == v2) { + Temp src = get_alu_src(ctx, instr->src[0]); emit_floor_f64(ctx, bld, Definition(dst), src); } else { isel_err(&instr->instr, "Unimplemented NIR instr bit size"); @@ -1980,7 +1980,6 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_fceil: { - Temp src0 = get_alu_src(ctx, instr->src[0]); if (dst.regClass() == v2b) { emit_vop1_instruction(ctx, instr, aco_opcode::v_ceil_f16, dst); } else if (dst.regClass() == v1) { @@ -1994,6 +1993,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) * if (src0 > 0.0 && src0 != trunc) * trunc += 1.0 */ + Temp src0 = get_alu_src(ctx, instr->src[0]); Temp trunc = emit_trunc_f64(ctx, bld, bld.def(v2), src0); Temp tmp0 = bld.vopc_e64(aco_opcode::v_cmp_gt_f64, bld.def(bld.lm), src0, Operand(0u)); Temp tmp1 = bld.vopc(aco_opcode::v_cmp_lg_f64, bld.hint_vcc(bld.def(bld.lm)), src0, trunc); @@ -2008,12 +2008,12 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_ftrunc: { - Temp src = get_alu_src(ctx, instr->src[0]); if (dst.regClass() == v2b) { emit_vop1_instruction(ctx, instr, aco_opcode::v_trunc_f16, dst); } else if (dst.regClass() == v1) { emit_vop1_instruction(ctx, instr, aco_opcode::v_trunc_f32, dst); } else if (dst.regClass() == v2) { + Temp src = get_alu_src(ctx, instr->src[0]); emit_trunc_f64(ctx, bld, Definition(dst), src); } else { isel_err(&instr->instr, "Unimplemented NIR instr bit size"); @@ -2021,7 +2021,6 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_fround_even: { - Temp src0 = get_alu_src(ctx, instr->src[0]); if (dst.regClass() == v2b) { emit_vop1_instruction(ctx, instr, aco_opcode::v_rndne_f16, dst); } else if (dst.regClass() == v1) { @@ -2032,6 +2031,7 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) } else { /* GFX6 doesn't support V_RNDNE_F64, lower it. */ Temp src0_lo = bld.tmp(v1), src0_hi = bld.tmp(v1); + Temp src0 = get_alu_src(ctx, instr->src[0]); bld.pseudo(aco_opcode::p_split_vector, Definition(src0_lo), Definition(src0_hi), src0); Temp bitmask = bld.sop1(aco_opcode::s_brev_b32, bld.def(s1), bld.copy(bld.def(s1), Operand(-2u))); @@ -2097,28 +2097,27 @@ void visit_alu_instr(isel_context *ctx, nir_alu_instr *instr) break; } case nir_op_frexp_sig: { - Temp src = get_alu_src(ctx, instr->src[0]); if (dst.regClass() == v2b) { - bld.vop1(aco_opcode::v_frexp_mant_f16, Definition(dst), src); + emit_vop1_instruction(ctx, instr, aco_opcode::v_frexp_mant_f16, dst); } else if (dst.regClass() == v1) { - bld.vop1(aco_opcode::v_frexp_mant_f32, Definition(dst), src); + emit_vop1_instruction(ctx, instr, aco_opcode::v_frexp_mant_f32, dst); } else if (dst.regClass() == v2) { - bld.vop1(aco_opcode::v_frexp_mant_f64, Definition(dst), src); + emit_vop1_instruction(ctx, instr, aco_opcode::v_frexp_mant_f64, dst); } else { isel_err(&instr->instr, "Unimplemented NIR instr bit size"); } break; } case nir_op_frexp_exp: { - Temp src = get_alu_src(ctx, instr->src[0]); if (instr->src[0].src.ssa->bit_size == 16) { + Temp src = get_alu_src(ctx, instr->src[0]); Temp tmp = bld.vop1(aco_opcode::v_frexp_exp_i16_f16, bld.def(v1), src); tmp = bld.pseudo(aco_opcode::p_extract_vector, bld.def(v1b), tmp, Operand(0u)); convert_int(ctx, bld, tmp, 8, 32, true, dst); } else if (instr->src[0].src.ssa->bit_size == 32) { - bld.vop1(aco_opcode::v_frexp_exp_i32_f32, Definition(dst), src); + emit_vop1_instruction(ctx, instr, aco_opcode::v_frexp_exp_i32_f32, dst); } else if (instr->src[0].src.ssa->bit_size == 64) { - bld.vop1(aco_opcode::v_frexp_exp_i32_f64, Definition(dst), src); + emit_vop1_instruction(ctx, instr, aco_opcode::v_frexp_exp_i32_f64, dst); } else { isel_err(&instr->instr, "Unimplemented NIR instr bit size"); }