diff --git a/src/amd/compiler/aco_optimizer.cpp b/src/amd/compiler/aco_optimizer.cpp index cfa287df9f6..70f333cc049 100644 --- a/src/amd/compiler/aco_optimizer.cpp +++ b/src/amd/compiler/aco_optimizer.cpp @@ -1046,6 +1046,10 @@ can_apply_extract(opt_ctx& ctx, aco_ptr& instr, unsigned idx, ssa_i if (!sel) { return false; + } else if (sel.size() == instr->operands[idx].bytes() && sel.size() == tmp.bytes() && + tmp.type() == instr->operands[idx].regClass().type()) { + assert(tmp.type() != RegType::sgpr); /* No sub-dword SGPR regclasses */ + return true; } else if ((instr->opcode == aco_opcode::v_cvt_f32_u32 || instr->opcode == aco_opcode::v_cvt_f32_i32) && sel.size() == 1 && !sel.sign_extend() && !instr->usesModifiers()) { @@ -1063,8 +1067,13 @@ can_apply_extract(opt_ctx& ctx, aco_ptr& instr, unsigned idx, ssa_i return true; } else if (idx < 2 && can_use_SDWA(ctx.program->gfx_level, instr, true) && (tmp.type() == RegType::vgpr || ctx.program->gfx_level >= GFX9)) { - if (instr->isSDWA() && instr->sdwa().sel[idx] != SubdwordSel::dword) - return false; + if (instr->isSDWA()) { + /* TODO: if we knew how many bytes this operand actually uses, we could have smaller + * second_dst parameter and apply more sign-extended sels. + */ + return apply_extract_twice(sel, instr->operands[idx].getTemp(), instr->sdwa().sel[idx], + Temp(0, v1)) != SubdwordSel(); + } return true; } else if (instr->isVALU() && sel.size() == 2 && !instr->valu().opsel[idx] && can_use_opsel(ctx.program->gfx_level, instr->opcode, idx)) { @@ -1103,8 +1112,9 @@ apply_extract(opt_ctx& ctx, aco_ptr& instr, unsigned idx, ssa_info& ctx.info[tmp.id()].label &= ~label_insert; - if (sel.size() == 4 && tmp.type() == instr->operands[idx].regClass().type()) { - /* full dword selection */ + if (sel.size() == instr->operands[idx].bytes() && sel.size() == tmp.bytes() && + tmp.type() == instr->operands[idx].regClass().type()) { + /* extract is a no-op */ } else if ((instr->opcode == aco_opcode::v_cvt_f32_u32 || instr->opcode == aco_opcode::v_cvt_f32_i32) && sel.size() == 1 && !sel.sign_extend() && !instr->usesModifiers()) { @@ -1137,7 +1147,8 @@ apply_extract(opt_ctx& ctx, aco_ptr& instr, unsigned idx, ssa_info& } else if (can_use_SDWA(ctx.program->gfx_level, instr, true) && (tmp.type() == RegType::vgpr || ctx.program->gfx_level >= GFX9)) { convert_to_SDWA(ctx.program->gfx_level, instr); - instr->sdwa().sel[idx] = sel; + instr->sdwa().sel[idx] = apply_extract_twice(sel, instr->operands[idx].getTemp(), + instr->sdwa().sel[idx], Temp(0, v1)); } else if (instr->isVALU()) { if (sel.offset()) { instr->valu().opsel[idx] = true; @@ -2042,15 +2053,16 @@ label_instruction(opt_ctx& ctx, aco_ptr& instr) ctx.info[instr->definitions[0].tempId()].set_canonicalized(); break; case aco_opcode::p_extract: { - if (instr->definitions[0].bytes() == 4 && instr->operands[0].isTemp()) { + if (instr->operands[0].isTemp()) { ctx.info[instr->definitions[0].tempId()].set_extract(instr.get()); - if (instr->operands[0].regClass() == v1 && parse_insert(instr.get())) + if (instr->definitions[0].bytes() == 4 && instr->operands[0].regClass() == v1 && + parse_insert(instr.get())) ctx.info[instr->operands[0].tempId()].set_insert(instr.get()); } break; } case aco_opcode::p_insert: { - if (instr->operands[0].bytes() == 4 && instr->operands[0].isTemp()) { + if (instr->operands[0].isTemp()) { if (instr->operands[0].regClass() == v1) ctx.info[instr->operands[0].tempId()].set_insert(instr.get()); if (parse_extract(instr.get())) diff --git a/src/amd/compiler/tests/test_sdwa.cpp b/src/amd/compiler/tests/test_sdwa.cpp index 1ab9ab37acb..a7fe358dec0 100644 --- a/src/amd/compiler/tests/test_sdwa.cpp +++ b/src/amd/compiler/tests/test_sdwa.cpp @@ -658,3 +658,44 @@ BEGIN_TEST(optimize.sdwa.extract_sgpr_limits) finish_opt_test(); END_TEST + +BEGIN_TEST(optimize.sdwa.subdword_extract) + //>> v1: %a, v1: %b, s2: %c = p_startpgm + if (!setup_cs("v1 v1 s2", GFX10_3)) + return; + + Temp a = inputs[0]; + Temp b = inputs[1]; + + //! v2b: %res0 = v_lshlrev_b16_e64 4, hi(%a) + //! p_unit_test 0, %res0 + writeout(0, bld.vop3(aco_opcode::v_lshlrev_b16_e64, bld.def(v2b), Operand::c32(4), + bld.pseudo(aco_opcode::p_extract, bld.def(v2b), a, Operand::c32(1), + Operand::c32(16), Operand::c32(false)))); + + //! v2b: %res1 = v_add_f16 %a, %b dst_sel:uword0 dst_preserve src0_sel:uword1 src1_sel:uword1 + //! p_unit_test 1, %res1 + writeout(1, + bld.vop2(aco_opcode::v_add_f16, bld.def(v2b), + bld.pseudo(aco_opcode::p_extract_vector, bld.def(v2b), a, Operand::c32(1)), + bld.pseudo(aco_opcode::p_extract_vector, bld.def(v2b), b, Operand::c32(1)))); + + //! v2b: %res2 = v_cndmask_b32 %a, %b, %c:vcc dst_sel:uword0 dst_preserve src0_sel:ubyte0 src1_sel:ubyte1 + //! p_unit_test 2, %res2 + writeout(2, bld.vop2(aco_opcode::v_cndmask_b32, bld.def(v2b), + bld.pseudo(aco_opcode::p_extract, bld.def(v2b), a, Operand::c32(0), + Operand::c32(8), Operand::c32(0)), + bld.pseudo(aco_opcode::p_extract, bld.def(v2b), b, Operand::c32(1), + Operand::c32(8), Operand::c32(0)), + inputs[2])); + + //! v1b: %res3 = v_or_b32 %a, %b dst_sel:ubyte0 dst_preserve src0_sel:ubyte0 src1_sel:ubyte2 + //! p_unit_test 3, %res3 + writeout(3, bld.vop2(aco_opcode::v_or_b32, bld.def(v1b), + bld.pseudo(aco_opcode::p_extract, bld.def(v1b), a, Operand::c32(0), + Operand::c32(16), Operand::c32(0)), + bld.pseudo(aco_opcode::p_extract, bld.def(v1b), b, Operand::c32(1), + Operand::c32(16), Operand::c32(0)))); + + finish_opt_test(); +END_TEST