From fc4b23130cc5710302a036f255359d750419d900 Mon Sep 17 00:00:00 2001 From: Georg Lehmann Date: Fri, 13 Sep 2024 18:55:13 +0200 Subject: [PATCH] aco/isel: add function to create builder for alu Reviewed-by: Rhys Perry Part-of: --- .../compiler/aco_instruction_selection.cpp | 110 ++++++++---------- 1 file changed, 51 insertions(+), 59 deletions(-) diff --git a/src/amd/compiler/aco_instruction_selection.cpp b/src/amd/compiler/aco_instruction_selection.cpp index 12bef5a9740..635609bfa00 100644 --- a/src/amd/compiler/aco_instruction_selection.cpp +++ b/src/amd/compiler/aco_instruction_selection.cpp @@ -116,6 +116,14 @@ get_ssa_temp(struct isel_context* ctx, nir_def* def) return Temp(id, ctx->program->temp_rc[id]); } +static Builder +create_alu_builder(isel_context* ctx, nir_alu_instr* instr) +{ + Builder bld(ctx->program, ctx->block); + bld.is_precise = instr->exact; + return bld; +} + Temp emit_mbcnt(isel_context* ctx, Temp dst, Operand mask = Operand(), Operand base = Operand::zero()) { @@ -794,26 +802,23 @@ void emit_sop2_instruction(isel_context* ctx, nir_alu_instr* instr, aco_opcode op, Temp dst, bool writes_scc, uint8_t uses_ub = 0) { - aco_ptr sop2{create_instruction(op, Format::SOP2, 2, writes_scc ? 2 : 1)}; - sop2->operands[0] = Operand(get_alu_src(ctx, instr->src[0])); - sop2->operands[1] = Operand(get_alu_src(ctx, instr->src[1])); - sop2->definitions[0] = Definition(dst); - if (instr->no_unsigned_wrap) - sop2->definitions[0].setNUW(true); - if (writes_scc) - sop2->definitions[1] = Definition(ctx->program->allocateId(s1), scc, s1); + Builder bld = create_alu_builder(ctx, instr); + bld.is_nuw = instr->no_unsigned_wrap; - for (int i = 0; i < 2; i++) { - if (uses_ub & (1 << i)) { - uint32_t src_ub = get_alu_src_ub(ctx, instr, i); - if (src_ub <= 0xffff) - sop2->operands[i].set16bit(true); - else if (src_ub <= 0xffffff) - sop2->operands[i].set24bit(true); - } + Operand operands[2] = {Operand(get_alu_src(ctx, instr->src[0])), + Operand(get_alu_src(ctx, instr->src[1]))}; + u_foreach_bit (i, uses_ub) { + uint32_t src_ub = get_alu_src_ub(ctx, instr, i); + if (src_ub <= 0xffff) + operands[i].set16bit(true); + else if (src_ub <= 0xffffff) + operands[i].set24bit(true); } - ctx->block->instructions.emplace_back(std::move(sop2)); + if (writes_scc) + bld.sop2(op, Definition(dst), bld.def(s1, scc), operands[0], operands[1]); + else + bld.sop2(op, Definition(dst), operands[0], operands[1]); } void @@ -821,54 +826,46 @@ emit_vop2_instruction(isel_context* ctx, nir_alu_instr* instr, aco_opcode opc, T bool commutative, bool swap_srcs = false, bool flush_denorms = false, bool nuw = false, uint8_t uses_ub = 0) { - Builder bld(ctx->program, ctx->block); - bld.is_precise = instr->exact; + Builder bld = create_alu_builder(ctx, instr); + bld.is_nuw = nuw; - Temp src0 = get_alu_src(ctx, instr->src[swap_srcs ? 1 : 0]); - Temp src1 = get_alu_src(ctx, instr->src[swap_srcs ? 0 : 1]); - if (src1.type() == RegType::sgpr) { - if (commutative && src0.type() == RegType::vgpr) { - Temp t = src0; - src0 = src1; - src1 = t; - } else { - src1 = as_vgpr(ctx, src1); - } + Operand operands[2] = {Operand(get_alu_src(ctx, instr->src[0])), + Operand(get_alu_src(ctx, instr->src[1]))}; + u_foreach_bit (i, uses_ub) { + uint32_t src_ub = get_alu_src_ub(ctx, instr, i); + if (src_ub <= 0xffff) + operands[i].set16bit(true); + else if (src_ub <= 0xffffff) + operands[i].set24bit(true); } - Operand op[2] = {Operand(src0), Operand(src1)}; + if (swap_srcs) + std::swap(operands[0], operands[1]); - for (int i = 0; i < 2; i++) { - if (uses_ub & (1 << i)) { - uint32_t src_ub = get_alu_src_ub(ctx, instr, swap_srcs ? !i : i); - if (src_ub <= 0xffff) - op[i].set16bit(true); - else if (src_ub <= 0xffffff) - op[i].set24bit(true); + if (operands[1].isOfType(RegType::sgpr)) { + if (commutative && operands[0].isOfType(RegType::vgpr)) { + std::swap(operands[0], operands[1]); + } else { + operands[1] = bld.copy(bld.def(RegType::vgpr, operands[1].size()), operands[1]); } } if (flush_denorms && ctx->program->gfx_level < GFX9) { assert(dst.size() == 1); - Temp tmp = bld.vop2(opc, bld.def(dst.regClass()), op[0], op[1]); + Temp tmp = bld.vop2(opc, bld.def(dst.regClass()), operands[0], operands[1]); if (dst.bytes() == 2) bld.vop2(aco_opcode::v_mul_f16, Definition(dst), Operand::c16(0x3c00), tmp); else bld.vop2(aco_opcode::v_mul_f32, Definition(dst), Operand::c32(0x3f800000u), tmp); } else { - if (nuw) { - bld.nuw().vop2(opc, Definition(dst), op[0], op[1]); - } else { - bld.vop2(opc, Definition(dst), op[0], op[1]); - } + bld.vop2(opc, Definition(dst), operands[0], operands[1]); } } void emit_vop2_instruction_logic64(isel_context* ctx, nir_alu_instr* instr, aco_opcode op, Temp dst) { - Builder bld(ctx->program, ctx->block); - bld.is_precise = instr->exact; + Builder bld = create_alu_builder(ctx, instr); Temp src0 = get_alu_src(ctx, instr->src[0]); Temp src1 = get_alu_src(ctx, instr->src[1]); @@ -904,8 +901,7 @@ emit_vop3a_instruction(isel_context* ctx, nir_alu_instr* instr, aco_opcode op, T has_sgpr = src[i].type() == RegType::sgpr; } - Builder bld(ctx->program, ctx->block); - bld.is_precise = instr->exact; + Builder bld = create_alu_builder(ctx, instr); if (flush_denorms && ctx->program->gfx_level < GFX9) { Temp tmp; if (num_sources == 3) @@ -940,8 +936,7 @@ emit_vop3p_instruction(isel_context* ctx, nir_alu_instr* instr, aco_opcode op, T unsigned opsel_hi = (instr->src[!swap_srcs].swizzle[1] & 1) << 1 | (instr->src[swap_srcs].swizzle[1] & 1); - Builder bld(ctx->program, ctx->block); - bld.is_precise = instr->exact; + Builder bld = create_alu_builder(ctx, instr); Builder::Result res = bld.vop3p(op, Definition(dst), src0, src1, opsel_lo, opsel_hi); emit_split_vector(ctx, dst, 2); return res; @@ -961,8 +956,7 @@ emit_idot_instruction(isel_context* ctx, nir_alu_instr* instr, aco_opcode op, Te has_sgpr = src[i].type() == RegType::sgpr; } - Builder bld(ctx->program, ctx->block); - bld.is_precise = instr->exact; + Builder bld = create_alu_builder(ctx, instr); VALU_instruction& vop3p = bld.vop3p(op, Definition(dst), src[0], src[1], src[2], 0x0, 0x7)->valu(); vop3p.clamp = clamp; @@ -972,8 +966,7 @@ emit_idot_instruction(isel_context* ctx, nir_alu_instr* instr, aco_opcode op, Te void emit_vop1_instruction(isel_context* ctx, nir_alu_instr* instr, aco_opcode op, Temp dst) { - Builder bld(ctx->program, ctx->block); - bld.is_precise = instr->exact; + Builder bld = create_alu_builder(ctx, instr); if (dst.type() == RegType::sgpr) bld.pseudo(aco_opcode::p_as_uniform, Definition(dst), bld.vop1(op, bld.def(RegType::vgpr, dst.size()), get_alu_src(ctx, instr->src[0]))); @@ -1001,7 +994,7 @@ emit_vopc_instruction(isel_context* ctx, nir_alu_instr* instr, aco_opcode op, Te } } - Builder bld(ctx->program, ctx->block); + Builder bld = create_alu_builder(ctx, instr); bld.vopc(op, Definition(dst), src0, src1); } @@ -1010,7 +1003,7 @@ emit_sopc_instruction(isel_context* ctx, nir_alu_instr* instr, aco_opcode op, Te { Temp src0 = get_alu_src(ctx, instr->src[0]); Temp src1 = get_alu_src(ctx, instr->src[1]); - Builder bld(ctx->program, ctx->block); + Builder bld = create_alu_builder(ctx, instr); assert(dst.regClass() == bld.lm); assert(src0.type() == RegType::sgpr); @@ -1344,7 +1337,7 @@ usub32_sat(Builder& bld, Definition dst, Temp src0, Temp src1) void emit_vec2_f2f16(isel_context* ctx, nir_alu_instr* instr, Temp dst) { - Builder bld(ctx->program, ctx->block); + Builder bld = create_alu_builder(ctx, instr); Temp src = get_ssa_temp(ctx, instr->src[0].src.ssa); RegClass rc = RegClass(src.regClass().type(), instr->src[0].src.ssa->bit_size / 32); Temp src0 = emit_extract_vector(ctx, src, instr->src[0].swizzle[0], rc); @@ -1365,8 +1358,7 @@ emit_vec2_f2f16(isel_context* ctx, nir_alu_instr* instr, Temp dst) void visit_alu_instr(isel_context* ctx, nir_alu_instr* instr) { - Builder bld(ctx->program, ctx->block); - bld.is_precise = instr->exact; + Builder bld = create_alu_builder(ctx, instr); Temp dst = get_ssa_temp(ctx, &instr->def); switch (instr->op) { case nir_op_vec2: @@ -1713,7 +1705,7 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr) emit_vop3p_instruction(ctx, instr, aco_opcode::v_pk_lshlrev_b16, dst, true); } else if (dst.regClass() == v1) { emit_vop2_instruction(ctx, instr, aco_opcode::v_lshlrev_b32, dst, false, true, false, - false, 2); + false, 1); } else if (dst.regClass() == v2 && ctx->program->gfx_level >= GFX8) { bld.vop3(aco_opcode::v_lshlrev_b64_e64, Definition(dst), get_alu_src(ctx, instr->src[1]), get_alu_src(ctx, instr->src[0]));