diff --git a/src/amd/compiler/aco_instruction_selection.cpp b/src/amd/compiler/aco_instruction_selection.cpp index b92ab49feeb..c7e881926f2 100644 --- a/src/amd/compiler/aco_instruction_selection.cpp +++ b/src/amd/compiler/aco_instruction_selection.cpp @@ -4502,6 +4502,21 @@ lds_load_callback(Builder& bld, const LoadEmitInfo& info, Temp offset, unsigned const EmitLoadParameters lds_load_params{lds_load_callback, UINT32_MAX}; +std::pair +get_smem_opcode(amd_gfx_level level, unsigned bytes, bool buffer, bool round_down) +{ + if (bytes <= (round_down ? 7 : 4)) + return {buffer ? aco_opcode::s_buffer_load_dword : aco_opcode::s_load_dword, 4}; + else if (bytes <= (round_down ? 15 : 8)) + return {buffer ? aco_opcode::s_buffer_load_dwordx2 : aco_opcode::s_load_dwordx2, 8}; + else if (bytes <= (round_down ? 31 : 16)) + return {buffer ? aco_opcode::s_buffer_load_dwordx4 : aco_opcode::s_load_dwordx4, 16}; + else if (bytes <= (round_down ? 63 : 32)) + return {buffer ? aco_opcode::s_buffer_load_dwordx8 : aco_opcode::s_load_dwordx8, 32}; + else + return {buffer ? aco_opcode::s_buffer_load_dwordx16 : aco_opcode::s_load_dwordx16, 64}; +} + Temp smem_load_callback(Builder& bld, const LoadEmitInfo& info, Temp offset, unsigned bytes_needed, unsigned align, unsigned const_offset, Temp dst_hint) @@ -4517,25 +4532,15 @@ smem_load_callback(Builder& bld, const LoadEmitInfo& info, Temp offset, unsigned offset = Temp(); } - bytes_needed = MIN2(bytes_needed, 64); - unsigned needed_round_up = util_next_power_of_two(bytes_needed); - unsigned needed_round_down = needed_round_up >> (needed_round_up != bytes_needed ? 1 : 0); - /* Only round-up global loads if it's aligned so that it won't cross pages */ - bytes_needed = buffer || align % needed_round_up == 0 ? needed_round_up : needed_round_down; + std::pair smaller = + get_smem_opcode(bld.program->gfx_level, bytes_needed, buffer, true); + std::pair larger = + get_smem_opcode(bld.program->gfx_level, bytes_needed, buffer, false); + /* Only round-up global loads if it's aligned so that it won't cross pages */ aco_opcode op; - if (bytes_needed <= 4) { - op = buffer ? aco_opcode::s_buffer_load_dword : aco_opcode::s_load_dword; - } else if (bytes_needed <= 8) { - op = buffer ? aco_opcode::s_buffer_load_dwordx2 : aco_opcode::s_load_dwordx2; - } else if (bytes_needed <= 16) { - op = buffer ? aco_opcode::s_buffer_load_dwordx4 : aco_opcode::s_load_dwordx4; - } else if (bytes_needed <= 32) { - op = buffer ? aco_opcode::s_buffer_load_dwordx8 : aco_opcode::s_load_dwordx8; - } else { - assert(bytes_needed == 64); - op = buffer ? aco_opcode::s_buffer_load_dwordx16 : aco_opcode::s_load_dwordx16; - } + std::tie(op, bytes_needed) = + buffer || (align % util_next_power_of_two(larger.second) == 0) ? larger : smaller; aco_ptr load{create_instruction(op, Format::SMEM, 2, 1)}; if (buffer) { @@ -7101,28 +7106,15 @@ visit_load_smem(isel_context* ctx, nir_intrinsic_instr* instr) Operand::c32(ctx->options->address32_hi)); } - aco_opcode opcode = aco_opcode::s_load_dword; - unsigned size = 1; - + aco_opcode opcode; + unsigned size; assert(dst.bytes() <= 64); + std::tie(opcode, size) = get_smem_opcode(ctx->program->gfx_level, dst.bytes(), false, false); - if (dst.bytes() > 32) { - opcode = aco_opcode::s_load_dwordx16; - size = 16; - } else if (dst.bytes() > 16) { - opcode = aco_opcode::s_load_dwordx8; - size = 8; - } else if (dst.bytes() > 8) { - opcode = aco_opcode::s_load_dwordx4; - size = 4; - } else if (dst.bytes() > 4) { - opcode = aco_opcode::s_load_dwordx2; - size = 2; - } - - if (dst.size() != size) { + if (dst.size() != DIV_ROUND_UP(size, 4)) { bld.pseudo(aco_opcode::p_extract_vector, Definition(dst), - bld.smem(opcode, bld.def(RegType::sgpr, size), base, offset), Operand::c32(0u)); + bld.smem(opcode, bld.def(RegClass::get(RegType::sgpr, size)), base, offset), + Operand::c32(0u)); } else { bld.smem(opcode, Definition(dst), base, offset); }