mesa/src/amd/compiler/aco_builder_h.py
Natalie Vock 6d799ac283 aco: Add pass for spilling call-related registers
This is a post-RA pass that tracks registers that are preserved by the
ABI, but clobbered by shader code. The pass inserts scratch spills and
reloads in appropriate locations to ensure the register values at the
end of the shader are the same as they were at the start.

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/38281>
2025-12-08 19:12:55 +00:00

696 lines
23 KiB
Python

template = """\
/*
* Copyright (c) 2019 Valve Corporation
*
* SPDX-License-Identifier: MIT
*
* This file was generated by aco_builder_h.py
*/
#ifndef _ACO_BUILDER_
#define _ACO_BUILDER_
#include "aco_ir.h"
namespace aco {
enum dpp_ctrl {
_dpp_quad_perm = 0x000,
_dpp_row_sl = 0x100,
_dpp_row_sr = 0x110,
_dpp_row_rr = 0x120,
dpp_wf_sl1 = 0x130,
dpp_wf_rl1 = 0x134,
dpp_wf_sr1 = 0x138,
dpp_wf_rr1 = 0x13C,
dpp_row_mirror = 0x140,
dpp_row_half_mirror = 0x141,
dpp_row_bcast15 = 0x142,
dpp_row_bcast31 = 0x143,
_dpp_row_share = 0x150,
_dpp_row_xmask = 0x160,
};
inline dpp_ctrl
dpp_quad_perm(unsigned lane0, unsigned lane1, unsigned lane2, unsigned lane3)
{
assert(lane0 < 4 && lane1 < 4 && lane2 < 4 && lane3 < 4);
return (dpp_ctrl)(lane0 | (lane1 << 2) | (lane2 << 4) | (lane3 << 6));
}
inline dpp_ctrl
dpp_row_sl(unsigned amount)
{
assert(amount > 0 && amount < 16);
return (dpp_ctrl)(((unsigned) _dpp_row_sl) | amount);
}
inline dpp_ctrl
dpp_row_sr(unsigned amount)
{
assert(amount > 0 && amount < 16);
return (dpp_ctrl)(((unsigned) _dpp_row_sr) | amount);
}
inline dpp_ctrl
dpp_row_rr(unsigned amount)
{
assert(amount > 0 && amount < 16);
return (dpp_ctrl)(((unsigned) _dpp_row_rr) | amount);
}
inline dpp_ctrl
dpp_row_share(unsigned lane)
{
assert(lane < 16);
return (dpp_ctrl)(((unsigned) _dpp_row_share) | lane);
}
inline dpp_ctrl
dpp_row_xmask(unsigned mask)
{
assert(mask < 16);
return (dpp_ctrl)(((unsigned) _dpp_row_xmask) | mask);
}
inline unsigned
ds_pattern_bitmode(unsigned and_mask, unsigned or_mask, unsigned xor_mask)
{
assert(and_mask < 32 && or_mask < 32 && xor_mask < 32);
return and_mask | (or_mask << 5) | (xor_mask << 10);
}
inline unsigned
ds_pattern_rotate(unsigned delta, unsigned mask)
{
assert(delta < 32 && mask < 32);
return mask | (delta << 5) | 0xc000;
}
aco_ptr<Instruction> create_s_mov(Definition dst, Operand src);
enum sendmsg {
sendmsg_none = 0,
sendmsg_gs = 2, /* gfx6 to gfx10.3 */
sendmsg_gs_done = 3, /* gfx6 to gfx10.3 */
sendmsg_hs_tessfactor = 2, /* gfx11+ */
sendmsg_dealloc_vgprs = 3, /* gfx11+ */
sendmsg_save_wave = 4, /* gfx8 to gfx10.3 */
sendmsg_stall_wave_gen = 5, /* gfx9+ */
sendmsg_halt_waves = 6, /* gfx9+ */
sendmsg_ordered_ps_done = 7, /* gfx9+ */
sendmsg_early_prim_dealloc = 8, /* gfx9 to gfx10 */
sendmsg_gs_alloc_req = 9, /* gfx9+ */
sendmsg_get_doorbell = 10, /* gfx9 to gfx10.3 */
sendmsg_get_ddid = 11, /* gfx10 to gfx10.3 */
sendmsg_id_mask = 0xf,
};
/* gfx11+ */
enum sendmsg_rtn {
sendmsg_rtn_get_doorbell = 0,
sendmsg_rtn_get_ddid = 1,
sendmsg_rtn_get_tma = 2,
sendmsg_rtn_get_realtime = 3,
sendmsg_rtn_save_wave = 4,
sendmsg_rtn_get_tba = 5,
sendmsg_rtn_mask = 0xff,
};
enum bperm_swiz {
bperm_b1_sign = 8,
bperm_b3_sign = 9,
bperm_b5_sign = 10,
bperm_b7_sign = 11,
bperm_0 = 12,
bperm_255 = 13,
};
enum class alu_delay_wait {
NO_DEP = 0,
VALU_DEP_1 = 1,
VALU_DEP_2 = 2,
VALU_DEP_3 = 3,
VALU_DEP_4 = 4,
TRANS32_DEP_1 = 5,
TRANS32_DEP_2 = 6,
TRANS32_DEP_3 = 7,
FMA_ACCUM_CYCLE_1 = 8,
SALU_CYCLE_1 = 9,
SALU_CYCLE_2 = 10,
SALU_CYCLE_3 = 11,
};
class Builder {
public:
struct Result {
Instruction *instr;
Result(Instruction *instr_) : instr(instr_) {}
operator Instruction *() const {
return instr;
}
operator Temp() const {
return instr->definitions[0].getTemp();
}
operator Operand() const {
return Operand((Temp)*this);
}
Definition& def(unsigned index) const {
return instr->definitions[index];
}
aco_ptr<Instruction> get_ptr() const {
return aco_ptr<Instruction>(instr);
}
Instruction * operator * () const {
return instr;
}
Instruction * operator -> () const {
return instr;
}
};
struct Op {
Operand op;
Op(Temp tmp) : op(tmp) {}
Op(Operand op_) : op(op_) {}
Op(Result res) : op((Temp)res) {}
};
enum WaveSpecificOpcode {
s_cselect = (unsigned) aco_opcode::s_cselect_b64,
s_cmp_lg = (unsigned) aco_opcode::s_cmp_lg_u64,
s_and = (unsigned) aco_opcode::s_and_b64,
s_andn2 = (unsigned) aco_opcode::s_andn2_b64,
s_or = (unsigned) aco_opcode::s_or_b64,
s_orn2 = (unsigned) aco_opcode::s_orn2_b64,
s_not = (unsigned) aco_opcode::s_not_b64,
s_mov = (unsigned) aco_opcode::s_mov_b64,
s_wqm = (unsigned) aco_opcode::s_wqm_b64,
s_and_saveexec = (unsigned) aco_opcode::s_and_saveexec_b64,
s_or_saveexec = (unsigned) aco_opcode::s_or_saveexec_b64,
s_andn2_wrexec = (unsigned) aco_opcode::s_andn2_wrexec_b64,
s_xnor = (unsigned) aco_opcode::s_xnor_b64,
s_xor = (unsigned) aco_opcode::s_xor_b64,
s_bcnt1_i32 = (unsigned) aco_opcode::s_bcnt1_i32_b64,
s_bitcmp1 = (unsigned) aco_opcode::s_bitcmp1_b64,
s_ff1_i32 = (unsigned) aco_opcode::s_ff1_i32_b64,
s_flbit_i32 = (unsigned) aco_opcode::s_flbit_i32_b64,
s_lshl = (unsigned) aco_opcode::s_lshl_b64,
};
Program *program;
bool use_iterator;
bool start; // only when use_iterator == false
RegClass lm;
std::vector<aco_ptr<Instruction>> *instructions;
std::vector<aco_ptr<Instruction>>::iterator it;
bool is_precise = false;
bool is_sz_preserve = false;
bool is_inf_preserve = false;
bool is_nan_preserve = false;
bool is_nuw = false;
Builder(Program *pgm) : program(pgm), use_iterator(false), start(false), lm(pgm ? pgm->lane_mask : s2), instructions(NULL) {}
Builder(Program *pgm, Block *block) : program(pgm), use_iterator(false), start(false), lm(pgm ? pgm->lane_mask : s2), instructions(&block->instructions) {}
Builder(Program *pgm, std::vector<aco_ptr<Instruction>> *instrs) : program(pgm), use_iterator(false), start(false), lm(pgm ? pgm->lane_mask : s2), instructions(instrs) {}
Builder precise() const {
Builder res = *this;
res.is_precise = true;
return res;
};
Builder nuw() const {
Builder res = *this;
res.is_nuw = true;
return res;
}
void moveEnd(Block *block) {
instructions = &block->instructions;
}
void reset() {
use_iterator = false;
start = false;
instructions = NULL;
}
void reset(Block *block) {
use_iterator = false;
start = false;
instructions = &block->instructions;
}
void reset(std::vector<aco_ptr<Instruction>> *instrs) {
use_iterator = false;
start = false;
instructions = instrs;
}
void reset(std::vector<aco_ptr<Instruction>> *instrs, std::vector<aco_ptr<Instruction>>::iterator instr_it) {
use_iterator = true;
start = false;
instructions = instrs;
it = instr_it;
}
Result insert(aco_ptr<Instruction> instr) {
Instruction *instr_ptr = instr.get();
if (instructions) {
if (use_iterator) {
it = instructions->emplace(it, std::move(instr));
it = std::next(it);
} else if (!start) {
instructions->emplace_back(std::move(instr));
} else {
instructions->emplace(instructions->begin(), std::move(instr));
}
}
return Result(instr_ptr);
}
Result insert(Instruction* instr) {
if (instructions) {
if (use_iterator) {
it = instructions->emplace(it, aco_ptr<Instruction>(instr));
it = std::next(it);
} else if (!start) {
instructions->emplace_back(aco_ptr<Instruction>(instr));
} else {
instructions->emplace(instructions->begin(), aco_ptr<Instruction>(instr));
}
}
return Result(instr);
}
Temp tmp(RegClass rc) {
return program->allocateTmp(rc);
}
Temp tmp(RegType type, unsigned size) {
return tmp(RegClass(type, size));
}
Definition def(RegClass rc) {
return Definition(program->allocateTmp(rc));
}
Definition def(RegType type, unsigned size) {
return def(RegClass(type, size));
}
Definition def(RegClass rc, PhysReg reg) {
return Definition(tmp(rc), reg);
}
inline aco_opcode w64(WaveSpecificOpcode opcode) const {
return (aco_opcode) opcode;
}
inline aco_opcode w32(WaveSpecificOpcode opcode) const {
switch (opcode) {
case s_cselect:
return aco_opcode::s_cselect_b32;
case s_cmp_lg:
return aco_opcode::s_cmp_lg_u32;
case s_and:
return aco_opcode::s_and_b32;
case s_andn2:
return aco_opcode::s_andn2_b32;
case s_or:
return aco_opcode::s_or_b32;
case s_orn2:
return aco_opcode::s_orn2_b32;
case s_not:
return aco_opcode::s_not_b32;
case s_mov:
return aco_opcode::s_mov_b32;
case s_wqm:
return aco_opcode::s_wqm_b32;
case s_and_saveexec:
return aco_opcode::s_and_saveexec_b32;
case s_or_saveexec:
return aco_opcode::s_or_saveexec_b32;
case s_andn2_wrexec:
return aco_opcode::s_andn2_wrexec_b32;
case s_xnor:
return aco_opcode::s_xnor_b32;
case s_xor:
return aco_opcode::s_xor_b32;
case s_bcnt1_i32:
return aco_opcode::s_bcnt1_i32_b32;
case s_bitcmp1:
return aco_opcode::s_bitcmp1_b32;
case s_ff1_i32:
return aco_opcode::s_ff1_i32_b32;
case s_flbit_i32:
return aco_opcode::s_flbit_i32_b32;
case s_lshl:
return aco_opcode::s_lshl_b32;
default:
UNREACHABLE("Unsupported wave specific opcode.");
}
}
inline aco_opcode w64or32(WaveSpecificOpcode opcode) const {
if (program->wave_size == 64)
return w64(opcode);
else
return w32(opcode);
}
% for fixed in ['m0', 'vcc', 'exec', 'scc']:
Operand ${fixed}(Temp tmp) {
% if fixed == 'vcc' or fixed == 'exec':
//vcc_hi and exec_hi can still be used in wave32
assert(tmp.type() == RegType::sgpr && tmp.bytes() <= 8);
% endif
Operand op(tmp);
op.setPrecolored(aco::${fixed});
return op;
}
Definition ${fixed}(Definition def) {
% if fixed == 'vcc' or fixed == 'exec':
//vcc_hi and exec_hi can still be used in wave32
assert(def.regClass().type() == RegType::sgpr && def.bytes() <= 8);
% endif
def.setPrecolored(aco::${fixed});
return def;
}
% endfor
Operand set16bit(Operand op) {
op.set16bit(true);
return op;
}
Operand set24bit(Operand op) {
op.set24bit(true);
return op;
}
/* hand-written helpers */
Temp as_uniform(Op op)
{
assert(op.op.isTemp());
if (op.op.getTemp().type() == RegType::vgpr)
return pseudo(aco_opcode::p_as_uniform, def(RegType::sgpr, op.op.size()), op);
else
return op.op.getTemp();
}
Result v_mul_imm(Definition dst, Temp tmp, uint32_t imm, bool tmpu24=false, bool tmpi24=false)
{
assert(tmp.type() == RegType::vgpr);
/* Assume 24bit if high 8 bits of tmp don't impact the result. */
if ((imm & 0xff) == 0) {
tmpu24 = true;
tmpi24 = true;
}
tmpu24 &= imm <= 0xffffffu;
tmpi24 &= imm <= 0x7fffffu || imm >= 0xff800000u;
bool has_lshl_add = program->gfx_level >= GFX9;
/* v_mul_lo_u32 has 1.6x the latency of most VALU on GFX10 (8 vs 5 cycles),
* compared to 4x the latency on <GFX10. */
unsigned mul_cost = program->gfx_level >= GFX10 ? 1 : (4 + Operand::c32(imm).isLiteral());
if (imm == 0) {
return copy(dst, Operand::zero());
} else if (imm == 1) {
return copy(dst, Operand(tmp));
} else if (imm == 0xffffffff) {
return vsub32(dst, Operand::zero(), tmp);
} else if (util_is_power_of_two_or_zero(imm)) {
return vop2(aco_opcode::v_lshlrev_b32, dst, Operand::c32(ffs(imm) - 1u), tmp);
} else if (tmpu24) {
return vop2(aco_opcode::v_mul_u32_u24, dst, Operand::c32(imm), tmp);
} else if (tmpi24) {
return vop2(aco_opcode::v_mul_i32_i24, dst, Operand::c32(imm), tmp);
} else if (util_is_power_of_two_nonzero(imm - 1u)) {
return vadd32(dst, vop2(aco_opcode::v_lshlrev_b32, def(v1), Operand::c32(ffs(imm - 1u) - 1u), tmp), tmp);
} else if (mul_cost > 2 && util_is_power_of_two_nonzero(imm + 1u)) {
return vsub32(dst, vop2(aco_opcode::v_lshlrev_b32, def(v1), Operand::c32(ffs(imm + 1u) - 1u), tmp), tmp);
}
unsigned instrs_required = util_bitcount(imm);
if (!has_lshl_add) {
instrs_required = util_bitcount(imm) - (imm & 0x1); /* shifts */
instrs_required += util_bitcount(imm) - 1; /* additions */
}
if (instrs_required < mul_cost) {
Result res(NULL);
Temp cur;
while (imm) {
unsigned shift = u_bit_scan(&imm);
Definition tmp_dst = imm ? def(v1) : dst;
if (shift && cur.id())
res = vadd32(Definition(tmp_dst), vop2(aco_opcode::v_lshlrev_b32, def(v1), Operand::c32(shift), tmp), cur);
else if (shift)
res = vop2(aco_opcode::v_lshlrev_b32, Definition(tmp_dst), Operand::c32(shift), tmp);
else if (cur.id())
res = vadd32(Definition(tmp_dst), tmp, cur);
else
tmp_dst = Definition(tmp);
cur = tmp_dst.getTemp();
}
return res;
}
Temp imm_tmp = copy(def(s1), Operand::c32(imm));
return vop3(aco_opcode::v_mul_lo_u32, dst, imm_tmp, tmp);
}
Result v_mul24_imm(Definition dst, Temp tmp, uint32_t imm)
{
return v_mul_imm(dst, tmp, imm & 0xffffffu, true);
}
Result copy(Definition dst, Op op) {
return pseudo(aco_opcode::p_parallelcopy, dst, op);
}
Result vadd32(Definition dst, Op a, Op b, bool carry_out=false, Op carry_in=Op(Operand(s2)), bool post_ra=false) {
if (b.op.isConstant() || b.op.regClass().type() != RegType::vgpr)
std::swap(a, b);
if (!post_ra && (!b.op.hasRegClass() || b.op.regClass().type() == RegType::sgpr))
b = copy(def(v1), b);
if (!carry_in.op.isUndefined())
return vop2(aco_opcode::v_addc_co_u32, Definition(dst), def(lm), a, b, carry_in);
else if (program->gfx_level >= GFX10 && carry_out)
return vop3(aco_opcode::v_add_co_u32_e64, Definition(dst), def(lm), a, b);
else if (program->gfx_level < GFX9 || carry_out)
return vop2(aco_opcode::v_add_co_u32, Definition(dst), def(lm), a, b);
else
return vop2(aco_opcode::v_add_u32, Definition(dst), a, b);
}
Result vsub32(Definition dst, Op a, Op b, bool carry_out=false, Op borrow=Op(Operand(s2)))
{
if (!borrow.op.isUndefined() || program->gfx_level < GFX9)
carry_out = true;
bool reverse = !b.op.isTemp() || b.op.regClass().type() != RegType::vgpr;
if (reverse)
std::swap(a, b);
if (!b.op.hasRegClass() || b.op.regClass().type() == RegType::sgpr)
b = copy(def(v1), b);
aco_opcode op;
Temp carry;
if (carry_out) {
carry = tmp(lm);
if (borrow.op.isUndefined())
op = reverse ? aco_opcode::v_subrev_co_u32 : aco_opcode::v_sub_co_u32;
else
op = reverse ? aco_opcode::v_subbrev_co_u32 : aco_opcode::v_subb_co_u32;
} else {
op = reverse ? aco_opcode::v_subrev_u32 : aco_opcode::v_sub_u32;
}
bool vop3 = false;
if (program->gfx_level >= GFX10 && op == aco_opcode::v_subrev_co_u32) {
vop3 = true;
op = aco_opcode::v_subrev_co_u32_e64;
} else if (program->gfx_level >= GFX10 && op == aco_opcode::v_sub_co_u32) {
vop3 = true;
op = aco_opcode::v_sub_co_u32_e64;
}
int num_ops = borrow.op.isUndefined() ? 2 : 3;
int num_defs = carry_out ? 2 : 1;
aco_ptr<Instruction> sub;
if (vop3)
sub.reset(create_instruction(op, Format::VOP3, num_ops, num_defs));
else
sub.reset(create_instruction(op, Format::VOP2, num_ops, num_defs));
sub->operands[0] = a.op;
sub->operands[1] = b.op;
if (!borrow.op.isUndefined())
sub->operands[2] = borrow.op;
sub->definitions[0] = dst;
if (carry_out)
sub->definitions[1] = Definition(carry);
return insert(std::move(sub));
}
Result readlane(Definition dst, Op vsrc, Op lane)
{
if (program->gfx_level >= GFX8)
return vop3(aco_opcode::v_readlane_b32_e64, dst, vsrc, lane);
else
return vop2(aco_opcode::v_readlane_b32, dst, vsrc, lane);
}
Result writelane(Definition dst, Op val, Op lane, Op vsrc) {
if (program->gfx_level >= GFX8)
return vop3(aco_opcode::v_writelane_b32_e64, dst, val, lane, vsrc);
else
return vop2(aco_opcode::v_writelane_b32, dst, val, lane, vsrc);
}
<%
import itertools
formats = [("pseudo", [Format.PSEUDO], list(itertools.product(range(5), range(7))) + [(8, 1), (1, 8), (1, 7)]),
("sop1", [Format.SOP1], [(0, 1), (1, 0), (1, 1), (1, 2), (2, 1), (3, 2)]),
("sop2", [Format.SOP2], itertools.product([1, 2], [2, 3])),
("sopk", [Format.SOPK], itertools.product([0, 1, 2], [0, 1])),
("sopp", [Format.SOPP], [(0, 0), (0, 1)]),
("sopc", [Format.SOPC], [(1, 2)]),
("smem", [Format.SMEM], [(0, 4), (0, 3), (1, 0), (1, 3), (1, 2), (1, 1), (0, 0)]),
("ds", [Format.DS], [(1, 0), (1, 1), (1, 2), (1, 3), (0, 2), (0, 3), (0, 4), (2, 3)]),
("ldsdir", [Format.LDSDIR], [(1, 1)]),
("mubuf", [Format.MUBUF], [(0, 4), (1, 3), (1, 4)]),
("mtbuf", [Format.MTBUF], [(0, 4), (1, 3)]),
("mimg", [Format.MIMG], list(itertools.product([0, 1], [3, 4, 5, 6, 7])) + [(3, 8)] + [(3, 14)]),
("exp", [Format.EXP], [(0, 4), (0, 5)]),
("branch", [Format.PSEUDO_BRANCH], [(0, 0), (0, 1)]),
("barrier", [Format.PSEUDO_BARRIER], [(0, 0)]),
("reduction", [Format.PSEUDO_REDUCTION], [(3, 3)]),
("call", [Format.PSEUDO_CALL], [(0, 0)]),
("vop1", [Format.VOP1], [(0, 0), (1, 1), (1, 2), (2, 2)]),
("vop1_sdwa", [Format.VOP1, Format.SDWA], [(1, 1)]),
("vop2", [Format.VOP2], itertools.product([1, 2], [2, 3])),
("vop2_sdwa", [Format.VOP2, Format.SDWA], itertools.product([1, 2], [2, 3])),
("vopc", [Format.VOPC], itertools.product([1, 2], [2])),
("vopc_sdwa", [Format.VOPC, Format.SDWA], itertools.product([1, 2], [2])),
("vop3", [Format.VOP3], [(1, 3), (1, 2), (1, 1), (2, 2)]),
("vop3p", [Format.VOP3P], [(1, 2), (1, 3)]),
("vopd", [Format.VOPD], [(2, 2), (2, 3), (2, 4), (2, 5), (2, 6)]),
("vinterp_inreg", [Format.VINTERP_INREG], [(1, 3)]),
("vintrp", [Format.VINTRP], [(1, 2), (1, 3)]),
("vop1_dpp", [Format.VOP1, Format.DPP16], [(1, 1)]),
("vop2_dpp", [Format.VOP2, Format.DPP16], itertools.product([1, 2], [2, 3])),
("vopc_dpp", [Format.VOPC, Format.DPP16], itertools.product([1, 2], [2])),
("vop3_dpp", [Format.VOP3, Format.DPP16], [(1, 3), (1, 2), (1, 1), (2, 2)]),
("vop3p_dpp", [Format.VOP3P, Format.DPP16], [(1, 2), (1, 3)]),
("vop1_dpp8", [Format.VOP1, Format.DPP8], [(1, 1)]),
("vop2_dpp8", [Format.VOP2, Format.DPP8], itertools.product([1, 2], [2, 3])),
("vopc_dpp8", [Format.VOPC, Format.DPP8], itertools.product([1, 2], [2])),
("vop3_dpp8", [Format.VOP3, Format.DPP8], [(1, 3), (1, 2), (1, 1), (2, 2)]),
("vop3p_dpp8", [Format.VOP3P, Format.DPP8], [(1, 2), (1, 3)]),
("vop1_e64", [Format.VOP1, Format.VOP3], itertools.product([1], [1])),
("vop2_e64", [Format.VOP2, Format.VOP3], itertools.product([1, 2], [2, 3])),
("vopc_e64", [Format.VOPC, Format.VOP3], itertools.product([1, 2], [2])),
("vop1_e64_dpp", [Format.VOP1, Format.VOP3, Format.DPP16], itertools.product([1], [1])),
("vop2_e64_dpp", [Format.VOP2, Format.VOP3, Format.DPP16], itertools.product([1, 2], [2, 3])),
("vopc_e64_dpp", [Format.VOPC, Format.VOP3, Format.DPP16], itertools.product([1, 2], [2])),
("vop1_e64_dpp8", [Format.VOP1, Format.VOP3, Format.DPP8], itertools.product([1], [1])),
("vop2_e64_dpp8", [Format.VOP2, Format.VOP3, Format.DPP8], itertools.product([1, 2], [2, 3])),
("vopc_e64_dpp8", [Format.VOPC, Format.VOP3, Format.DPP8], itertools.product([1, 2], [2])),
("flat", [Format.FLAT], [(0, 3), (1, 2), (1, 3)]),
("global", [Format.GLOBAL], [(0, 3), (1, 2), (1, 3)]),
("scratch", [Format.SCRATCH], [(0, 3), (1, 2), (1, 3)])]
formats = [(f if len(f) == 5 else f + ('',)) for f in formats]
%>\\
% for name, formats, shapes, extra_field_setup in formats:
% for num_definitions, num_operands in shapes:
<%
args = ['aco_opcode opcode']
has_disable_wqm = False
for i in range(num_definitions):
args.append('Definition def%d' % i)
for i in range(num_operands):
args.append('Op op%d' % i)
for f in formats:
args += f.get_builder_field_decls()
has_disable_wqm |= f.has_disable_wqm()
%>\\
Result ${name}(${', '.join(args)})
{
unsigned num_ops = ${num_operands};
% if has_disable_wqm:
num_ops += disable_wqm * 2;
%endif
Instruction* instr = create_instruction(opcode, (Format)(${'|'.join('(int)Format::%s' % f.name for f in formats)}), num_ops, ${num_definitions});
% for i in range(num_definitions):
instr->definitions[${i}] = def${i};
instr->definitions[${i}].setPrecise(is_precise);
instr->definitions[${i}].setSZPreserve(is_sz_preserve);
instr->definitions[${i}].setInfPreserve(is_inf_preserve);
instr->definitions[${i}].setNaNPreserve(is_nan_preserve);
instr->definitions[${i}].setNUW(is_nuw);
% endfor
% for i in range(num_operands):
instr->operands[${i}] = op${i}.op;
% endfor
% if has_disable_wqm:
if (disable_wqm) {
instr_exact_mask(instr) = Operand();
instr_wqm_mask(instr) = Operand();
}
%endif
% for f in formats:
% for dest, field_name in zip(f.get_builder_field_dests(), f.get_builder_field_names()):
instr->${f.get_accessor()}().${dest} = ${field_name};
% endfor
${f.get_builder_initialization(num_operands)}
% endfor
${extra_field_setup}
return insert(instr);
}
% if name == 'sop1' or name == 'sop2' or name == 'sopc':
<%
args[0] = 'WaveSpecificOpcode opcode'
params = []
for i in range(num_definitions):
params.append('def%d' % i)
for i in range(num_operands):
params.append('op%d' % i)
%>\\
inline Result ${name}(${', '.join(args)})
{
return ${name}(w64or32(opcode), ${', '.join(params)});
}
% endif
% endfor
% endfor
};
void hw_init_scratch(Builder& bld, Definition def, Operand scratch_addr, Operand scratch_offset);
} // namespace aco
#endif /* _ACO_BUILDER_ */"""
from aco_opcodes import Format
from mako.template import Template
print(Template(template).render(Format=Format))