diff --git a/src/nouveau/compiler/nak/ir.rs b/src/nouveau/compiler/nak/ir.rs index 78febca91c6..f8756d4a80b 100644 --- a/src/nouveau/compiler/nak/ir.rs +++ b/src/nouveau/compiler/nak/ir.rs @@ -6468,6 +6468,17 @@ pub enum OffsetStride { X16 = 4, } +impl OffsetStride { + pub fn shift(&self) -> u32 { + match self { + Self::X1 => 0, + Self::X4 => 2, + Self::X8 => 3, + Self::X16 => 4, + } + } +} + impl fmt::Display for OffsetStride { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let s = match self { diff --git a/src/nouveau/compiler/nak/sm70_encode.rs b/src/nouveau/compiler/nak/sm70_encode.rs index 07736e039b6..1fd267c256e 100644 --- a/src/nouveau/compiler/nak/sm70_encode.rs +++ b/src/nouveau/compiler/nak/sm70_encode.rs @@ -10,6 +10,7 @@ use crate::sm70::ShaderModel70; use bitview::*; use rustc_hash::FxHashMap; +use std::mem; use std::ops::Range; /// A per-op trait that implements Volta+ opcode semantics @@ -774,6 +775,60 @@ fn op_gpr(op: &impl DstsAsSlice) -> RegFile { } } +fn legalize_load_store_address( + b: &mut LegalizeBuilder, + addr: &mut Src, + uniform_addr: &mut Src, + stride: Option<&mut OffsetStride>, +) { + let stride_x1_or_none = matches!(stride, Some(OffsetStride::X1) | None); + if addr.is_ugpr_reg() { + if stride_x1_or_none && uniform_addr.is_zero() { + *uniform_addr = mem::replace(addr, Src::ZERO); + } else { + b.copy_src_if_uniform(addr); + } + } + + if uniform_addr.is_gpr_reg() { + if addr.is_zero() { + assert!(stride_x1_or_none); + *addr = mem::replace(uniform_addr, Src::ZERO); + } else { + let uniform_ssa = uniform_addr.as_ssa().unwrap(); + let mut ssa = addr.as_ssa().unwrap(); + + let addr_comps = ssa.comps(); + if let Some(stride) = stride { + if *stride != OffsetStride::X1 { + assert_eq!(addr_comps, 1); + let shift = stride.shift(); + let shift = b.copy(shift.into()); + *addr = b.shl(addr.clone(), shift.into()).into(); + ssa = addr.as_ssa().unwrap(); + *stride = OffsetStride::X1; + } + } + + if uniform_ssa.comps() == 2 { + // In case the non uniform address is 32 bits and the uniform one 64, + // we need convert it to 64 bits. + if uniform_ssa.comps() != addr_comps { + let zero = b.copy(0.into()); + *addr = [ssa[0], zero].into(); + } + *addr = b + .iadd64(addr.clone(), uniform_addr.clone(), Src::ZERO) + .into() + } else { + *addr = + b.iadd(addr.clone(), uniform_addr.clone(), Src::ZERO).into() + } + *uniform_addr = 0.into(); + } + } +} + // // Implementations of SM70Op for each op we support on Volta+ // @@ -3165,7 +3220,12 @@ impl SM70Op for OpSuAtom { impl SM70Op for OpLd { fn legalize(&mut self, b: &mut LegalizeBuilder) { - b.copy_src_if_uniform(&mut self.addr); + legalize_load_store_address( + b, + &mut self.addr, + &mut self.uniform_addr, + Some(&mut self.stride), + ); b.copy_src_if_uniform(&mut self.pred); } @@ -3327,8 +3387,13 @@ impl SM70Op for OpLdc { impl SM70Op for OpSt { fn legalize(&mut self, b: &mut LegalizeBuilder) { - b.copy_src_if_uniform(&mut self.addr); b.copy_src_if_uniform(&mut self.data); + legalize_load_store_address( + b, + &mut self.addr, + &mut self.uniform_addr, + Some(&mut self.stride), + ); } fn encode(&self, e: &mut SM70Encoder<'_>) { @@ -3461,9 +3526,19 @@ impl SM70Encoder<'_> { impl SM70Op for OpAtom { fn legalize(&mut self, b: &mut LegalizeBuilder) { - b.copy_src_if_uniform(&mut self.addr); - b.copy_src_if_uniform(&mut self.cmpr); b.copy_src_if_uniform(&mut self.data); + + if matches!(self.atom_op, AtomOp::CmpExch(_)) { + b.copy_src_if_uniform(&mut self.addr); + b.copy_src_if_uniform(&mut self.cmpr); + } else { + legalize_load_store_address( + b, + &mut self.addr, + &mut self.uniform_address, + Some(&mut self.addr_stride), + ); + } } fn encode(&self, e: &mut SM70Encoder<'_>) { @@ -4291,7 +4366,12 @@ impl SM70Op for OpHmma { impl SM70Op for OpLdsm { fn legalize(&mut self, b: &mut LegalizeBuilder) { - b.copy_src_if_uniform(&mut self.addr); + legalize_load_store_address( + b, + &mut self.addr, + &mut self.uniform_addr, + None, + ); } fn encode(&self, e: &mut SM70Encoder<'_>) {