nak: Wire up DP4

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/26533>
This commit is contained in:
Faith Ekstrand 2023-12-05 15:09:34 -06:00 committed by Marge Bot
parent fcf31d9c25
commit eb633b4978
5 changed files with 99 additions and 1 deletions

View file

@ -118,6 +118,9 @@ fn nir_options(dev: &nv_device_info) -> nir_shader_compiler_options {
op.lower_scmp = true;
op.lower_uadd_carry = true;
op.lower_usub_borrow = true;
op.has_sdot_4x8 = dev.sm >= 70;
op.has_udot_4x8 = dev.sm >= 70;
op.has_sudot_4x8 = dev.sm >= 70;
op
}

View file

@ -661,6 +661,33 @@ impl SM70Instr {
self.set_pred_src(77..80, 80, op.carry[1]);
}
fn encode_idp4(&mut self, op: &OpIDp4) {
self.encode_alu(
0x026,
Some(op.dst),
ALUSrc::from_src(&op.srcs[0]),
ALUSrc::from_src(&op.srcs[1]),
ALUSrc::from_src(&op.srcs[2]),
);
self.set_bit(
73,
match op.src_types[0] {
IntType::U8 => false,
IntType::I8 => true,
_ => panic!("Invalid DP4 source type"),
},
);
self.set_bit(
74,
match op.src_types[1] {
IntType::U8 => false,
IntType::I8 => true,
_ => panic!("Invalid DP4 source type"),
},
);
}
fn encode_imad(&mut self, op: &OpIMad) {
self.encode_alu(
0x024,
@ -1915,6 +1942,7 @@ impl SM70Instr {
Op::IAbs(op) => si.encode_iabs(&op),
Op::IAdd3(op) => si.encode_iadd3(&op),
Op::IAdd3X(op) => si.encode_iadd3x(&op),
Op::IDp4(op) => si.encode_idp4(&op),
Op::IMad(op) => si.encode_imad(&op),
Op::IMad64(op) => si.encode_imad64(&op),
Op::IMnMx(op) => si.encode_imnmx(&op),

View file

@ -1024,6 +1024,33 @@ impl<'a> ShaderFromNir<'a> {
b.prmt(low.into(), high.into(), [0, 1, 4, 5])
}
nir_op_sdot_4x8_iadd => {
let dst = b.alloc_ssa(RegFile::GPR, 1);
b.push_op(OpIDp4 {
dst: dst.into(),
src_types: [IntType::I8, IntType::I8],
srcs: [srcs[0], srcs[1], srcs[2]],
});
dst
}
nir_op_sudot_4x8_iadd => {
let dst = b.alloc_ssa(RegFile::GPR, 1);
b.push_op(OpIDp4 {
dst: dst.into(),
src_types: [IntType::I8, IntType::U8],
srcs: [srcs[0], srcs[1], srcs[2]],
});
dst
}
nir_op_udot_4x8_uadd => {
let dst = b.alloc_ssa(RegFile::GPR, 1);
b.push_op(OpIDp4 {
dst: dst.into(),
src_types: [IntType::U8, IntType::U8],
srcs: [srcs[0], srcs[1], srcs[2]],
});
dst
}
nir_op_u2f16 | nir_op_u2f32 | nir_op_u2f64 => {
let src_bits = alu.get_src(0).src.bit_size();
let dst_bits = alu.def.bit_size();

View file

@ -2450,6 +2450,32 @@ impl DisplayOp for OpIAdd3X {
}
impl_display_for_op!(OpIAdd3X);
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpIDp4 {
pub dst: Dst,
pub src_types: [IntType; 2],
#[src_type(I32)]
pub srcs: [Src; 3],
}
impl DisplayOp for OpIDp4 {
fn fmt_op(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"idp4{}{} {} {} {}",
self.src_types[0],
self.src_types[1],
self.srcs[0],
self.srcs[1],
self.srcs[2],
)
}
}
impl_display_for_op!(OpIDp4);
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpIMad {
@ -4299,6 +4325,7 @@ pub enum Op {
INeg(OpINeg),
IAdd3(OpIAdd3),
IAdd3X(OpIAdd3X),
IDp4(OpIDp4),
IMad(OpIMad),
IMad64(OpIMad64),
IMnMx(OpIMnMx),
@ -4731,6 +4758,7 @@ impl Instr {
| Op::INeg(_)
| Op::IAdd3(_)
| Op::IAdd3X(_)
| Op::IDp4(_)
| Op::IMad(_)
| Op::IMad64(_)
| Op::IMnMx(_)

View file

@ -63,9 +63,12 @@ fn copy_src_if_not_reg(b: &mut impl SSABuilder, src: &mut Src, file: RegFile) {
}
}
fn swap_srcs_if_not_reg(x: &mut Src, y: &mut Src) {
fn swap_srcs_if_not_reg(x: &mut Src, y: &mut Src) -> bool {
if !src_is_reg(x) && src_is_reg(y) {
std::mem::swap(x, y);
true
} else {
false
}
}
@ -153,6 +156,15 @@ fn legalize_instr(
copy_src_if_not_reg(b, src0, RegFile::GPR);
copy_src_if_not_reg(b, src2, RegFile::GPR);
}
Op::IDp4(op) => {
let [ref mut src_type0, ref mut src_type1] = op.src_types;
let [ref mut src0, ref mut src1, ref mut src2] = op.srcs;
if swap_srcs_if_not_reg(src0, src1) {
std::mem::swap(src_type0, src_type1);
}
copy_src_if_not_reg(b, src0, RegFile::GPR);
copy_src_if_not_reg(b, src2, RegFile::GPR);
}
Op::IMad(op) => {
let [ref mut src0, ref mut src1, ref mut src2] = op.srcs;
swap_srcs_if_not_reg(src0, src1);