aco: implement mqsad_4x8 and shfr

Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Georg Lehmann <dadschoorse@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/26251>
This commit is contained in:
Rhys Perry 2024-01-05 17:38:40 +00:00 committed by Marge Bot
parent 08903bbe89
commit 6b301eae36
3 changed files with 37 additions and 1 deletions

View file

@ -3424,6 +3424,41 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr)
emit_vop3a_instruction(ctx, instr, aco_opcode::v_msad_u8, dst, false, 3u, true);
break;
}
case nir_op_mqsad_4x8: {
assert(dst.regClass() == v4);
Temp ref = get_alu_src(ctx, instr->src[0]);
Temp src = get_alu_src(ctx, instr->src[1], 2);
Temp accum = get_alu_src(ctx, instr->src[2], 4);
Builder::Result res = bld.vop3(aco_opcode::v_mqsad_u32_u8, Definition(dst), as_vgpr(ctx, src),
as_vgpr(ctx, ref), as_vgpr(ctx, accum));
res.instr->operands[0].setLateKill(true);
res.instr->operands[1].setLateKill(true);
res.instr->operands[2].setLateKill(true);
emit_split_vector(ctx, dst, 4);
break;
}
case nir_op_shfr: {
if (dst.regClass() == s1) {
Temp src = bld.pseudo(aco_opcode::p_create_vector, bld.def(s2),
get_alu_src(ctx, instr->src[1]), get_alu_src(ctx, instr->src[0]));
Temp amount;
if (nir_src_is_const(instr->src[2].src)) {
amount = bld.copy(bld.def(s1), Operand::c32(nir_src_as_uint(instr->src[2].src) & 0x1f));
} else {
amount = bld.sop2(aco_opcode::s_and_b32, bld.def(s1), bld.def(s1, scc),
get_alu_src(ctx, instr->src[2]), Operand::c32(0x1f));
}
Temp res = bld.sop2(aco_opcode::s_lshr_b64, bld.def(s2), bld.def(s1, scc), src, amount);
bld.pseudo(aco_opcode::p_extract_vector, Definition(dst), res, Operand::zero());
} else if (dst.regClass() == v1) {
emit_vop3a_instruction(ctx, instr, aco_opcode::v_alignbit_b32, dst, false, 3u);
} else {
isel_err(&instr->instr, "Unimplemented NIR instr bit size");
}
break;
}
case nir_op_fquantize2f16: {
Temp src = get_alu_src(ctx, instr->src[0]);
Temp f16;

View file

@ -393,6 +393,7 @@ init_context(isel_context* ctx, nir_shader* shader)
case nir_op_frexp_exp:
case nir_op_cube_amd:
case nir_op_msad_4x8:
case nir_op_mqsad_4x8:
case nir_op_udot_4x8_uadd:
case nir_op_sdot_4x8_iadd:
case nir_op_sudot_4x8_iadd:

View file

@ -1185,7 +1185,7 @@ VOP3 = {
("v_qsad_pk_u16_u8", False, False, dst(2), src(2, 1, 2), op(0x172, gfx8=0x1e5, gfx10=0x172, gfx11=0x23a)),
("v_mqsad_pk_u16_u8", False, False, dst(2), src(2, 1, 2), op(0x173, gfx8=0x1e6, gfx10=0x173, gfx11=0x23b)),
("v_trig_preop_f64", False, False, dst(2), src(2, 2), op(0x174, gfx8=0x292, gfx10=0x174, gfx11=0x32f), InstrClass.ValuDouble),
("v_mqsad_u32_u8", False, False, dst(4), src(2, 1, 4), op(gfx7=0x175, gfx8=0x1e7, gfx10=0x175, gfx11=0x23d)),
("v_mqsad_u32_u8", False, False, dst(4), src(2, 1, 4), op(gfx7=0x175, gfx8=0x1e7, gfx10=0x175, gfx11=0x23d), InstrClass.ValuQuarterRate32),
("v_mad_u64_u32", False, False, dst(2, VCC), src(1, 1, 2), op(gfx7=0x176, gfx8=0x1e8, gfx10=0x176, gfx11=0x2fe), InstrClass.Valu64),
("v_mad_i64_i32", False, False, dst(2, VCC), src(1, 1, 2), op(gfx7=0x177, gfx8=0x1e9, gfx10=0x177, gfx11=0x2ff), InstrClass.Valu64),
("v_mad_legacy_f16", True, True, dst(1), src(1, 1, 1), op(gfx8=0x1ea, gfx10=-1)),