nak: Add base support for 8 and 16-bit types

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/26348>
This commit is contained in:
Faith Ekstrand 2023-11-22 14:32:21 -06:00 committed by Faith Ekstrand
parent 082c986614
commit 9e84e9e44b
2 changed files with 238 additions and 61 deletions

View file

@ -233,6 +233,44 @@ pub trait SSABuilder: Builder {
dst
}
fn prmt4(&mut self, src: [Src;4], sel: [u8;4]) -> SSARef {
let max_sel = *sel.iter().max().unwrap();
if max_sel < 8 {
self.prmt(src[0], src[1], sel)
} else if max_sel < 12 {
let mut sel_a = [0_u8; 4];
let mut sel_b = [0_u8; 4];
for i in 0..4_u8 {
if sel[usize::from(i)] < 8 {
sel_a[usize::from(i)] = sel[usize::from(i)];
sel_b[usize::from(i)] = i;
} else {
sel_b[usize::from(i)] = (sel[usize::from(i)] - 8) + 4;
}
}
let a = self.prmt(src[0], src[1], sel_a);
self.prmt(a.into(), src[2], sel_b)
} else if max_sel < 16 {
let mut sel_a = [0_u8; 4];
let mut sel_b = [0_u8; 4];
let mut sel_c = [0_u8; 4];
for i in 0..4_u8 {
if sel[usize::from(i)] < 8 {
sel_a[usize::from(i)] = sel[usize::from(i)];
sel_c[usize::from(i)] = i;
} else {
sel_b[usize::from(i)] = sel[usize::from(i)] - 8;
sel_c[usize::from(i)] = 4 + i;
}
}
let a = self.prmt(src[0], src[1], sel_a);
let b = self.prmt(src[2], src[3], sel_b);
self.prmt(a.into(), b.into(), sel_c)
} else {
panic!("Invalid permute value: {max_sel}");
}
}
fn sel(&mut self, cond: Src, x: Src, y: Src) -> SSARef {
assert!(cond.src_ref.is_predicate());
assert!(x.is_predicate() == y.is_predicate());

View file

@ -243,12 +243,21 @@ impl<'a> ShaderFromNir<'a> {
.or_insert(vec);
}
fn get_ssa_comp(&mut self, def: &nir_def, c: u8) -> SSARef {
fn get_ssa_comp(&mut self, def: &nir_def, c: u8) -> (SSARef, u8) {
let vec = self.get_ssa(def);
match def.bit_size {
1 | 32 => vec[usize::from(c)].into(),
64 => [vec[usize::from(c) * 2], vec[usize::from(c) * 2 + 1]].into(),
_ => panic!("Unsupported bit size"),
1 => (vec[usize::from(c)].into(), 0),
8 => (vec[usize::from(c / 4)].into(), c % 4),
16 => (vec[usize::from(c / 2)].into(), (c * 2) % 4),
32 => (vec[usize::from(c)].into(), 0),
64 => {
let comps = [
vec[usize::from(c) * 2 + 0],
vec[usize::from(c) * 2 + 1],
];
(comps.into(), 0)
}
_ => panic!("Unsupported bit size: {}", def.bit_size),
}
}
@ -271,7 +280,7 @@ impl<'a> ShaderFromNir<'a> {
if let Some(base_def) = std::ptr::NonNull::new(addr_offset.base.def) {
let base_def = unsafe { base_def.as_ref() };
let base_comp = u8::try_from(addr_offset.base.comp).unwrap();
let base = self.get_ssa_comp(base_def, base_comp);
let (base, _) = self.get_ssa_comp(base_def, base_comp);
(base.into(), addr_offset.offset)
} else {
(SrcRef::Zero.into(), addr_offset.offset)
@ -296,53 +305,156 @@ impl<'a> ShaderFromNir<'a> {
}
fn parse_alu(&mut self, b: &mut impl SSABuilder, alu: &nir_alu_instr) {
let mut srcs = Vec::new();
for (i, alu_src) in alu.srcs_as_slice().iter().enumerate() {
let bit_size = alu_src.src.bit_size();
let comps = alu.src_components(i.try_into().unwrap());
let alu_src_ssa = self.get_ssa(&alu_src.src.as_def());
let mut src_comps = Vec::new();
for c in 0..comps {
let s = usize::from(alu_src.swizzle[usize::from(c)]);
if bit_size == 1 || bit_size == 32 {
src_comps.push(alu_src_ssa[s]);
} else if bit_size == 64 {
src_comps.push(alu_src_ssa[s * 2]);
src_comps.push(alu_src_ssa[s * 2 + 1]);
} else {
panic!("Unhandled bit size");
}
}
srcs.push(Src::from(SSARef::try_from(src_comps).unwrap()));
}
/* Handle vectors as a special case since they're the only ALU ops that
* can produce more than a 16B of data.
*/
// Handle vectors and pack ops as a special case since they're the only
// ALU ops that can produce more than 16B. They are also the only ALU
// ops which we allow to consume small (8 and 16-bit) vector data
// scattered across multiple dwords
match alu.op {
nir_op_mov | nir_op_vec2 | nir_op_vec3 | nir_op_vec4
nir_op_mov
| nir_op_pack_32_4x8_split
| nir_op_pack_32_2x16_split
| nir_op_pack_64_2x32_split
| nir_op_vec2 | nir_op_vec3 | nir_op_vec4
| nir_op_vec5 | nir_op_vec8 | nir_op_vec16 => {
let file = if alu.def.bit_size == 1 {
RegFile::Pred
} else {
RegFile::GPR
};
let src_bit_size = alu.get_src(0).src.bit_size();
let bits = alu.def.num_components * alu.def.bit_size;
let mut dst_vec = Vec::new();
for src in srcs {
for v in src.as_ssa().unwrap().iter() {
let dst = b.alloc_ssa(file, 1)[0];
b.copy_to(dst.into(), (*v).into());
dst_vec.push(dst);
// Collect the sources into a vec with src_bit_size per SSA
// value in the vec. This implicitly makes 64-bit sources look
// like two 32-bit values
let mut srcs = Vec::new();
if alu.op == nir_op_mov {
let src = alu.get_src(0);
for c in 0..alu.def.num_components {
let s = src.swizzle[usize::from(c)];
let (src, byte) =
self.get_ssa_comp(src.src.as_def(), s);
for ssa in src.iter() {
srcs.push((*ssa, byte));
}
}
} else {
for src in alu.srcs_as_slice().iter() {
let s = src.swizzle[0];
let (src, byte) =
self.get_ssa_comp(src.src.as_def(), s);
for ssa in src.iter() {
srcs.push((*ssa, byte));
}
}
}
self.set_ssa(&alu.def, dst_vec);
let mut comps = Vec::new();
match src_bit_size {
1 | 32 | 64 => {
for (ssa, _) in srcs {
comps.push(ssa);
}
}
8 => {
for dc in 0..bits.div_ceil(32) {
let mut psrc = [Src::new_zero(); 4];
let mut psel = [0_u8; 4];
for b in 0..4 {
let sc = usize::from(dc * 4 + b);
if sc < srcs.len() {
let (ssa, byte) = srcs[sc];
for i in 0..4_u8 {
let psrc_i = &mut psrc[usize::from(i)];
if *psrc_i == Src::new_zero() {
*psrc_i = ssa.into();
} else if *psrc_i != Src::from(ssa) {
continue;
}
psel[usize::from(b)] = i * 4 + byte;
}
}
}
comps.push(b.prmt4(psrc, psel)[0]);
}
}
16 => {
for dc in 0..bits.div_ceil(32) {
let mut psrc = [Src::new_zero(); 2];
let mut psel = [0_u8; 4];
for w in 0..2 {
let sc = usize::from(dc * 2 + w);
if sc < srcs.len() {
let (ssa, byte) = srcs[sc];
let w_usize = usize::from(w);
psrc[w_usize] = ssa.into();
psel[w_usize * 2 + 0] = (w * 4) + byte;
psel[w_usize * 2 + 1] = (w * 4) + byte + 1;
}
}
comps.push(b.prmt(psrc[0], psrc[1], psel)[0]);
}
}
_ => panic!("Unknown bit size: {src_bit_size}"),
}
self.set_ssa(&alu.def, comps);
return;
}
_ => (),
}
let mut srcs: Vec<Src> = Vec::new();
for (i, alu_src) in alu.srcs_as_slice().iter().enumerate() {
let bit_size = alu_src.src.bit_size();
let comps = alu.src_components(i.try_into().unwrap());
let ssa = self.get_ssa(&alu_src.src.as_def());
match bit_size {
1 => {
assert!(comps == 1);
let s = usize::from(alu_src.swizzle[0]);
srcs.push(ssa[s].into());
}
8 => {
assert!(comps <= 4);
let s = alu_src.swizzle[0];
let dw = ssa[usize::from(s / 4)];
let mut prmt = [4_u8; 4];
for c in 0..comps {
let cs = alu_src.swizzle[usize::from(c)];
assert!(s / 4 == cs / 4);
prmt[usize::from(c)] = cs;
}
srcs.push(b.prmt(dw.into(), 0.into(), prmt).into());
}
16 => {
assert!(comps <= 2);
let s = alu_src.swizzle[0];
let dw = ssa[usize::from(s / 2)];
let mut prmt = [0_u8; 4];
for c in 0..comps {
let cs = alu_src.swizzle[usize::from(c)];
assert!(s / 2 == cs / 2);
prmt[usize::from(c) * 2 + 0] = cs * 2 + 0;
prmt[usize::from(c) * 2 + 1] = cs * 2 + 1;
}
// TODO: Some ops can handle swizzles
srcs.push(b.prmt(dw.into(), 0.into(), prmt).into());
}
32 => {
assert!(comps == 1);
let s = usize::from(alu_src.swizzle[0]);
srcs.push(ssa[s].into());
}
64 => {
assert!(comps == 1);
let s = usize::from(alu_src.swizzle[0]);
srcs.push([ssa[s * 2], ssa[s * 2 + 1]].into());
}
_ => panic!("Invalid bit size: {bit_size}"),
}
}
let dst: SSARef = match alu.op {
nir_op_b2b1 => {
assert!(alu.get_src(0).bit_size() == 32);
@ -813,12 +925,6 @@ impl<'a> ShaderFromNir<'a> {
nir_op_ixor => {
b.lop2(LogicOp::new_lut(&|x, y, _| x ^ y), srcs[0], srcs[1])
}
nir_op_pack_64_2x32_split => {
let dst = b.alloc_ssa(RegFile::GPR, 2);
b.copy_to(dst[0].into(), srcs[0]);
b.copy_to(dst[1].into(), srcs[1]);
dst
}
nir_op_pack_half_2x16_split => {
assert!(alu.get_src(0).bit_size() == 32);
let low = b.alloc_ssa(RegFile::GPR, 1);
@ -867,6 +973,12 @@ impl<'a> ShaderFromNir<'a> {
nir_op_ult => {
b.isetp(IntCmpType::U32, IntCmpOp::Lt, srcs[0], srcs[1])
}
nir_op_unpack_32_2x16_split_x => {
b.prmt(srcs[0], 0.into(), [0, 1, 4, 4])
}
nir_op_unpack_32_2x16_split_y => {
b.prmt(srcs[0], 0.into(), [2, 3, 4, 4])
}
nir_op_unpack_64_2x32_split_x => {
let src0_x = srcs[0].as_ssa().unwrap()[0];
b.copy(src0_x.into())
@ -2136,22 +2248,49 @@ impl<'a> ShaderFromNir<'a> {
b: &mut impl SSABuilder,
load_const: &nir_load_const_instr,
) {
let mut dst_vec = Vec::new();
for c in 0..load_const.def.num_components {
if load_const.def.bit_size == 1 {
let imm_b1 = unsafe { load_const.values()[c as usize].b };
dst_vec.push(b.copy(imm_b1.into())[0]);
} else if load_const.def.bit_size == 32 {
let imm_u32 = unsafe { load_const.values()[c as usize].u32_ };
dst_vec.push(b.copy(imm_u32.into())[0]);
} else if load_const.def.bit_size == 64 {
let imm_u64 = unsafe { load_const.values()[c as usize].u64_ };
dst_vec.push(b.copy((imm_u64 as u32).into())[0]);
dst_vec.push(b.copy(((imm_u64 >> 32) as u32).into())[0]);
let values = &load_const.values();
let mut dst = Vec::new();
match load_const.def.bit_size {
1 => for c in 0..load_const.def.num_components {
let imm_b1 = unsafe { values[usize::from(c)].b };
dst.push(b.copy(imm_b1.into())[0]);
}
8 => for dw in 0..load_const.def.num_components.div_ceil(4) {
let mut imm_u32 = 0;
for b in 0..4 {
let c = dw * 4 + b;
if c < load_const.def.num_components {
let imm_u8 = unsafe { values[usize::from(c)].u8_ };
imm_u32 |= u32::from(imm_u8) << b * 8;
}
}
dst.push(b.copy(imm_u32.into())[0]);
}
16 => for dw in 0..load_const.def.num_components.div_ceil(2) {
let mut imm_u32 = 0;
for w in 0..2 {
let c = dw * 2 + w;
if c < load_const.def.num_components {
let imm_u16 = unsafe { values[usize::from(c)].u16_ };
imm_u32 |= u32::from(imm_u16) << w * 16;
}
}
dst.push(b.copy(imm_u32.into())[0]);
}
32 => for c in 0..load_const.def.num_components {
let imm_u32 = unsafe { values[usize::from(c)].u32_ };
dst.push(b.copy(imm_u32.into())[0]);
}
64 => for c in 0..load_const.def.num_components {
let imm_u64 = unsafe { values[c as usize].u64_ };
dst.push(b.copy((imm_u64 as u32).into())[0]);
dst.push(b.copy(((imm_u64 >> 32) as u32).into())[0]);
}
_ => panic!("Unknown bit size: {}", load_const.def.bit_size),
}
self.set_ssa(&load_const.def, dst_vec);
self.set_ssa(&load_const.def, dst);
}
fn parse_undef(