From 2f94353735b5ddfe2a72499e7bf6c7bbc80b9a00 Mon Sep 17 00:00:00 2001 From: Rhys Perry Date: Wed, 12 Aug 2020 14:35:15 +0100 Subject: [PATCH] aco: add p_extract/p_insert MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit These will let us make the SDWA optimizer much simpler than if we were to recognize combinations of shift/and/bfe. Signed-off-by: Rhys Perry Reviewed-by: Timur Kristóf Part-of: --- src/amd/compiler/aco_lower_to_hw_instr.cpp | 97 ++++++++++++++++++++++ src/amd/compiler/aco_opcodes.py | 8 ++ src/amd/compiler/aco_optimizer.cpp | 86 +++++++++++++++++-- src/amd/compiler/aco_validate.cpp | 23 +++++ 4 files changed, 207 insertions(+), 7 deletions(-) diff --git a/src/amd/compiler/aco_lower_to_hw_instr.cpp b/src/amd/compiler/aco_lower_to_hw_instr.cpp index c4b8120ed16..f9dfe0b2d29 100644 --- a/src/amd/compiler/aco_lower_to_hw_instr.cpp +++ b/src/amd/compiler/aco_lower_to_hw_instr.cpp @@ -1994,6 +1994,103 @@ void lower_to_hw_instr(Program* program) Operand(reg.advance(4), s1), Operand(0u), Operand(scc, s1)); break; } + case aco_opcode::p_extract: + { + assert(instr->operands[1].isConstant()); + assert(instr->operands[2].isConstant()); + assert(instr->operands[3].isConstant()); + if (instr->definitions[0].regClass() == s1) + assert(instr->definitions.size() >= 2 && instr->definitions[1].physReg() == scc); + Definition dst = instr->definitions[0]; + Operand op = instr->operands[0]; + unsigned bits = instr->operands[2].constantValue(); + unsigned index = instr->operands[1].constantValue(); + unsigned offset = index * bits; + bool signext = !instr->operands[3].constantEquals(0); + + if (dst.regClass() == s1) { + if (offset == (32 - bits)) { + bld.sop2(signext ? aco_opcode::s_ashr_i32 : aco_opcode::s_lshr_b32, + dst, bld.def(s1, scc), op, Operand(offset)); + } else if (offset == 0 && signext && (bits == 8 || bits == 16)) { + bld.sop1(bits == 8 ? aco_opcode::s_sext_i32_i8 : aco_opcode::s_sext_i32_i16, dst, op); + } else { + bld.sop2(signext ? aco_opcode::s_bfe_i32 : aco_opcode::s_bfe_u32, + dst, bld.def(s1, scc), op, Operand((bits << 16) | offset)); + } + } else if (dst.regClass() == v1 || ctx.program->chip_class <= GFX7) { + assert(op.physReg().byte() == 0 && dst.physReg().byte() == 0); + if (offset == (32 - bits) && op.regClass() != s1) { + bld.vop2(signext ? aco_opcode::v_ashrrev_i32 : aco_opcode::v_lshrrev_b32, + dst, Operand(offset), op); + } else { + bld.vop3(signext ? aco_opcode::v_bfe_i32 : aco_opcode::v_bfe_u32, + dst, op, Operand(offset), Operand(bits)); + } + } else if (dst.regClass() == v2b) { + aco_ptr sdwa{create_instruction( + aco_opcode::v_mov_b32, (Format)((uint16_t)Format::VOP1|(uint16_t)Format::SDWA), 1, 1)}; + sdwa->operands[0] = Operand(op.physReg().advance(-op.physReg().byte()), + RegClass::get(op.regClass().type(), 4)); + sdwa->definitions[0] = dst; + sdwa->sel[0] = sdwa_ubyte0 + op.physReg().byte() + index; + if (signext) + sdwa->sel[0] |= sdwa_sext; + sdwa->dst_sel = sdwa_uword; + bld.insert(std::move(sdwa)); + } + break; + } + case aco_opcode::p_insert: + { + assert(instr->operands[1].isConstant()); + assert(instr->operands[2].isConstant()); + if (instr->definitions[0].regClass() == s1) + assert(instr->definitions.size() >= 2 && instr->definitions[1].physReg() == scc); + Definition dst = instr->definitions[0]; + Operand op = instr->operands[0]; + unsigned bits = instr->operands[2].constantValue(); + unsigned index = instr->operands[1].constantValue(); + unsigned offset = index * bits; + + if (dst.regClass() == s1) { + if (offset == (32 - bits)) { + bld.sop2(aco_opcode::s_lshl_b32, dst, bld.def(s1, scc), op, Operand(offset)); + } else if (offset == 0) { + bld.sop2(aco_opcode::s_bfe_u32, dst, bld.def(s1, scc), op, Operand(bits << 16)); + } else { + bld.sop2(aco_opcode::s_bfe_u32, dst, bld.def(s1, scc), op, Operand(bits << 16)); + bld.sop2(aco_opcode::s_lshl_b32, dst, bld.def(s1, scc), Operand(dst.physReg(), s1), Operand(offset)); + } + } else if (dst.regClass() == v1 || ctx.program->chip_class <= GFX7) { + if (offset == (dst.bytes() * 8u - bits)) { + bld.vop2(aco_opcode::v_lshlrev_b32, dst, Operand(offset), op); + } else if (offset == 0) { + bld.vop3(aco_opcode::v_bfe_u32, dst, op, Operand(0u), Operand(bits)); + } else if (program->chip_class >= GFX9 || (op.regClass() != s1 && program->chip_class >= GFX8)) { + aco_ptr sdwa{create_instruction(aco_opcode::v_mov_b32, (Format)((uint16_t)Format::VOP1|(uint16_t)Format::SDWA), 1, 1)}; + sdwa->operands[0] = op; + sdwa->definitions[0] = dst; + sdwa->sel[0] = sdwa_udword; + sdwa->dst_sel = (bits == 8 ? sdwa_ubyte0 : sdwa_uword0) + (offset / bits); + bld.insert(std::move(sdwa)); + } else { + bld.vop3(aco_opcode::v_bfe_u32, dst, op, Operand(0u), Operand(bits)); + bld.vop2(aco_opcode::v_lshlrev_b32, dst, Operand(offset), Operand(dst.physReg(), v1)); + } + } else { + assert(dst.regClass() == v2b); + aco_ptr sdwa{create_instruction( + aco_opcode::v_mov_b32, (Format)((uint16_t)Format::VOP1|(uint16_t)Format::SDWA), 1, 1)}; + sdwa->operands[0] = op; + sdwa->definitions[0] = Definition(dst.physReg().advance(-dst.physReg().byte()), v1); + sdwa->sel[0] = sdwa_uword; + sdwa->dst_sel = sdwa_ubyte0 + dst.physReg().byte() + index; + sdwa->dst_preserve = 1; + bld.insert(std::move(sdwa)); + } + break; + } default: break; } diff --git a/src/amd/compiler/aco_opcodes.py b/src/amd/compiler/aco_opcodes.py index a28f1d5d765..07ac9cf104c 100644 --- a/src/amd/compiler/aco_opcodes.py +++ b/src/amd/compiler/aco_opcodes.py @@ -320,6 +320,14 @@ opcode("p_bpermute") opcode("p_constaddr") +# These don't have to be pseudo-ops, but it makes optimization easier to only +# have to consider two instructions. +# (src0 >> (index * bits)) & ((1 << bits) - 1) with optional sign extension +opcode("p_extract") # src1=index, src2=bits, src3=signext +# (src0 & ((1 << bits) - 1)) << (index * bits) +opcode("p_insert") # src1=index, src2=bits + + # SOP2 instructions: 2 scalar inputs, 1 scalar output (+optional scc) SOP2 = { # GFX6, GFX7, GFX8, GFX9, GFX10, name diff --git a/src/amd/compiler/aco_optimizer.cpp b/src/amd/compiler/aco_optimizer.cpp index 42f50cae8fc..b1fadd33c31 100644 --- a/src/amd/compiler/aco_optimizer.cpp +++ b/src/amd/compiler/aco_optimizer.cpp @@ -763,6 +763,8 @@ bool alu_can_accept_constant(aco_opcode opcode, unsigned operand) case aco_opcode::v_readlane_b32: case aco_opcode::v_readlane_b32_e64: case aco_opcode::v_readfirstlane_b32: + case aco_opcode::p_extract: + case aco_opcode::p_insert: return operand != 0; default: return true; @@ -1610,6 +1612,16 @@ void label_instruction(opt_ctx &ctx, aco_ptr& instr) if (instr->operands[0].constantEquals(0x3f800000u)) ctx.info[instr->definitions[0].tempId()].set_canonicalized(); break; + case aco_opcode::p_extract: { + if (instr->operands[0].isTemp()) + ctx.info[instr->definitions[0].tempId()].set_bitwise(instr.get()); + break; + } + case aco_opcode::p_insert: { + if (instr->operands[0].isTemp()) + ctx.info[instr->definitions[0].tempId()].set_bitwise(instr.get()); + break; + } default: break; } @@ -2210,6 +2222,70 @@ bool combine_three_valu_op(opt_ctx& ctx, aco_ptr& instr, aco_opcode return false; } +/* creates v_lshl_add_u32, v_lshl_or_b32 or v_and_or_b32 */ +bool combine_add_or_then_and_lshl(opt_ctx& ctx, aco_ptr& instr) +{ + bool is_or = instr->opcode == aco_opcode::v_or_b32; + aco_opcode new_op_lshl = is_or ? aco_opcode::v_lshl_or_b32 : aco_opcode::v_lshl_add_u32; + + if (is_or && combine_three_valu_op(ctx, instr, aco_opcode::s_and_b32, aco_opcode::v_and_or_b32, "120", 1 | 2)) + return true; + if (is_or && combine_three_valu_op(ctx, instr, aco_opcode::v_and_b32, aco_opcode::v_and_or_b32, "120", 1 | 2)) + return true; + if (combine_three_valu_op(ctx, instr, aco_opcode::s_lshl_b32, new_op_lshl, "120", 1 | 2)) + return true; + if (combine_three_valu_op(ctx, instr, aco_opcode::v_lshlrev_b32, new_op_lshl, "210", 1 | 2)) + return true; + + if (instr->isSDWA()) + return false; + + /* v_or_b32(p_extract(a, 0, 8/16, 0), b) -> v_and_or_b32(a, 0xff/0xffff, b) + * v_or_b32(p_insert(a, 0, 8/16), b) -> v_and_or_b32(a, 0xff/0xffff, b) + * v_or_b32(p_insert(a, 24/16, 8/16), b) -> v_lshl_or_b32(a, 24/16, b) + * v_add_u32(p_insert(a, 24/16, 8/16), b) -> v_lshl_add_b32(a, 24/16, b) + */ + for (unsigned i = 0; i < 2; i++) { + Instruction *extins = follow_operand(ctx, instr->operands[i]); + if (!extins) + continue; + + aco_opcode op; + Operand operands[3]; + + if (extins->opcode == aco_opcode::p_insert && + (extins->operands[1].constantValue() + 1) * extins->operands[2].constantValue() == 32) { + op = new_op_lshl; + operands[1] = Operand(extins->operands[1].constantValue() * extins->operands[2].constantValue()); + } else if (is_or && (extins->opcode == aco_opcode::p_insert || + (extins->opcode == aco_opcode::p_extract && extins->operands[3].constantEquals(0))) && + extins->operands[1].constantEquals(0)) { + op = aco_opcode::v_and_or_b32; + operands[1] = Operand(extins->operands[2].constantEquals(8) ? 0xffu : 0xffffu); + } else { + continue; + } + + operands[0] = extins->operands[0]; + operands[2] = instr->operands[!i]; + + if (!check_vop3_operands(ctx, 3, operands)) + continue; + + bool neg[3] = {}, abs[3] = {}; + uint8_t opsel = 0, omod = 0; + bool clamp = false; + if (instr->isVOP3()) + clamp = instr->vop3().clamp; + + ctx.uses[instr->operands[i].tempId()]--; + create_vop3_for_op3(ctx, op, instr, operands, neg, abs, opsel, clamp, omod); + return true; + } + + return false; +} + bool combine_minmax(opt_ctx& ctx, aco_ptr& instr, aco_opcode opposite, aco_opcode minmax3) { if (combine_three_valu_op(ctx, instr, instr->opcode, minmax3, "012", 1 | 2)) @@ -3198,10 +3274,7 @@ void combine_instruction(opt_ctx &ctx, aco_ptr& instr) } else if (instr->opcode == aco_opcode::v_or_b32 && ctx.program->chip_class >= GFX9) { if (combine_three_valu_op(ctx, instr, aco_opcode::s_or_b32, aco_opcode::v_or3_b32, "012", 1 | 2)) ; else if (combine_three_valu_op(ctx, instr, aco_opcode::v_or_b32, aco_opcode::v_or3_b32, "012", 1 | 2)) ; - else if (combine_three_valu_op(ctx, instr, aco_opcode::s_and_b32, aco_opcode::v_and_or_b32, "120", 1 | 2)) ; - else if (combine_three_valu_op(ctx, instr, aco_opcode::v_and_b32, aco_opcode::v_and_or_b32, "120", 1 | 2)) ; - else if (combine_three_valu_op(ctx, instr, aco_opcode::s_lshl_b32, aco_opcode::v_lshl_or_b32, "120", 1 | 2)) ; - else combine_three_valu_op(ctx, instr, aco_opcode::v_lshlrev_b32, aco_opcode::v_lshl_or_b32, "210", 1 | 2); + else combine_add_or_then_and_lshl(ctx, instr) ; } else if (instr->opcode == aco_opcode::v_xor_b32 && ctx.program->chip_class >= GFX10) { if (combine_three_valu_op(ctx, instr, aco_opcode::v_xor_b32, aco_opcode::v_xor3_b32, "012", 1 | 2)) ; else combine_three_valu_op(ctx, instr, aco_opcode::s_xor_b32, aco_opcode::v_xor3_b32, "012", 1 | 2); @@ -3215,9 +3288,8 @@ void combine_instruction(opt_ctx &ctx, aco_ptr& instr) else if (combine_three_valu_op(ctx, instr, aco_opcode::s_add_i32, aco_opcode::v_add3_u32, "012", 1 | 2)) ; else if (combine_three_valu_op(ctx, instr, aco_opcode::s_add_u32, aco_opcode::v_add3_u32, "012", 1 | 2)) ; else if (combine_three_valu_op(ctx, instr, aco_opcode::v_add_u32, aco_opcode::v_add3_u32, "012", 1 | 2)) ; - else if (combine_three_valu_op(ctx, instr, aco_opcode::s_lshl_b32, aco_opcode::v_lshl_add_u32, "120", 1 | 2)) ; - else if (combine_three_valu_op(ctx, instr, aco_opcode::v_lshlrev_b32, aco_opcode::v_lshl_add_u32, "210", 1 | 2)) ; - else combine_three_valu_op(ctx, instr, aco_opcode::v_mul_lo_u16, aco_opcode::v_mad_u32_u16, "120", 1 | 2) ; + else if (combine_three_valu_op(ctx, instr, aco_opcode::v_mul_lo_u16, aco_opcode::v_mad_u32_u16, "120", 1 | 2)) ; + else combine_add_or_then_and_lshl(ctx, instr) ; } } else if (instr->opcode == aco_opcode::v_add_co_u32 || instr->opcode == aco_opcode::v_add_co_u32_e64) { diff --git a/src/amd/compiler/aco_validate.cpp b/src/amd/compiler/aco_validate.cpp index f5ba8ab4958..51ca2a35ae1 100644 --- a/src/amd/compiler/aco_validate.cpp +++ b/src/amd/compiler/aco_validate.cpp @@ -376,6 +376,29 @@ bool validate_ir(Program* program) check(instr->definitions[0].size() == op.size(), "Operand sizes must match Definition size", instr.get()); } check(instr->operands.size() == block.linear_preds.size(), "Number of Operands does not match number of predecessors", instr.get()); + } else if (instr->opcode == aco_opcode::p_extract || instr->opcode == aco_opcode::p_insert) { + check(instr->operands[0].isTemp(), + "Data operand must be temporary", instr.get()); + check(instr->operands[1].isConstant(), "Index must be constant", instr.get()); + if (instr->opcode == aco_opcode::p_extract) + check(instr->operands[3].isConstant(), "Sign-extend flag must be constant", instr.get()); + + check(instr->definitions[0].getTemp().type() != RegType::sgpr || + instr->operands[0].getTemp().type() == RegType::sgpr, + "Can't extract/insert VGPR to SGPR", instr.get()); + + if (instr->operands[0].getTemp().type() == RegType::vgpr) + check(instr->operands[0].bytes() == instr->definitions[0].bytes(), + "Sizes of operand and definition must match", instr.get()); + + if (instr->definitions[0].getTemp().type() == RegType::sgpr) + check(instr->definitions.size() >= 2 && instr->definitions[1].isFixed() && instr->definitions[1].physReg() == scc, "SGPR extract/insert needs a SCC definition", instr.get()); + + check(instr->operands[2].constantEquals(8) || instr->operands[2].constantEquals(16), "Size must be 8 or 16", instr.get()); + check(instr->operands[2].constantValue() < instr->operands[0].getTemp().bytes() * 8u, "Size must be smaller than source", instr.get()); + + unsigned comp = instr->operands[0].bytes() * 8u / MAX2(instr->operands[2].constantValue(), 1); + check(instr->operands[1].constantValue() < comp, "Index must be in-bounds", instr.get()); } break; }