diff --git a/src/amd/compiler/aco_register_allocation.cpp b/src/amd/compiler/aco_register_allocation.cpp index 3e7e87e4ad2..6ba71eb34b6 100644 --- a/src/amd/compiler/aco_register_allocation.cpp +++ b/src/amd/compiler/aco_register_allocation.cpp @@ -43,8 +43,7 @@ void add_subdword_operand(ra_ctx& ctx, aco_ptr& instr, unsigned idx RegClass rc); std::pair get_subdword_definition_info(Program* program, const aco_ptr& instr, RegClass rc); -void add_subdword_definition(Program* program, aco_ptr& instr, unsigned idx, - PhysReg reg); +void add_subdword_definition(Program* program, aco_ptr& instr, PhysReg reg); struct assignment { PhysReg reg; @@ -565,99 +564,124 @@ get_subdword_definition_info(Program* program, const aco_ptr& instr { chip_class chip = program->chip_class; - if (instr->isPseudo() && chip >= GFX8) - return std::make_pair(rc.bytes() % 2 == 0 ? 2 : 1, rc.bytes()); - else if (instr->isPseudo()) - return std::make_pair(4, rc.size() * 4u); - - unsigned bytes_written = chip >= GFX10 ? rc.bytes() : 4u; - switch (instr->opcode) { - case aco_opcode::v_mad_f16: - case aco_opcode::v_mad_u16: - case aco_opcode::v_mad_i16: - case aco_opcode::v_fma_f16: - case aco_opcode::v_div_fixup_f16: - case aco_opcode::v_interp_p2_f16: bytes_written = chip >= GFX9 ? rc.bytes() : 4u; break; - default: break; + if (instr->isPseudo()) { + if (chip >= GFX8) + return std::make_pair(rc.bytes() % 2 == 0 ? 2 : 1, rc.bytes()); + else + return std::make_pair(4, rc.size() * 4u); } - bytes_written = bytes_written > 4 ? align(bytes_written, 4) : bytes_written; - bytes_written = MAX2(bytes_written, instr_info.definition_size[(int)instr->opcode] / 8u); - if (can_use_SDWA(chip, instr, false)) { - return std::make_pair(rc.bytes(), rc.bytes()); - } else if (rc.bytes() == 2 && can_use_opsel(chip, instr->opcode, -1, 1)) { - return std::make_pair(2u, bytes_written); + if (instr->isVALU() || instr->isVINTRP()) { + assert(rc.bytes() <= 2); + + if (can_use_SDWA(chip, instr, false)) + return std::make_pair(rc.bytes(), rc.bytes()); + + unsigned bytes_written = 4u; + if (instr_is_16bit(chip, instr->opcode)) + bytes_written = 2u; + + unsigned stride = 4u; + if (instr->opcode == aco_opcode::v_fma_mixlo_f16 || + can_use_opsel(chip, instr->opcode, -1, true)) + stride = 2u; + + return std::make_pair(stride, bytes_written); } switch (instr->opcode) { - case aco_opcode::buffer_load_ubyte_d16: - case aco_opcode::buffer_load_short_d16: - case aco_opcode::flat_load_ubyte_d16: - case aco_opcode::flat_load_short_d16: - case aco_opcode::scratch_load_ubyte_d16: - case aco_opcode::scratch_load_short_d16: - case aco_opcode::global_load_ubyte_d16: - case aco_opcode::global_load_short_d16: case aco_opcode::ds_read_u8_d16: + case aco_opcode::ds_read_i8_d16: case aco_opcode::ds_read_u16_d16: - if (chip >= GFX9 && !program->dev.sram_ecc_enabled) + case aco_opcode::flat_load_ubyte_d16: + case aco_opcode::flat_load_sbyte_d16: + case aco_opcode::flat_load_short_d16: + case aco_opcode::global_load_ubyte_d16: + case aco_opcode::global_load_sbyte_d16: + case aco_opcode::global_load_short_d16: + case aco_opcode::scratch_load_ubyte_d16: + case aco_opcode::scratch_load_sbyte_d16: + case aco_opcode::scratch_load_short_d16: + case aco_opcode::buffer_load_ubyte_d16: + case aco_opcode::buffer_load_sbyte_d16: + case aco_opcode::buffer_load_short_d16: { + assert(chip >= GFX9); + if (!program->dev.sram_ecc_enabled) return std::make_pair(2u, 2u); else return std::make_pair(2u, 4u); - case aco_opcode::v_fma_mixlo_f16: return std::make_pair(2u, 2u); - default: break; } - return std::make_pair(4u, bytes_written); + default: return std::make_pair(4, rc.size() * 4u); + } } void -add_subdword_definition(Program* program, aco_ptr& instr, unsigned idx, PhysReg reg) +add_subdword_definition(Program* program, aco_ptr& instr, PhysReg reg) { - RegClass rc = instr->definitions[idx].regClass(); - chip_class chip = program->chip_class; + if (instr->isPseudo()) + return; - if (instr->isPseudo()) { - return; - } else if (can_use_SDWA(chip, instr, false)) { - unsigned def_size = instr_info.definition_size[(int)instr->opcode]; - if (reg.byte() || chip < GFX10 || def_size > rc.bytes() * 8u) - convert_to_SDWA(chip, instr); - return; - } else if (reg.byte() && rc.bytes() == 2 && - can_use_opsel(chip, instr->opcode, -1, reg.byte() / 2)) { - VOP3_instruction& vop3 = instr->vop3(); - if (reg.byte() == 2) - vop3.opsel |= (1 << 3); /* dst in high half */ - return; - } + if (instr->isVALU()) { + chip_class chip = program->chip_class; + assert(instr->definitions[0].bytes() <= 2); - if (reg.byte() == 2) { - if (instr->opcode == aco_opcode::v_fma_mixlo_f16) + if (reg.byte() == 0 && instr_is_16bit(chip, instr->opcode)) + return; + + /* check if we can use opsel */ + if (instr->format == Format::VOP3) { + assert(reg.byte() == 2); + assert(can_use_opsel(chip, instr->opcode, -1, true)); + instr->vop3().opsel |= (1 << 3); /* dst in high half */ + return; + } + + if (instr->opcode == aco_opcode::v_fma_mixlo_f16) { instr->opcode = aco_opcode::v_fma_mixhi_f16; - else if (instr->opcode == aco_opcode::buffer_load_ubyte_d16) - instr->opcode = aco_opcode::buffer_load_ubyte_d16_hi; - else if (instr->opcode == aco_opcode::buffer_load_short_d16) - instr->opcode = aco_opcode::buffer_load_short_d16_hi; - else if (instr->opcode == aco_opcode::flat_load_ubyte_d16) - instr->opcode = aco_opcode::flat_load_ubyte_d16_hi; - else if (instr->opcode == aco_opcode::flat_load_short_d16) - instr->opcode = aco_opcode::flat_load_short_d16_hi; - else if (instr->opcode == aco_opcode::scratch_load_ubyte_d16) - instr->opcode = aco_opcode::scratch_load_ubyte_d16_hi; - else if (instr->opcode == aco_opcode::scratch_load_short_d16) - instr->opcode = aco_opcode::scratch_load_short_d16_hi; - else if (instr->opcode == aco_opcode::global_load_ubyte_d16) - instr->opcode = aco_opcode::global_load_ubyte_d16_hi; - else if (instr->opcode == aco_opcode::global_load_short_d16) - instr->opcode = aco_opcode::global_load_short_d16_hi; - else if (instr->opcode == aco_opcode::ds_read_u8_d16) - instr->opcode = aco_opcode::ds_read_u8_d16_hi; - else if (instr->opcode == aco_opcode::ds_read_u16_d16) - instr->opcode = aco_opcode::ds_read_u16_d16_hi; - else - unreachable("Something went wrong: Impossible register assignment."); + return; + } + + /* use SDWA */ + assert(can_use_SDWA(chip, instr, false)); + convert_to_SDWA(chip, instr); + return; } + + if (reg.byte() == 0) + return; + else if (instr->opcode == aco_opcode::buffer_load_ubyte_d16) + instr->opcode = aco_opcode::buffer_load_ubyte_d16_hi; + else if (instr->opcode == aco_opcode::buffer_load_sbyte_d16) + instr->opcode = aco_opcode::buffer_load_sbyte_d16_hi; + else if (instr->opcode == aco_opcode::buffer_load_short_d16) + instr->opcode = aco_opcode::buffer_load_short_d16_hi; + else if (instr->opcode == aco_opcode::flat_load_ubyte_d16) + instr->opcode = aco_opcode::flat_load_ubyte_d16_hi; + else if (instr->opcode == aco_opcode::flat_load_sbyte_d16) + instr->opcode = aco_opcode::flat_load_sbyte_d16_hi; + else if (instr->opcode == aco_opcode::flat_load_short_d16) + instr->opcode = aco_opcode::flat_load_short_d16_hi; + else if (instr->opcode == aco_opcode::scratch_load_ubyte_d16) + instr->opcode = aco_opcode::scratch_load_ubyte_d16_hi; + else if (instr->opcode == aco_opcode::scratch_load_sbyte_d16) + instr->opcode = aco_opcode::scratch_load_sbyte_d16_hi; + else if (instr->opcode == aco_opcode::scratch_load_short_d16) + instr->opcode = aco_opcode::scratch_load_short_d16_hi; + else if (instr->opcode == aco_opcode::global_load_ubyte_d16) + instr->opcode = aco_opcode::global_load_ubyte_d16_hi; + else if (instr->opcode == aco_opcode::global_load_sbyte_d16) + instr->opcode = aco_opcode::global_load_sbyte_d16_hi; + else if (instr->opcode == aco_opcode::global_load_short_d16) + instr->opcode = aco_opcode::global_load_short_d16_hi; + else if (instr->opcode == aco_opcode::ds_read_u8_d16) + instr->opcode = aco_opcode::ds_read_u8_d16_hi; + else if (instr->opcode == aco_opcode::ds_read_i8_d16) + instr->opcode = aco_opcode::ds_read_i8_d16_hi; + else if (instr->opcode == aco_opcode::ds_read_u16_d16) + instr->opcode = aco_opcode::ds_read_u16_d16_hi; + else + unreachable("Something went wrong: Impossible register assignment."); } void @@ -2576,7 +2600,7 @@ register_allocation(Program* program, std::vector& live_out_per_block, ra PhysReg reg = get_reg(ctx, register_file, tmp, parallelcopy, instr); definition->setFixed(reg); if (reg.byte() || register_file.test(reg, 4)) { - add_subdword_definition(program, instr, i, reg); + add_subdword_definition(program, instr, reg); definition = &instr->definitions[i]; /* add_subdword_definition can invalidate the reference */ }