diff --git a/src/amd/compiler/aco_optimizer.cpp b/src/amd/compiler/aco_optimizer.cpp index 9e5c4a29dcd..e2c137b5db5 100644 --- a/src/amd/compiler/aco_optimizer.cpp +++ b/src/amd/compiler/aco_optimizer.cpp @@ -791,7 +791,7 @@ parse_base_offset(opt_ctx& ctx, Instruction* instr, unsigned op_index, Temp* bas } void -skip_smem_offset_align(opt_ctx& ctx, SMEM_instruction* smem) +skip_smem_offset_align(opt_ctx& ctx, SMEM_instruction* smem, uint32_t align) { bool soe = smem->operands.size() >= (!smem->definitions.empty() ? 3 : 4); if (soe && !smem->operands[1].isConstant()) @@ -808,10 +808,11 @@ skip_smem_offset_align(opt_ctx& ctx, SMEM_instruction* smem) if (bitwise_instr->opcode != aco_opcode::s_and_b32) return; - if (bitwise_instr->operands[0].constantEquals(-4) && + uint32_t mask = ~(align - 1u); + if (bitwise_instr->operands[0].constantEquals(mask) && bitwise_instr->operands[1].isOfType(op.regClass().type())) op.setTemp(bitwise_instr->operands[1].getTemp()); - else if (bitwise_instr->operands[1].constantEquals(-4) && + else if (bitwise_instr->operands[1].constantEquals(mask) && bitwise_instr->operands[0].isOfType(op.regClass().type())) op.setTemp(bitwise_instr->operands[0].getTemp()); } @@ -819,9 +820,22 @@ skip_smem_offset_align(opt_ctx& ctx, SMEM_instruction* smem) void smem_combine(opt_ctx& ctx, aco_ptr& instr) { + uint32_t align = 4; + switch (instr->opcode) { + case aco_opcode::s_load_sbyte: + case aco_opcode::s_load_ubyte: + case aco_opcode::s_buffer_load_sbyte: + case aco_opcode::s_buffer_load_ubyte: align = 1; break; + case aco_opcode::s_load_sshort: + case aco_opcode::s_load_ushort: + case aco_opcode::s_buffer_load_sshort: + case aco_opcode::s_buffer_load_ushort: align = 2; break; + default: break; + } + /* skip &-4 before offset additions: load((a + 16) & -4, 0) */ - if (!instr->operands.empty()) - skip_smem_offset_align(ctx, &instr->smem()); + if (!instr->operands.empty() && align > 1) + skip_smem_offset_align(ctx, &instr->smem(), align); /* propagate constants and combine additions */ if (!instr->operands.empty() && instr->operands[1].isTemp()) { @@ -834,7 +848,7 @@ smem_combine(opt_ctx& ctx, aco_ptr& instr) instr->operands[1] = Operand::c32(info.val); } else if (parse_base_offset(ctx, instr.get(), 1, &base, &offset, true) && base.regClass() == s1 && offset <= ctx.program->dev.smem_offset_max && - ctx.program->gfx_level >= GFX9 && offset % 4u == 0) { + ctx.program->gfx_level >= GFX9 && offset % align == 0) { bool soe = smem.operands.size() >= (!smem.definitions.empty() ? 3 : 4); if (soe) { if (ctx.info[smem.operands.back().tempId()].is_constant_or_literal(32) && @@ -860,8 +874,8 @@ smem_combine(opt_ctx& ctx, aco_ptr& instr) } /* skip &-4 after offset additions: load(a & -4, 16) */ - if (!instr->operands.empty()) - skip_smem_offset_align(ctx, &instr->smem()); + if (!instr->operands.empty() && align > 1) + skip_smem_offset_align(ctx, &instr->smem(), align); } Operand