From 4249daeedd91767a1009bc3850342458d5556c8e Mon Sep 17 00:00:00 2001 From: Rhys Perry Date: Mon, 13 Apr 2026 16:05:36 +0100 Subject: [PATCH] aco: add helpers to get instruction subdword capabilities MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Rhys Perry Reviewed-by: Daniel Schürmann Part-of: --- src/amd/compiler/aco_ir.cpp | 163 +++++++++++++++++- src/amd/compiler/aco_ir.h | 12 ++ src/amd/compiler/aco_register_allocation.cpp | 168 ++----------------- 3 files changed, 183 insertions(+), 160 deletions(-) diff --git a/src/amd/compiler/aco_ir.cpp b/src/amd/compiler/aco_ir.cpp index 46ed5cbc61c..ff3785395a7 100644 --- a/src/amd/compiler/aco_ir.cpp +++ b/src/amd/compiler/aco_ir.cpp @@ -285,7 +285,7 @@ get_sync_info(const Instruction* instr) } bool -can_use_SDWA(amd_gfx_level gfx_level, const aco_ptr& instr, bool pre_ra) +can_use_SDWA(amd_gfx_level gfx_level, const Instruction* instr, bool pre_ra) { if (!instr->isVALU()) return false; @@ -297,7 +297,7 @@ can_use_SDWA(amd_gfx_level gfx_level, const aco_ptr& instr, bool pr return true; if (instr->isVOP3()) { - VALU_instruction& vop3 = instr->valu(); + const VALU_instruction& vop3 = instr->valu(); if (instr->format == Format::VOP3) return false; if (vop3.clamp && instr->isVOPC() && gfx_level != GFX8) @@ -351,6 +351,12 @@ can_use_SDWA(amd_gfx_level gfx_level, const aco_ptr& instr, bool pr instr->opcode != aco_opcode::v_clrexcp && instr->opcode != aco_opcode::v_swap_b32; } +bool +can_use_SDWA(amd_gfx_level gfx_level, const aco_ptr& instr, bool pre_ra) +{ + return can_use_SDWA(gfx_level, instr.get(), pre_ra); +} + /* updates "instr" and returns the old instruction (or NULL if no update was needed) */ aco_ptr convert_to_SDWA(amd_gfx_level gfx_level, aco_ptr& instr) @@ -836,6 +842,159 @@ get_gfx11_true16_mask(aco_opcode op) } } +unsigned +get_subdword_operand_stride(Program* program, const Instruction* instr, unsigned idx, RegClass rc) +{ + assert(rc.is_subdword()); + assert(program->gfx_level >= GFX8); + + /* Pseudo instructions are a bit special here, because instr might just be a dummy instruction, + * so we shouldn't try accessing it's operands. */ + if (instr->isPseudo()) { + /* v_readfirstlane_b32 cannot use SDWA */ + if (instr->opcode == aco_opcode::p_as_uniform || + instr->opcode == aco_opcode::p_permlane64_shared_vgpr) + return 4; + else + return rc.bytes() % 2 == 0 ? 2 : 1; + } + + assert(instr->operands[idx].regClass() == rc); + + if (rc.bytes() > 2) + return 4; + + if (instr->isVALU()) { + if (can_use_SDWA(program->gfx_level, instr, false)) + return rc.bytes(); + if (can_use_opsel(program->gfx_level, instr->opcode, idx)) + return 2; + if (instr->isVOP3P()) + return 2; + } + + switch (instr->opcode) { + case aco_opcode::v_mov_b32: + case aco_opcode::v_not_b32: + case aco_opcode::v_and_b32: + case aco_opcode::v_or_b32: + case aco_opcode::v_xor_b32: + case aco_opcode::v_cndmask_b32: + return program->gfx_level >= GFX11 && instr->definitions[0].bytes() <= 2 ? 2 : 4; + case aco_opcode::v_cvt_f32_ubyte0: return 1; + case aco_opcode::ds_write_b8: + case aco_opcode::ds_write_b16: + case aco_opcode::buffer_store_byte: + case aco_opcode::buffer_store_short: + case aco_opcode::buffer_store_format_d16_x: + case aco_opcode::flat_store_byte: + case aco_opcode::flat_store_short: + case aco_opcode::scratch_store_byte: + case aco_opcode::scratch_store_short: + case aco_opcode::global_store_byte: + case aco_opcode::global_store_short: return program->gfx_level >= GFX9 ? 2 : 4; + default: return 4; + } +} + +SubdwordCaps +get_subdword_definition_caps(Program* program, const Instruction* instr, unsigned idx, RegClass rc) +{ + amd_gfx_level gfx_level = program->gfx_level; + assert(rc.is_subdword()); + assert(gfx_level >= GFX8); + + SubdwordCaps caps; + caps.placement_stride = rc.bytes() % 2 == 0 ? 2 : 1; + caps.overwrite_bytes = rc.bytes(); + + /* Pseudo instructions are a bit special here, because instr might just be a dummy instruction, + * so we shouldn't try accessing it's definitions. Pseudo instructions are also the only ones + * with multiple definitions, and "idx" isn't always correct. */ + if (instr->isPseudo()) { + if (instr->opcode == aco_opcode::p_interp_gfx11 || + instr->opcode == aco_opcode::p_permlane64_shared_vgpr) { + caps.overwrite_bytes = rc.size() * 4; + caps.placement_stride = 4; + } + return caps; + } + + assert(instr->definitions[idx].regClass() == rc); + + if (instr->isVALU()) { + if (rc.bytes() == 3) { + caps.overwrite_bytes = 4; + caps.placement_stride = 4; + return caps; + } + assert(rc.bytes() <= 2); + + if (can_use_SDWA(gfx_level, instr, false)) + return caps; + + if ((instr->opcode == aco_opcode::v_cndmask_b32 || instr->opcode == aco_opcode::v_mov_b32 || + instr->opcode == aco_opcode::v_not_b32 || instr->opcode == aco_opcode::v_and_b32 || + instr->opcode == aco_opcode::v_or_b32 || instr->opcode == aco_opcode::v_xor_b32) && + program->gfx_level >= GFX11) { + /* Convert to 16bit opcode on demand. */ + caps.overwrite_bytes = 2; + caps.placement_stride = 2; + return caps; + } + + bool preserve = instr_is_16bit(gfx_level, instr->opcode); + bool can_write_hi16 = instr->opcode == aco_opcode::v_fma_mixlo_f16 || + instr->opcode == aco_opcode::p_v_fma_mixlo_f16_rtz || + can_use_opsel(gfx_level, instr->opcode, -1); + caps.placement_stride = can_write_hi16 ? 2 : 4; + caps.overwrite_bytes = preserve ? 2 : 4; + return caps; + } + + switch (instr->opcode) { + case aco_opcode::v_interp_p2_f16: assert(rc.bytes() == 2); return caps; + /* D16 loads with _hi version */ + case aco_opcode::ds_read_u8_d16: + case aco_opcode::ds_read_i8_d16: + case aco_opcode::ds_read_u16_d16: + case aco_opcode::flat_load_ubyte_d16: + case aco_opcode::flat_load_sbyte_d16: + case aco_opcode::flat_load_short_d16: + case aco_opcode::global_load_ubyte_d16: + case aco_opcode::global_load_sbyte_d16: + case aco_opcode::global_load_short_d16: + case aco_opcode::scratch_load_ubyte_d16: + case aco_opcode::scratch_load_sbyte_d16: + case aco_opcode::scratch_load_short_d16: + case aco_opcode::buffer_load_ubyte_d16: + case aco_opcode::buffer_load_sbyte_d16: + case aco_opcode::buffer_load_short_d16: + case aco_opcode::buffer_load_format_d16_x: { + assert(gfx_level >= GFX9); + caps.overwrite_bytes = program->dev.sram_ecc_enabled ? 4 : 2; + caps.placement_stride = 2; + return caps; + } + /* 3-component D16 loads */ + case aco_opcode::buffer_load_format_d16_xyz: + case aco_opcode::tbuffer_load_format_d16_xyz: { + assert(gfx_level >= GFX9); + caps.overwrite_bytes = program->dev.sram_ecc_enabled ? 8 : 6; + caps.placement_stride = 4; + return caps; + } + default: break; + } + + caps.placement_stride = 4; + if (instr->isMIMG() && instr->mimg().d16 && !program->dev.sram_ecc_enabled) + assert(gfx_level >= GFX9); + else + caps.overwrite_bytes = rc.size() * 4; + return caps; +} + uint32_t get_reduction_identity(ReduceOp op, unsigned idx) { diff --git a/src/amd/compiler/aco_ir.h b/src/amd/compiler/aco_ir.h index fef6fbf3a07..2f1c58f5c61 100644 --- a/src/amd/compiler/aco_ir.h +++ b/src/amd/compiler/aco_ir.h @@ -2052,7 +2052,19 @@ bool can_use_input_modifiers(amd_gfx_level gfx_level, aco_opcode op, int idx); bool can_use_opsel(amd_gfx_level gfx_level, aco_opcode op, int idx); bool instr_is_16bit(amd_gfx_level gfx_level, aco_opcode op); uint8_t get_gfx11_true16_mask(aco_opcode op); +bool can_use_SDWA(amd_gfx_level gfx_level, const Instruction* instr, bool pre_ra); bool can_use_SDWA(amd_gfx_level gfx_level, const aco_ptr& instr, bool pre_ra); + +struct SubdwordCaps { + unsigned placement_stride; + unsigned overwrite_bytes; +}; + +unsigned get_subdword_operand_stride(Program* program, const Instruction* instr, unsigned idx, + RegClass rc); +SubdwordCaps get_subdword_definition_caps(Program* program, const Instruction* instr, unsigned idx, + RegClass rc); + bool opcode_supports_dpp(amd_gfx_level gfx_level, aco_opcode opcode, bool vop3p); bool can_use_DPP(amd_gfx_level gfx_level, const aco_ptr& instr, bool dpp8); bool can_write_m0(const aco_ptr& instr); diff --git a/src/amd/compiler/aco_register_allocation.cpp b/src/amd/compiler/aco_register_allocation.cpp index b8c4f4c5850..92a717b5853 100644 --- a/src/amd/compiler/aco_register_allocation.cpp +++ b/src/amd/compiler/aco_register_allocation.cpp @@ -22,8 +22,6 @@ namespace { struct ra_ctx; struct DefInfo; -unsigned get_subdword_operand_stride(amd_gfx_level gfx_level, const aco_ptr& instr, - unsigned idx, RegClass rc); void add_subdword_operand(ra_ctx& ctx, aco_ptr& instr, unsigned idx, unsigned byte, RegClass rc); void add_subdword_definition(Program* program, aco_ptr& instr, PhysReg reg, @@ -289,9 +287,12 @@ struct DefInfo { if (rc.is_subdword() && operand >= 0) { /* stride in bytes */ - stride = get_subdword_operand_stride(ctx.program->gfx_level, instr, operand, rc); + stride = get_subdword_operand_stride(ctx.program, instr.get(), operand, rc); } else if (rc.is_subdword()) { - get_subdword_definition_info(ctx.program, instr); + SubdwordCaps caps = get_subdword_definition_caps(ctx.program, instr.get(), 0, rc); + stride = rc.bytes() == caps.overwrite_bytes ? caps.placement_stride : 4; + rc = rc.resize(caps.overwrite_bytes); + data_stride = caps.placement_stride; } else if (instr->isMIMG() && instr->mimg().d16 && ctx.program->gfx_level <= GFX9) { /* Workaround GFX9 hardware bug for D16 image instructions: FeatureImageGather4D16Bug * @@ -317,9 +318,6 @@ struct DefInfo { if (!data_stride) data_stride = stride; } - -private: - void get_subdword_definition_info(Program* program, const aco_ptr& instr); }; class RegisterFile { @@ -631,56 +629,6 @@ convert_bitwise_to_16bit(Instruction* instr) return true; } -unsigned -get_subdword_operand_stride(amd_gfx_level gfx_level, const aco_ptr& instr, - unsigned idx, RegClass rc) -{ - assert(gfx_level >= GFX8); - if (instr->isPseudo()) { - /* v_readfirstlane_b32 cannot use SDWA */ - if (instr->opcode == aco_opcode::p_as_uniform || - instr->opcode == aco_opcode::p_permlane64_shared_vgpr) - return 4; - else - return rc.bytes() % 2 == 0 ? 2 : 1; - } - - if (rc.bytes() > 2) - return 4; - - if (instr->isVALU()) { - if (can_use_SDWA(gfx_level, instr, false)) - return rc.bytes(); - if (can_use_opsel(gfx_level, instr->opcode, idx)) - return 2; - if (instr->isVOP3P()) - return 2; - } - - switch (instr->opcode) { - case aco_opcode::v_mov_b32: - case aco_opcode::v_not_b32: - case aco_opcode::v_and_b32: - case aco_opcode::v_or_b32: - case aco_opcode::v_xor_b32: - case aco_opcode::v_cndmask_b32: - return gfx_level >= GFX11 && instr->definitions[0].bytes() <= 2 ? 2 : 4; - case aco_opcode::v_cvt_f32_ubyte0: return 1; - case aco_opcode::ds_write_b8: - case aco_opcode::ds_write_b16: return gfx_level >= GFX9 ? 2 : 4; - case aco_opcode::buffer_store_byte: - case aco_opcode::buffer_store_short: - case aco_opcode::buffer_store_format_d16_x: - case aco_opcode::flat_store_byte: - case aco_opcode::flat_store_short: - case aco_opcode::scratch_store_byte: - case aco_opcode::scratch_store_short: - case aco_opcode::global_store_byte: - case aco_opcode::global_store_short: return gfx_level >= GFX9 ? 2 : 4; - default: return 4; - } -} - void add_subdword_operand(ra_ctx& ctx, aco_ptr& instr, unsigned idx, unsigned byte, RegClass rc) @@ -750,102 +698,6 @@ add_subdword_operand(ra_ctx& ctx, aco_ptr& instr, unsigned idx, uns return; } -void -DefInfo::get_subdword_definition_info(Program* program, const aco_ptr& instr) -{ - amd_gfx_level gfx_level = program->gfx_level; - assert(gfx_level >= GFX8); - - stride = rc.bytes() % 2 == 0 ? 2 : 1; - - if (instr->isPseudo()) { - if (instr->opcode == aco_opcode::p_interp_gfx11 || - instr->opcode == aco_opcode::p_permlane64_shared_vgpr) { - rc = RegClass(RegType::vgpr, rc.size()); - stride = 4; - } - return; - } - - if (instr->isVALU()) { - if (rc.bytes() == 3) { - rc = v1; - stride = 4; - return; - } - - if (can_use_SDWA(gfx_level, instr, false)) - return; - - rc = instr_is_16bit(gfx_level, instr->opcode) ? v2b : v1; - stride = 4; - if (instr->opcode == aco_opcode::v_fma_mixlo_f16 || - instr->opcode == aco_opcode::p_v_fma_mixlo_f16_rtz || - can_use_opsel(gfx_level, instr->opcode, -1)) { - data_stride = 2; - stride = rc == v2b ? 2 : stride; - } else if ((instr->opcode == aco_opcode::v_cndmask_b32 || - instr->opcode == aco_opcode::v_mov_b32 || - instr->opcode == aco_opcode::v_not_b32 || - instr->opcode == aco_opcode::v_and_b32 || instr->opcode == aco_opcode::v_or_b32 || - instr->opcode == aco_opcode::v_xor_b32) && - program->gfx_level >= GFX11) { - /* Convert to 16bit opcode on demand. */ - rc = v2b; - data_stride = 2; - stride = 2; - } - return; - } - - switch (instr->opcode) { - case aco_opcode::v_interp_p2_f16: return; - /* D16 loads with _hi version */ - case aco_opcode::ds_read_u8_d16: - case aco_opcode::ds_read_i8_d16: - case aco_opcode::ds_read_u16_d16: - case aco_opcode::flat_load_ubyte_d16: - case aco_opcode::flat_load_sbyte_d16: - case aco_opcode::flat_load_short_d16: - case aco_opcode::global_load_ubyte_d16: - case aco_opcode::global_load_sbyte_d16: - case aco_opcode::global_load_short_d16: - case aco_opcode::scratch_load_ubyte_d16: - case aco_opcode::scratch_load_sbyte_d16: - case aco_opcode::scratch_load_short_d16: - case aco_opcode::buffer_load_ubyte_d16: - case aco_opcode::buffer_load_sbyte_d16: - case aco_opcode::buffer_load_short_d16: - case aco_opcode::buffer_load_format_d16_x: { - assert(gfx_level >= GFX9); - if (program->dev.sram_ecc_enabled) { - rc = v1; - stride = 4; - data_stride = 2; - } else { - stride = 2; - } - return; - } - /* 3-component D16 loads */ - case aco_opcode::buffer_load_format_d16_xyz: - case aco_opcode::tbuffer_load_format_d16_xyz: { - assert(gfx_level >= GFX9); - stride = 4; - if (program->dev.sram_ecc_enabled) - rc = v2; - return; - } - default: break; - } - - stride = 4; - if (instr->isMIMG() && instr->mimg().d16 && !program->dev.sram_ecc_enabled) - assert(gfx_level >= GFX9); - else - rc = RegClass(RegType::vgpr, rc.size()); -} - void add_subdword_definition(Program* program, aco_ptr& instr, PhysReg reg, bool allow_16bit_write) @@ -2387,11 +2239,12 @@ handle_pseudo(ra_ctx& ctx, const RegisterFile& reg_file, Instruction* instr) } bool -operand_can_use_reg(amd_gfx_level gfx_level, aco_ptr& instr, unsigned idx, PhysReg reg, +operand_can_use_reg(ra_ctx& ctx, aco_ptr& instr, unsigned idx, PhysReg reg, RegClass rc) { + amd_gfx_level gfx_level = ctx.program->gfx_level; if (reg.byte()) { - unsigned stride = get_subdword_operand_stride(gfx_level, instr, idx, rc); + unsigned stride = get_subdword_operand_stride(ctx.program, instr.get(), idx, rc); if (reg.byte() % stride) return false; } @@ -3699,8 +3552,7 @@ undo_renames(ra_ctx& ctx, std::vector& parallelcopies, } bool use_original = !op.isPrecolored() && !op.isLateKill(); - use_original &= operand_can_use_reg(ctx.program->gfx_level, instr, i, copy.op.physReg(), - copy.op.regClass()); + use_original &= operand_can_use_reg(ctx, instr, i, copy.op.physReg(), copy.op.regClass()); if (use_original) { const PhysRegInterval copy_reg = {copy.op.physReg(), copy.op.size()}; @@ -4128,7 +3980,7 @@ register_allocation(Program* program, ra_test_policy policy) } PhysReg reg = ctx.assignments[operand.tempId()].reg; - if (operand_can_use_reg(program->gfx_level, instr, i, reg, operand.regClass())) + if (operand_can_use_reg(ctx, instr, i, reg, operand.regClass())) operand.setFixed(reg); else get_reg_for_operand(ctx, register_file, parallelcopy, instr, operand, i);