Merge branch 'nak/opt/load_store_offsets' into 'main'

nak: support UGPR encoding of load/store/atom instructions

See merge request mesa/mesa!39384
This commit is contained in:
Karol Herbst 2026-05-08 00:20:26 +00:00
commit f1e276c5ca
16 changed files with 569 additions and 102 deletions

View file

@ -5792,11 +5792,13 @@ nir_lower_shader_calls(nir_shader *shader,
void *mem_ctx);
int nir_get_io_offset_src_number(const nir_intrinsic_instr *instr);
int nir_get_io_uniform_offset_src_number(const nir_intrinsic_instr *instr);
int nir_get_io_index_src_number(const nir_intrinsic_instr *instr);
int nir_get_io_data_src_number(const nir_intrinsic_instr *instr);
int nir_get_io_arrayed_index_src_number(const nir_intrinsic_instr *instr);
nir_src *nir_get_io_offset_src(nir_intrinsic_instr *instr);
nir_src *nir_get_io_uniform_offset_src(nir_intrinsic_instr *instr);
nir_src *nir_get_io_index_src(nir_intrinsic_instr *instr);
nir_src *nir_get_io_data_src(nir_intrinsic_instr *instr);
nir_src *nir_get_io_arrayed_index_src(nir_intrinsic_instr *instr);
@ -5806,7 +5808,6 @@ static inline unsigned
nir_get_io_base_size_nv(const nir_intrinsic_instr *intr)
{
switch (intr->intrinsic) {
case nir_intrinsic_global_atomic_nv:
case nir_intrinsic_global_atomic_swap_nv:
case nir_intrinsic_shared_atomic_nv:
case nir_intrinsic_shared_atomic_swap_nv:
@ -5819,6 +5820,9 @@ nir_get_io_base_size_nv(const nir_intrinsic_instr *intr)
case nir_intrinsic_store_shared_nv:
case nir_intrinsic_store_shared_unlock_nv:
return 24;
case nir_intrinsic_global_atomic_nv:
/* TODO: SM100+ only has 23 bits for the UGPR + GPR form */
return 23;
case nir_intrinsic_ldc_nv:
case nir_intrinsic_ldcx_nv:
return 16;

View file

@ -941,7 +941,8 @@ intrinsic("load_vulkan_descriptor", src_comp=[-1], dest_comp=0,
# The offset is sign-extended or zero-extended based on the SIGN_EXTEND index.
#
# NV variants all come with a 24 bit base, that is unsigned with a constant 0 address,
# signed otherwise.
# signed otherwise. Non swap atomic also comes with an additional uniform address source
# right after the non uniform memory address.
#
# PCO global variants use a vec3 for the memory address and data, where component X
# has the low 32 address bits, component Y has the high 32 address bits, and component Z
@ -950,13 +951,13 @@ intrinsic("load_vulkan_descriptor", src_comp=[-1], dest_comp=0,
intrinsic("deref_atomic", src_comp=[-1, 1], dest_comp=1, indices=[ACCESS, ATOMIC_OP])
intrinsic("ssbo_atomic", src_comp=[-1, 1, 1], dest_comp=1, indices=[ACCESS, ATOMIC_OP, OFFSET_SHIFT])
intrinsic("shared_atomic", src_comp=[1, 1], dest_comp=1, indices=[BASE, ATOMIC_OP])
intrinsic("shared_atomic_nv", src_comp=[1, 1], dest_comp=1, indices=[BASE, ATOMIC_OP, OFFSET_SHIFT_NV])
intrinsic("shared_atomic_nv", src_comp=[1, 1, 1], dest_comp=1, indices=[BASE, ATOMIC_OP, OFFSET_SHIFT_NV])
intrinsic("task_payload_atomic", src_comp=[1, 1], dest_comp=1, indices=[BASE, ATOMIC_OP])
intrinsic("global_atomic", src_comp=[1, 1], dest_comp=1, indices=[ATOMIC_OP])
intrinsic("global_atomic_2x32", src_comp=[2, 1], dest_comp=1, indices=[ATOMIC_OP])
intrinsic("global_atomic_amd", src_comp=[1, 1, 1], dest_comp=1, indices=[BASE, ATOMIC_OP])
intrinsic("global_atomic_agx", src_comp=[1, 1, 1], dest_comp=1, indices=[ATOMIC_OP, SIGN_EXTEND])
intrinsic("global_atomic_nv", src_comp=[1, 1], dest_comp=1, indices=[BASE, ATOMIC_OP])
intrinsic("global_atomic_nv", src_comp=[1, 1, 1], dest_comp=1, indices=[BASE, ATOMIC_OP])
intrinsic("global_atomic_pco", src_comp=[3], dest_comp=1, indices=[ATOMIC_OP], bit_sizes=[32])
intrinsic("deref_atomic_swap", src_comp=[-1, 1, 1], dest_comp=1, indices=[ACCESS, ATOMIC_OP])
@ -1946,15 +1947,15 @@ load("global_amd", [1, 1], indices=[BASE, ACCESS, ALIGN_MUL, ALIGN_OFFSET], flag
# src[] = { value, address, unsigned 32-bit offset }.
store("global_amd", [1, 1], indices=[BASE, ACCESS, ALIGN_MUL, ALIGN_OFFSET, WRITE_MASK])
# src[] = { address }. BASE is a 24 bit unsigned offset if a constant 0 address is given,
# signed otherwise.
# src[] = { address, uniform_address }. BASE is a 24 bit unsigned offset if a constant 0 address and
# a constant 0 uniform_address is given, signed otherwise.
# load_global_nv has an additional boolean input that makes the load return 0 on false.
load("global_nv", [1, 1], indices=[BASE, ACCESS, ALIGN_MUL, ALIGN_OFFSET], flags=[CAN_ELIMINATE])
store("global_nv", [1], indices=[BASE, ACCESS, ALIGN_MUL, ALIGN_OFFSET])
load("scratch_nv", [1], indices=[BASE, ACCESS, ALIGN_MUL, ALIGN_OFFSET], flags=[CAN_ELIMINATE])
store("scratch_nv", [1], indices=[BASE, ALIGN_MUL, ALIGN_OFFSET])
load("shared_nv", [1], indices=[BASE, OFFSET_SHIFT_NV, ACCESS, ALIGN_MUL, ALIGN_OFFSET], flags=[CAN_ELIMINATE])
store("shared_nv", [1], indices=[BASE, OFFSET_SHIFT_NV, ACCESS, ALIGN_MUL, ALIGN_OFFSET])
load("global_nv", [1, 1, 1], indices=[BASE, ACCESS, ALIGN_MUL, ALIGN_OFFSET], flags=[CAN_ELIMINATE])
store("global_nv", [1, 1], indices=[BASE, ACCESS, ALIGN_MUL, ALIGN_OFFSET])
load("scratch_nv", [1, 1], indices=[BASE, ACCESS, ALIGN_MUL, ALIGN_OFFSET], flags=[CAN_ELIMINATE])
store("scratch_nv", [1, 1], indices=[BASE, ALIGN_MUL, ALIGN_OFFSET])
load("shared_nv", [1, 1], indices=[BASE, OFFSET_SHIFT_NV, ACCESS, ALIGN_MUL, ALIGN_OFFSET], flags=[CAN_ELIMINATE])
store("shared_nv", [1, 1], indices=[BASE, OFFSET_SHIFT_NV, ACCESS, ALIGN_MUL, ALIGN_OFFSET])
# Same as shared_atomic_add, but with GDS. src[] = {store_val, gds_addr, m0}
intrinsic("gds_atomic_add_amd", src_comp=[1, 1, 1], dest_comp=1, indices=[BASE])
@ -2965,7 +2966,8 @@ intrinsic("ssa_bar_nv", src_comp=[1])
intrinsic("cmat_muladd_nv", src_comp=[-1, -1, -1], dest_comp=0, bit_sizes=src2,
indices=[FLAGS], flags=[CAN_ELIMINATE])
intrinsic("cmat_load_shared_nv", src_comp=[1], dest_comp=0, indices=[NUM_MATRICES, MATRIX_LAYOUT, BASE], flags=[CAN_ELIMINATE])
# src[] = { address, uniform_address }
intrinsic("cmat_load_shared_nv", src_comp=[1, 1], dest_comp=0, indices=[NUM_MATRICES, MATRIX_LAYOUT, BASE], flags=[CAN_ELIMINATE])
# Moves a 8x8 16bit matrix with transposition within a subgroup
intrinsic("cmat_mov_transpose_nv", src_comp=[2], dest_comp=2, bit_sizes=[16], flags=[CAN_ELIMINATE, CAN_REORDER, SUBGROUP])

View file

@ -1019,6 +1019,7 @@ nir_get_io_offset_src_number(const nir_intrinsic_instr *instr)
case nir_intrinsic_load_push_data_intel:
case nir_intrinsic_vild_nv:
case nir_intrinsic_load_shader_indirect_data_intel:
case nir_intrinsic_cmat_load_shared_nv:
return 0;
case nir_intrinsic_load_ubo:
case nir_intrinsic_load_ubo_vec4:
@ -1108,6 +1109,39 @@ nir_get_io_offset_src(nir_intrinsic_instr *instr)
case nir_intrinsic_bindless_image_##name: \
case nir_intrinsic_image_heap_##name
/**
* Return the uniform offset source number for a load/store intrinsic or -1 if there's no offset.
*/
int
nir_get_io_uniform_offset_src_number(const nir_intrinsic_instr *instr)
{
switch (instr->intrinsic) {
case nir_intrinsic_cmat_load_shared_nv:
case nir_intrinsic_global_atomic_nv:
case nir_intrinsic_load_global_nv:
case nir_intrinsic_load_scratch_nv:
case nir_intrinsic_load_shared_nv:
case nir_intrinsic_shared_atomic_nv:
return 1;
case nir_intrinsic_store_global_nv:
case nir_intrinsic_store_scratch_nv:
case nir_intrinsic_store_shared_nv:
return 2;
default:
return -1;
}
}
/**
* Return the uniform offset source for a load/store intrinsic.
*/
nir_src *
nir_get_io_uniform_offset_src(nir_intrinsic_instr *instr)
{
const int idx = nir_get_io_uniform_offset_src_number(instr);
return idx >= 0 ? &instr->src[idx] : NULL;
}
/**
* Return the index or handle source number for a load/store intrinsic or -1
* if there's no index or handle.

View file

@ -193,11 +193,12 @@ try_fold_load_store_nv(nir_builder *b,
assert(offset_idx >= 0);
nir_src src = intrin->src[offset_idx];
nir_src *uniform_src = nir_get_io_uniform_offset_src(intrin);
int32_t min = 0;
uint32_t max = BITFIELD_MASK(offset_bits);
if (!nir_src_is_const(src)) {
if (!nir_src_is_const(src) || (uniform_src && !nir_src_is_const(*uniform_src))) {
max >>= 1;
min = ~max;
}
@ -211,6 +212,11 @@ try_fold_load_store_nv(nir_builder *b,
return false;
}
/* We don't try to fold the offset for the uniform source on purpose,
* because we rely on running nir_opt_offsets before moving in the uniform
* source. However, we might run this pass again _after_ that, because we
* can eliminate a u2u64 on the _non uniform_ source and therefore might be
* able to fold in more constants into base. */
return try_fold_load_store(b, intrin, state, offset_idx, min, max, false);
}

View file

@ -761,9 +761,11 @@ validate_intrinsic_instr(nir_intrinsic_instr *instr, validate_state *state)
case nir_intrinsic_vild_nv: {
int base = nir_intrinsic_base(instr);
nir_src src = *nir_get_io_offset_src(instr);
nir_src *uniform_src = nir_get_io_uniform_offset_src(instr);
unsigned const_bits = nir_get_io_base_size_nv(instr);
if (nir_src_is_const(src) && nir_src_as_int(src) == 0) {
if (nir_src_is_const(src) && nir_src_as_int(src) == 0 &&
(!uniform_src || (nir_src_is_const(*uniform_src) && nir_src_as_int(*uniform_src) == 0))) {
validate_assert(state, base >= 0 && base < BITFIELD_MASK(const_bits));
} else {
int32_t max = BITFIELD_MASK(const_bits - 1);
@ -771,8 +773,14 @@ validate_intrinsic_instr(nir_intrinsic_instr *instr, validate_state *state)
validate_assert(state, base >= min && base < max);
}
if (uniform_src) {
validate_assert(state, uniform_src->ssa->bit_size >= src.ssa->bit_size);
if (state->impl->valid_metadata & nir_metadata_divergence)
validate_assert(state, !uniform_src->ssa->divergent);
}
if (instr->intrinsic == nir_intrinsic_load_global_nv) {
validate_assert(state, instr->src[1].ssa->bit_size == 1);
validate_assert(state, instr->src[2].ssa->bit_size == 1);
}
break;

View file

@ -2975,7 +2975,8 @@ impl<'a> ShaderFromNir<'a> {
nir_intrinsic_global_atomic_nv => {
let bit_size = intrin.def.bit_size();
let addr = self.get_src(&srcs[0]);
let data = self.get_src(&srcs[1]);
let uaddr = self.get_src(&srcs[1]);
let data = self.get_src(&srcs[2]);
let atom_type = self.get_atomic_type(intrin);
let atom_op = self.get_atomic_op(intrin, AtomCmpSrc::Separate);
@ -2992,6 +2993,7 @@ impl<'a> ShaderFromNir<'a> {
dst.clone().into()
},
addr: addr,
uniform_address: uaddr,
cmpr: 0.into(),
data: data,
atom_op: atom_op,
@ -3018,6 +3020,7 @@ impl<'a> ShaderFromNir<'a> {
b.push_op(OpAtom {
dst: dst.clone().into(),
addr: addr,
uniform_address: Src::ZERO,
cmpr: cmpr,
data: data,
atom_op: AtomOp::CmpExch(AtomCmpSrc::Separate),
@ -3218,12 +3221,14 @@ impl<'a> ShaderFromNir<'a> {
.get_eviction_priority(intrin.access()),
};
let addr = self.get_src(&srcs[0]);
let pred = self.get_src(&srcs[1]);
let uaddr = self.get_src(&srcs[1]);
let pred = self.get_src(&srcs[2]);
let dst = b.alloc_ssa_vec(RegFile::GPR, size_B.div_ceil(4));
b.push_op(OpLd {
dst: dst.clone().into(),
addr: addr,
uniform_addr: uaddr,
pred: pred,
offset: intrin.base(),
stride: OffsetStride::X1,
@ -3330,11 +3335,13 @@ impl<'a> ShaderFromNir<'a> {
eviction_priority: MemEvictionPriority::Normal,
};
let addr = self.get_src(&srcs[0]);
let uaddr = self.get_src(&srcs[1]);
let dst = b.alloc_ssa_vec(RegFile::GPR, size_B.div_ceil(4));
b.push_op(OpLd {
dst: dst.clone().into(),
addr: addr,
uniform_addr: uaddr,
pred: true.into(),
offset: intrin.base(),
stride: OffsetStride::X1,
@ -3353,11 +3360,14 @@ impl<'a> ShaderFromNir<'a> {
eviction_priority: MemEvictionPriority::Normal,
};
let addr = self.get_src(&srcs[0]);
let uaddr = self.get_src(&srcs[1]);
let dst = b.alloc_ssa_vec(RegFile::GPR, size_B.div_ceil(4));
b.push_op(OpLd {
dst: dst.clone().into(),
addr: addr,
uniform_addr: uaddr,
pred: true.into(),
offset: intrin.base(),
stride: intrin.offset_shift_nv().try_into().unwrap(),
@ -3668,7 +3678,8 @@ impl<'a> ShaderFromNir<'a> {
nir_intrinsic_shared_atomic_nv => {
let bit_size = intrin.def.bit_size();
let addr = self.get_src(&srcs[0]);
let data = self.get_src(&srcs[1]);
let uaddr = self.get_src(&srcs[1]);
let data = self.get_src(&srcs[2]);
let atom_type = self.get_atomic_type(intrin);
let atom_op = self.get_atomic_op(intrin, AtomCmpSrc::Separate);
@ -3678,6 +3689,7 @@ impl<'a> ShaderFromNir<'a> {
b.push_op(OpAtom {
dst: dst.clone().into(),
addr: addr,
uniform_address: uaddr,
cmpr: 0.into(),
data: data,
atom_op: atom_op,
@ -3704,6 +3716,7 @@ impl<'a> ShaderFromNir<'a> {
b.push_op(OpAtom {
dst: dst.clone().into(),
addr: addr,
uniform_address: Src::ZERO,
cmpr: cmpr,
data: data,
atom_op: AtomOp::CmpExch(AtomCmpSrc::Separate),
@ -3733,9 +3746,11 @@ impl<'a> ShaderFromNir<'a> {
.get_eviction_priority(intrin.access()),
};
let addr = self.get_src(&srcs[1]);
let uaddr = self.get_src(&srcs[2]);
b.push_op(OpSt {
addr: addr,
uniform_addr: uaddr,
data: data,
offset: intrin.base(),
stride: OffsetStride::X1,
@ -3764,9 +3779,11 @@ impl<'a> ShaderFromNir<'a> {
eviction_priority: MemEvictionPriority::Normal,
};
let addr = self.get_src(&srcs[1]);
let uaddr = self.get_src(&srcs[2]);
b.push_op(OpSt {
addr: addr,
uniform_addr: uaddr,
data: data,
offset: intrin.base(),
stride: OffsetStride::X1,
@ -3785,9 +3802,11 @@ impl<'a> ShaderFromNir<'a> {
eviction_priority: MemEvictionPriority::Normal,
};
let addr = self.get_src(&srcs[1]);
let uaddr = self.get_src(&srcs[2]);
b.push_op(OpSt {
addr: addr,
uniform_addr: uaddr,
data: data,
offset: intrin.base(),
stride: intrin.offset_shift_nv().try_into().unwrap(),
@ -3902,11 +3921,13 @@ impl<'a> ShaderFromNir<'a> {
};
let dst = b.alloc_ssa_vec(RegFile::GPR, comps);
let addr = self.get_src(&srcs[0]);
let uaddr = self.get_src(&srcs[1]);
b.push_op(OpLdsm {
dst: dst.clone().into(),
mat_size,
mat_count,
addr,
uniform_addr: uaddr,
offset: intrin.base(),
});
self.set_dst(&intrin.def, dst);

View file

@ -154,6 +154,7 @@ impl<'a> TestShaderBuilder<'a> {
self.push_op(OpLd {
dst: dst.clone().into(),
addr: self.data_addr.clone().into(),
uniform_addr: Src::ZERO,
pred: true.into(),
offset: offset.into(),
access: access,
@ -178,6 +179,7 @@ impl<'a> TestShaderBuilder<'a> {
assert!(data.comps() == comps);
self.push_op(OpSt {
addr: self.data_addr.clone().into(),
uniform_addr: Src::ZERO,
data: data.into(),
offset: offset.into(),
access: access,
@ -1734,6 +1736,7 @@ fn test_op_ldsm() {
let offset = b.imul(lane_id.into(), 16.into());
b.push_op(OpSt {
addr: offset.into(),
uniform_addr: Src::ZERO,
data: input.into(),
offset: 0,
access: MemAccess {
@ -1755,6 +1758,7 @@ fn test_op_ldsm() {
mat_size: LdsmSize::M8N8,
mat_count: 4,
addr: addr.into(),
uniform_addr: Src::ZERO,
offset: 0,
});
b.st_test_data(16, MemType::B128, res);

View file

@ -6468,6 +6468,17 @@ pub enum OffsetStride {
X16 = 4,
}
impl OffsetStride {
pub fn shift(&self) -> u32 {
match self {
Self::X1 => 0,
Self::X4 => 2,
Self::X8 => 3,
Self::X16 => 4,
}
}
}
impl fmt::Display for OffsetStride {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let s = match self {
@ -6502,6 +6513,9 @@ pub struct OpLd {
#[src_type(GPR)]
pub addr: Src,
#[src_type(GPR)]
pub uniform_addr: Src,
/// On false the load returns 0
#[src_type(Pred)]
pub pred: Src,
@ -6513,7 +6527,11 @@ pub struct OpLd {
impl DisplayOp for OpLd {
fn fmt_op(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "ld{} [{}{}", self.access, self.addr, self.stride)?;
write!(
f,
"ld{} [{}{}+{}",
self.access, self.addr, self.stride, self.uniform_addr
)?;
if self.offset > 0 {
write!(f, "+{:#x}", self.offset)?;
}
@ -6602,6 +6620,9 @@ pub struct OpLdsm {
#[src_type(SSA)]
pub addr: Src,
#[src_type(SSA)]
pub uniform_addr: Src,
pub offset: i32,
}
@ -6658,6 +6679,9 @@ pub struct OpSt {
#[src_type(SSA)]
pub data: Src,
#[src_type(GPR)]
pub uniform_addr: Src,
pub offset: i32,
pub stride: OffsetStride,
pub access: MemAccess,
@ -6665,7 +6689,11 @@ pub struct OpSt {
impl DisplayOp for OpSt {
fn fmt_op(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "st{} [{}{}", self.access, self.addr, self.stride)?;
write!(
f,
"st{} [{}{}+{}",
self.access, self.addr, self.stride, self.uniform_addr
)?;
if self.offset > 0 {
write!(f, "+{:#x}", self.offset)?;
}
@ -6711,6 +6739,9 @@ pub struct OpAtom {
#[src_type(GPR)]
pub addr: Src,
#[src_type(GPR)]
pub uniform_address: Src,
#[src_type(GPR)]
pub cmpr: Src,
@ -6743,10 +6774,16 @@ impl DisplayOp for OpAtom {
if !self.addr.is_zero() {
write!(f, "{}{}", self.addr, self.addr_stride)?;
}
if self.addr_offset > 0 {
if !self.uniform_address.is_zero() {
if !self.addr.is_zero() {
write!(f, "+")?;
}
write!(f, "{}", self.uniform_address)?;
}
if self.addr_offset > 0 {
if !self.addr.is_zero() || !self.uniform_address.is_zero() {
write!(f, "+")?;
}
write!(f, "{:#x}", self.addr_offset)?;
}
write!(f, "]")?;

View file

@ -95,6 +95,7 @@ impl LowerCopySwap {
b.push_op(OpLd {
dst: copy.dst,
addr: Src::ZERO,
uniform_addr: Src::ZERO,
pred: true.into(),
offset: addr.try_into().unwrap(),
stride: OffsetStride::X1,
@ -175,6 +176,7 @@ impl LowerCopySwap {
self.slm_size = max(self.slm_size, addr + 4);
b.push_op(OpSt {
addr: Src::ZERO,
uniform_addr: Src::ZERO,
data: copy.src,
offset: addr.try_into().unwrap(),
stride: OffsetStride::X1,

View file

@ -288,10 +288,11 @@ pub fn test_ldc() {
#[test]
pub fn test_ld_st_atom() {
let r0 = RegRef::new(RegFile::GPR, 0, 1);
let r1 = RegRef::new(RegFile::GPR, 1, 1);
let r1 = RegRef::new(RegFile::GPR, 1, 2);
let r2 = RegRef::new(RegFile::GPR, 2, 1);
let r3 = RegRef::new(RegFile::GPR, 3, 1);
let p4 = RegRef::new(RegFile::Pred, 4, 1);
let ur2 = RegRef::new(RegFile::UGPR, 2, 2);
let order = MemOrder::Strong(MemScope::CTA);
@ -318,6 +319,18 @@ pub fn test_ld_st_atom() {
{
for addr_stride in [OffsetStride::X1, OffsetStride::X8] {
let cta = if sm >= 80 { "sm" } else { "cta" };
let r1_str =
if sm >= 75 && matches!(space, MemSpace::Global(_)) {
"r1.64"
} else {
"r1"
};
let urz = if sm >= 73 {
SrcRef::Reg(ur2).into()
} else {
Src::ZERO
};
let urz_str = if sm >= 73 { "+ur2" } else { "" };
let pri = match space {
MemSpace::Global(_) => MemEvictionPriority::First,
@ -339,6 +352,7 @@ pub fn test_ld_st_atom() {
let instr = OpLd {
dst: Dst::Reg(r0),
addr: SrcRef::Reg(r1).into(),
uniform_addr: urz.clone(),
pred: if matches!(space, MemSpace::Global(_))
&& sm >= 73
{
@ -353,7 +367,7 @@ pub fn test_ld_st_atom() {
let expected = match space {
MemSpace::Global(_) if sm >= 73 => {
format!(
"ldg.e.ef.strong.{cta} r0, [r1+{addr_offset_str}], p4;"
"ldg.e.ef.strong.{cta} r0, [{r1_str}{urz_str}+{addr_offset_str}], p4;"
)
}
MemSpace::Global(_) => {
@ -363,17 +377,20 @@ pub fn test_ld_st_atom() {
}
MemSpace::Shared => {
format!(
"lds r0, [r1{addr_stride}+{addr_offset_str}];"
"lds r0, [{r1_str}{addr_stride}{urz_str}+{addr_offset_str}];"
)
}
MemSpace::Local => {
format!("ldl r0, [r1+{addr_offset_str}];")
format!(
"ldl r0, [{r1_str}{urz_str}+{addr_offset_str}];"
)
}
};
c.push(instr, expected);
let instr = OpSt {
addr: SrcRef::Reg(r1).into(),
uniform_addr: urz.clone(),
data: SrcRef::Reg(r2).into(),
offset: addr_offset,
access: access.clone(),
@ -382,16 +399,18 @@ pub fn test_ld_st_atom() {
let expected = match space {
MemSpace::Global(_) => {
format!(
"stg.e.ef.strong.{cta} [r1+{addr_offset_str}], r2;"
"stg.e.ef.strong.{cta} [{r1_str}{urz_str}+{addr_offset_str}], r2;"
)
}
MemSpace::Shared => {
format!(
"sts [r1{addr_stride}+{addr_offset_str}], r2;"
"sts [{r1_str}{addr_stride}{urz_str}+{addr_offset_str}], r2;"
)
}
MemSpace::Local => {
format!("stl [r1+{addr_offset_str}], r2;")
format!(
"stl [{r1_str}{urz_str}+{addr_offset_str}], r2;"
)
}
};
c.push(instr, expected);
@ -405,6 +424,7 @@ pub fn test_ld_st_atom() {
Dst::None
},
addr: SrcRef::Reg(r1).into(),
uniform_address: urz.clone(),
data: SrcRef::Reg(r2).into(),
atom_op: AtomOp::Add,
cmpr: SrcRef::Reg(r3).into(),
@ -429,7 +449,7 @@ pub fn test_ld_st_atom() {
};
let dst =
if use_dst { "pt, r0, " } else { "" };
format!("{op}.e.add.ef{atom_type_str}.strong.{cta} {dst}[r1+{addr_offset_str}], r2;")
format!("{op}.e.add.ef{atom_type_str}.strong.{cta} {dst}[{r1_str}{urz_str}+{addr_offset_str}], r2;")
}
MemSpace::Shared => {
if atom_type.is_float() {
@ -439,7 +459,7 @@ pub fn test_ld_st_atom() {
continue;
}
let dst = if use_dst { "r0" } else { "rz" };
format!("atoms.add{atom_type_str} {dst}, [r1{addr_stride}+{addr_offset_str}], r2;")
format!("atoms.add{atom_type_str} {dst}, [{r1_str}{addr_stride}{urz_str}+{addr_offset_str}], r2;")
}
MemSpace::Local => continue,
};

View file

@ -10,6 +10,7 @@ use crate::sm70::ShaderModel70;
use bitview::*;
use rustc_hash::FxHashMap;
use std::mem;
use std::ops::Range;
/// A per-op trait that implements Volta+ opcode semantics
@ -108,6 +109,29 @@ impl SM70Encoder<'_> {
}
}
fn set_reg_addr(
&mut self,
range: Range<usize>,
src: &Src,
size_bit: usize,
) {
assert!(src.is_unmodified());
match src.src_ref {
SrcRef::Zero => {
self.set_reg(range, self.zero_reg(RegFile::GPR));
// We always treat a zero GPR as 32 bits, so the UGPR source
// can be 32 bits.
self.set_bit(size_bit, false);
}
SrcRef::Reg(reg) => {
self.set_reg(range, reg);
assert!(reg.comps() <= 2);
self.set_bit(size_bit, reg.comps() == 2);
}
_ => panic!("Not a register"),
}
}
fn set_ureg_src(&mut self, start: usize, src: &Src) {
assert!(src.src_mod.is_none());
match src.src_ref {
@ -117,6 +141,24 @@ impl SM70Encoder<'_> {
}
}
fn set_ureg_addr(&mut self, start: usize, src: &Src, size_bit: usize) {
assert!(src.src_mod.is_none());
match src.src_ref {
SrcRef::Zero => {
self.set_ureg(start, self.zero_reg(RegFile::UGPR));
// We always treat a zero UGPR as 64 bits, so the GPR source
// can be 64 bit.
self.set_bit(size_bit, true);
}
SrcRef::Reg(reg) => {
self.set_ureg(start, reg);
assert!(reg.comps() <= 2);
self.set_bit(size_bit, reg.comps() == 2);
}
_ => panic!("Not a register"),
}
}
fn set_pred_dst(&mut self, range: Range<usize>, dst: &Dst) {
match dst {
Dst::None => self.set_pred_reg(range, self.true_reg(RegFile::Pred)),
@ -733,6 +775,60 @@ fn op_gpr(op: &impl DstsAsSlice) -> RegFile {
}
}
fn legalize_load_store_address(
b: &mut LegalizeBuilder,
addr: &mut Src,
uniform_addr: &mut Src,
stride: Option<&mut OffsetStride>,
) {
let stride_x1_or_none = matches!(stride, Some(OffsetStride::X1) | None);
if addr.is_ugpr_reg() {
if stride_x1_or_none && uniform_addr.is_zero() {
*uniform_addr = mem::replace(addr, Src::ZERO);
} else {
b.copy_src_if_uniform(addr);
}
}
if uniform_addr.is_gpr_reg() {
if addr.is_zero() {
assert!(stride_x1_or_none);
*addr = mem::replace(uniform_addr, Src::ZERO);
} else {
let uniform_ssa = uniform_addr.as_ssa().unwrap();
let mut ssa = addr.as_ssa().unwrap();
let addr_comps = ssa.comps();
if let Some(stride) = stride {
if *stride != OffsetStride::X1 {
assert_eq!(addr_comps, 1);
let shift = stride.shift();
let shift = b.copy(shift.into());
*addr = b.shl(addr.clone(), shift.into()).into();
ssa = addr.as_ssa().unwrap();
*stride = OffsetStride::X1;
}
}
if uniform_ssa.comps() == 2 {
// In case the non uniform address is 32 bits and the uniform one 64,
// we need convert it to 64 bits.
if uniform_ssa.comps() != addr_comps {
let zero = b.copy(0.into());
*addr = [ssa[0], zero].into();
}
*addr = b
.iadd64(addr.clone(), uniform_addr.clone(), Src::ZERO)
.into()
} else {
*addr =
b.iadd(addr.clone(), uniform_addr.clone(), Src::ZERO).into()
}
*uniform_addr = 0.into();
}
}
}
//
// Implementations of SM70Op for each op we support on Volta+
//
@ -3009,13 +3105,6 @@ impl SM70Encoder<'_> {
}
fn set_mem_access(&mut self, access: &MemAccess) {
self.set_field(
72..73,
match access.space.addr_type() {
MemAddrType::A32 => 0_u8,
MemAddrType::A64 => 1_u8,
},
);
self.set_mem_type(73..76, access.mem_type);
self.set_mem_order(&access.order);
self.set_eviction_priority(&access.eviction_priority);
@ -3131,20 +3220,31 @@ impl SM70Op for OpSuAtom {
impl SM70Op for OpLd {
fn legalize(&mut self, b: &mut LegalizeBuilder) {
b.copy_src_if_uniform(&mut self.addr);
legalize_load_store_address(
b,
&mut self.addr,
&mut self.uniform_addr,
Some(&mut self.stride),
);
b.copy_src_if_uniform(&mut self.pred);
}
fn encode(&self, e: &mut SM70Encoder<'_>) {
let has_ugpr = e.sm >= 73;
match self.access.space {
MemSpace::Global(_) => {
e.set_opcode(0x381);
assert_eq!(self.stride, OffsetStride::X1);
if e.sm >= 73 {
if has_ugpr {
e.set_opcode(0x981);
e.set_reg_addr(24..32, &self.addr, 90);
e.set_ureg_addr(32, &self.uniform_addr, 72);
e.set_rev_pred_src(64..67, 67, &self.pred);
} else {
assert!(self.pred.is_true());
e.set_opcode(0x381);
e.set_reg_addr(24..32, &self.addr, 72);
}
e.set_pred_dst(81..84, &Dst::None);
e.set_mem_access(&self.access);
}
@ -3152,6 +3252,10 @@ impl SM70Op for OpLd {
assert!(self.pred.is_true());
assert_eq!(self.stride, OffsetStride::X1);
e.set_opcode(0x983);
e.set_reg_src(24..32, &self.addr);
if has_ugpr {
e.set_ureg_src(32, &self.uniform_addr);
}
e.set_field(84..87, 1_u8);
e.set_mem_type(73..76, self.access.mem_type);
@ -3165,6 +3269,10 @@ impl SM70Op for OpLd {
e.set_opcode(0x984);
assert!(self.pred.is_true());
e.set_reg_src(24..32, &self.addr);
if has_ugpr {
e.set_ureg_src(32, &self.uniform_addr);
}
e.set_mem_type(73..76, self.access.mem_type);
assert!(self.access.order == MemOrder::Strong(MemScope::CTA));
assert!(
@ -3179,8 +3287,11 @@ impl SM70Op for OpLd {
}
e.set_dst(&self.dst);
e.set_reg_src(24..32, &self.addr);
e.set_field(40..64, self.offset);
// We always enable UGPR mode, because the .E bit changes
// which source it applies to depending on it.
// This way it always applies to the UGPR.
e.set_bit(91, has_ugpr);
}
}
@ -3276,20 +3387,40 @@ impl SM70Op for OpLdc {
impl SM70Op for OpSt {
fn legalize(&mut self, b: &mut LegalizeBuilder) {
b.copy_src_if_uniform(&mut self.addr);
b.copy_src_if_uniform(&mut self.data);
legalize_load_store_address(
b,
&mut self.addr,
&mut self.uniform_addr,
Some(&mut self.stride),
);
}
fn encode(&self, e: &mut SM70Encoder<'_>) {
let has_ugpr = e.sm >= 75;
match self.access.space {
MemSpace::Global(_) => {
e.set_opcode(0x386);
assert_eq!(self.stride, OffsetStride::X1);
if has_ugpr {
e.set_opcode(0x986);
e.set_reg_addr(24..32, &self.addr, 90);
e.set_ureg_addr(64, &self.uniform_addr, 72);
} else {
e.set_opcode(0x386);
e.set_reg_addr(24..32, &self.addr, 72);
}
e.set_mem_access(&self.access);
}
MemSpace::Local => {
e.set_opcode(0x387);
assert_eq!(self.stride, OffsetStride::X1);
if has_ugpr {
e.set_opcode(0x987);
e.set_reg_src(24..32, &self.addr);
e.set_ureg_src(64, &self.uniform_addr);
} else {
e.set_opcode(0x387);
e.set_reg_src(24..32, &self.addr);
}
e.set_field(84..87, 1_u8);
e.set_mem_type(73..76, self.access.mem_type);
@ -3300,7 +3431,14 @@ impl SM70Op for OpSt {
);
}
MemSpace::Shared => {
e.set_opcode(0x388);
if has_ugpr {
e.set_opcode(0x988);
e.set_reg_src(24..32, &self.addr);
e.set_ureg_src(64, &self.uniform_addr);
} else {
e.set_opcode(0x388);
e.set_reg_src(24..32, &self.addr);
}
e.set_mem_type(73..76, self.access.mem_type);
assert!(self.access.order == MemOrder::Strong(MemScope::CTA));
@ -3314,9 +3452,12 @@ impl SM70Op for OpSt {
}
}
e.set_reg_src(24..32, &self.addr);
e.set_reg_src(32..40, &self.data);
e.set_field(40..64, self.offset);
// We always enable UGPR mode, because the .E bit changes
// which source it applies to depending on it.
// This way it always applies to the UGPR.
e.set_bit(91, has_ugpr);
}
}
@ -3385,12 +3526,23 @@ impl SM70Encoder<'_> {
impl SM70Op for OpAtom {
fn legalize(&mut self, b: &mut LegalizeBuilder) {
b.copy_src_if_uniform(&mut self.addr);
b.copy_src_if_uniform(&mut self.cmpr);
b.copy_src_if_uniform(&mut self.data);
if matches!(self.atom_op, AtomOp::CmpExch(_)) {
b.copy_src_if_uniform(&mut self.addr);
b.copy_src_if_uniform(&mut self.cmpr);
} else {
legalize_load_store_address(
b,
&mut self.addr,
&mut self.uniform_address,
Some(&mut self.addr_stride),
);
}
}
fn encode(&self, e: &mut SM70Encoder<'_>) {
let has_ugpr = e.sm >= 75;
match self.mem_space {
MemSpace::Global(_) => {
if self.dst.is_none() {
@ -3401,34 +3553,58 @@ impl SM70Op for OpAtom {
}
e.set_reg_src(32..40, &self.data);
e.set_field(40..64, self.addr_offset);
e.set_atom_op(87..90, self.atom_op);
if has_ugpr {
e.set_reg_addr(24..32, &self.addr, 90);
e.set_ureg_addr(64, &self.uniform_address, 72);
e.set_bit(91, true);
} else {
e.set_reg_addr(24..32, &self.addr, 72);
assert!(self.uniform_address.is_zero());
}
} else if let AtomOp::CmpExch(cmp_src) = self.atom_op {
e.set_opcode(0x3a9);
assert!(cmp_src == AtomCmpSrc::Separate);
assert!(self.uniform_address.is_zero());
e.set_reg_addr(24..32, &self.addr, 72);
e.set_reg_src(32..40, &self.cmpr);
e.set_field(40..64, self.addr_offset);
e.set_reg_src(64..72, &self.data);
e.set_pred_dst(81..84, &Dst::None);
} else {
if e.sm >= 90 && self.atom_type.is_float() {
e.set_opcode(0x3a3);
e.set_opcode(0x9a3);
} else if has_ugpr {
e.set_opcode(0x9a8);
} else {
e.set_opcode(0x3a8);
}
if e.sm >= 100 {
e.set_reg_addr(24..32, &self.addr, 63);
e.set_ureg_addr(64, &self.uniform_address, 72);
} else if has_ugpr {
e.set_reg_addr(24..32, &self.addr, 70);
e.set_ureg_addr(64, &self.uniform_address, 72);
} else {
e.set_reg_addr(24..32, &self.addr, 72);
assert!(self.uniform_address.is_zero());
};
if e.sm >= 100 {
e.set_field(40..63, self.addr_offset);
} else {
e.set_field(40..64, self.addr_offset);
};
e.set_reg_src(32..40, &self.data);
e.set_pred_dst(81..84, &Dst::None);
e.set_atom_op(87..91, self.atom_op);
e.set_bit(91, has_ugpr);
}
e.set_field(
72..73,
match self.mem_space.addr_type() {
MemAddrType::A32 => 0_u8,
MemAddrType::A64 => 1_u8,
},
);
e.set_mem_order(&self.mem_order);
e.set_eviction_priority(&self.mem_eviction_priority);
assert_eq!(self.addr_stride, OffsetStride::X1);
@ -3439,10 +3615,17 @@ impl SM70Op for OpAtom {
e.set_opcode(0x38d);
assert!(cmp_src == AtomCmpSrc::Separate);
assert!(self.uniform_address.is_zero());
e.set_reg_src(32..40, &self.cmpr);
e.set_reg_src(64..72, &self.data);
} else {
e.set_opcode(0x38c);
if has_ugpr {
e.set_opcode(0x98c);
e.set_ureg_src(64, &self.uniform_address);
e.set_bit(91, true);
} else {
e.set_opcode(0x38c);
}
e.set_reg_src(32..40, &self.data);
assert!(
@ -3457,6 +3640,8 @@ impl SM70Op for OpAtom {
e.set_atom_op(87..91, self.atom_op);
}
e.set_reg_src(24..32, &self.addr);
e.set_field(40..64, self.addr_offset);
assert!(e.sm >= 75 || self.addr_stride == OffsetStride::X1);
e.set_field(78..80, self.addr_stride.encode_sm75());
@ -3468,8 +3653,6 @@ impl SM70Op for OpAtom {
}
e.set_dst(&self.dst);
e.set_reg_src(24..32, &self.addr);
e.set_field(40..64, self.addr_offset);
e.set_atom_type(self.atom_type, false);
}
}
@ -4183,7 +4366,12 @@ impl SM70Op for OpHmma {
impl SM70Op for OpLdsm {
fn legalize(&mut self, b: &mut LegalizeBuilder) {
b.copy_src_if_uniform(&mut self.addr);
legalize_load_store_address(
b,
&mut self.addr,
&mut self.uniform_addr,
None,
);
}
fn encode(&self, e: &mut SM70Encoder<'_>) {
@ -4192,6 +4380,7 @@ impl SM70Op for OpLdsm {
e.set_opcode(0x83b);
e.set_dst(&self.dst);
e.set_reg_src(24..32, &self.addr);
e.set_ureg_src(32, &self.uniform_addr);
e.set_field(40..64, self.offset);
e.set_field(
72..74,
@ -4212,6 +4401,7 @@ impl SM70Op for OpLdsm {
// LdsmSize::M8N32 => 3,
},
);
e.set_bit(91, !self.uniform_addr.is_zero());
}
}

View file

@ -1019,8 +1019,23 @@ nak_nir_lower_load_store(nir_shader *nir, const struct nak_compiler *nak)
continue;
nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
nir_src *addr;
switch (intr->intrinsic) {
case nir_intrinsic_load_global_bounded:
case nir_intrinsic_load_global_constant_bounded: {
addr = &intr->src[0];
break;
}
default:
addr = nir_get_io_offset_src(intr);
break;
}
if (!addr)
continue;
b.cursor = nir_before_instr(instr);
nir_src *addr = nir_get_io_offset_src(intr);
nir_def *uaddr = nir_imm_zero(&b, 1, addr->ssa->bit_size);
nir_def *res = NULL;
nir_intrinsic_instr *new = NULL;
@ -1028,7 +1043,7 @@ nak_nir_lower_load_store(nir_shader *nir, const struct nak_compiler *nak)
case nir_intrinsic_load_global:
case nir_intrinsic_load_global_constant: {
nir_def *nir_true = nir_imm_bool(&b, true);
res = nir_load_global_nv(&b, intr->def.num_components, intr->def.bit_size, addr->ssa, nir_true);
res = nir_load_global_nv(&b, intr->def.num_components, intr->def.bit_size, addr->ssa, uaddr, nir_true);
break;
}
case nir_intrinsic_load_global_bounded:
@ -1044,32 +1059,32 @@ nak_nir_lower_load_store(nir_shader *nir, const struct nak_compiler *nak)
nir_def *addr = nir_iadd(&b, base->ssa, nir_u2u64(&b, offset->ssa));
nir_def *last_byte = nir_iadd_imm(&b, offset->ssa, load_size - 1);
nir_def *cond = nir_ult(&b, last_byte, size->ssa);
res = nir_load_global_nv(&b, intr->def.num_components, intr->def.bit_size, addr, cond);
res = nir_load_global_nv(&b, intr->def.num_components, intr->def.bit_size, addr, uaddr, cond);
break;
}
case nir_intrinsic_load_scratch:
res = nir_load_scratch_nv(&b, intr->def.num_components, intr->def.bit_size, addr->ssa);
res = nir_load_scratch_nv(&b, intr->def.num_components, intr->def.bit_size, addr->ssa, uaddr);
break;
case nir_intrinsic_load_shared:
res = nir_load_shared_nv(&b, intr->def.num_components, intr->def.bit_size, addr->ssa);
res = nir_load_shared_nv(&b, intr->def.num_components, intr->def.bit_size, addr->ssa, uaddr);
break;
case nir_intrinsic_store_global:
new = nir_store_global_nv(&b, intr->src[0].ssa, addr->ssa);
new = nir_store_global_nv(&b, intr->src[0].ssa, addr->ssa, uaddr);
break;
case nir_intrinsic_store_scratch:
new = nir_store_scratch_nv(&b, intr->src[0].ssa, addr->ssa);
new = nir_store_scratch_nv(&b, intr->src[0].ssa, addr->ssa, uaddr);
break;
case nir_intrinsic_store_shared:
new = nir_store_shared_nv(&b, intr->src[0].ssa, addr->ssa);
new = nir_store_shared_nv(&b, intr->src[0].ssa, addr->ssa, uaddr);
break;
case nir_intrinsic_global_atomic:
res = nir_global_atomic_nv(&b, intr->def.bit_size, addr->ssa, intr->src[1].ssa);
res = nir_global_atomic_nv(&b, intr->def.bit_size, addr->ssa, uaddr, intr->src[1].ssa);
break;
case nir_intrinsic_global_atomic_swap:
res = nir_global_atomic_swap_nv(&b, intr->def.bit_size, addr->ssa, intr->src[1].ssa, intr->src[2].ssa);
break;
case nir_intrinsic_shared_atomic:
res = nir_shared_atomic_nv(&b, intr->def.bit_size, addr->ssa, intr->src[1].ssa);
res = nir_shared_atomic_nv(&b, intr->def.bit_size, addr->ssa, uaddr, intr->src[1].ssa);
break;
case nir_intrinsic_shared_atomic_swap:
res = nir_shared_atomic_swap_nv(&b, intr->def.bit_size, addr->ssa, intr->src[1].ssa, intr->src[2].ssa);
@ -1115,6 +1130,113 @@ nak_nir_lower_load_store(nir_shader *nir, const struct nak_compiler *nak)
return progress;
}
static bool
is_divergent_phi(nir_instr *instr)
{
if (instr->type != nir_instr_type_phi)
return false;
nir_phi_instr *phi = nir_instr_as_phi(instr);
return nak_nir_phi_is_divergent(phi);
}
static bool
nak_nir_opt_uniform_address_impl(struct nir_builder *b,
nir_intrinsic_instr *intr, void *cb_data)
{
switch (intr->intrinsic) {
case nir_intrinsic_cmat_load_shared_nv:
case nir_intrinsic_global_atomic_nv:
case nir_intrinsic_load_global_nv:
case nir_intrinsic_load_scratch_nv:
case nir_intrinsic_load_shared_nv:
case nir_intrinsic_shared_atomic_nv:
case nir_intrinsic_store_global_nv:
case nir_intrinsic_store_scratch_nv:
case nir_intrinsic_store_shared_nv: {
nir_src *offset_src = nir_get_io_offset_src(intr);
nir_def *offset = offset_src->ssa;
nir_src *uniform_offset_src = nir_get_io_uniform_offset_src(intr);
nir_def *uniform_offset = uniform_offset_src->ssa;
nir_block *use_block = intr->instr.block;
assert(nir_src_as_uint(*uniform_offset_src) == 0);
/* Nak can't collect vectors in non uniform control flow, so don't
* even try */
if (offset->bit_size == 64 && nak_block_is_divergent(use_block))
return false;
/* We ignore any constant offset */
if (nir_src_is_const(*offset_src))
return false;
/* If the source is already uniform, just swap them as the uniform slot
* should be 0 */
if (!nir_def_is_divergent_at_use_block(offset, use_block)) {
if (is_divergent_phi(nir_def_instr(offset)))
return false;
nir_src_rewrite(uniform_offset_src, offset);
nir_src_rewrite(offset_src, uniform_offset);
return true;
}
nir_alu_instr *iadd = nir_def_as_alu_or_null(offset_src->ssa);
if (!iadd || iadd->op != nir_op_iadd)
return false;
unsigned src0_div = nir_def_is_divergent_at_use_block(iadd->src[0].src.ssa, use_block);
unsigned src1_div = nir_def_is_divergent_at_use_block(iadd->src[1].src.ssa, use_block);
if (src0_div && src1_div)
return false;
b->cursor = nir_before_instr(&intr->instr);
nir_def *addr, *uaddr;
if (src0_div) {
assert(!src1_div);
addr = nir_ssa_for_alu_src(b, iadd, 0);
uaddr = nir_ssa_for_alu_src(b, iadd, 1);
} else {
assert(src1_div);
addr = nir_ssa_for_alu_src(b, iadd, 1);
uaddr = nir_ssa_for_alu_src(b, iadd, 0);
}
if (is_divergent_phi(nir_def_instr(uaddr)))
return false;
/* We can remove a u2u64 on the non uniform src */
if (addr->bit_size == 64) {
nir_alu_instr *u2u64 = nir_def_as_alu_or_null(addr);
if (u2u64 && u2u64->op == nir_op_u2u64)
addr = nir_ssa_for_alu_src(b, u2u64, 0);
}
nir_src_rewrite(offset_src, addr);
nir_src_rewrite(uniform_offset_src, uaddr);
return true;
}
default:
return false;
}
}
/** This pass assumes it is ran after nir_opt_offset */
static bool
nak_nir_opt_uniform_address(nir_shader *nir)
{
if (nak_debug_no_ugpr())
return false;
nir_divergence_analysis(nir);
return nir_shader_intrinsics_pass(
nir,
nak_nir_opt_uniform_address_impl,
nir_metadata_control_flow,
NULL
);
}
static bool
nak_nir_opt_offset_shift_nv_impl(struct nir_builder *b,
nir_intrinsic_instr *intrin, void *data)
@ -1318,6 +1440,12 @@ nak_postprocess_nir(nir_shader *nir,
.cb_data = nak,
};
OPT(nir, nir_opt_offsets, &nak_offset_options);
if (nak->sm >= 73) {
OPT(nir, nak_nir_opt_uniform_address);
/* TODO: as we eliminate u2u64s we could fold more offsets in, however
* This would require us to verify it doesn't overflow, which we can't. */
/* OPT(nir, nir_opt_offsets, &nak_offset_options); */
}
/* Should run after nir_opt_offsets, because nir_opt_algebraic will move
* iadds down the chain */

View file

@ -376,6 +376,39 @@ lower_cf_list(nir_builder *b, nir_def *esc_reg, struct scope *parent_scope,
}
}
bool
nak_nir_phi_is_divergent(nir_phi_instr *phi)
{
bool divergent = false;
nir_foreach_phi_src(phi_src, phi) {
/* There is a tricky case we need to care about here where a
* convergent block has a divergent dominator. This can happen
* if, for instance, you have the following loop:
*
* loop {
* if (div) {
* %20 = load_ubo(0, 0);
* } else {
* terminate;
* }
* }
* use(%20);
*
* In this case, the load_ubo() dominates the use() even though
* the load_ubo() exists in divergent control-flow. In this
* case, we simply flag the whole phi divergent because we
* don't want to deal with inserting a r2ur somewhere.
*/
if (phi_src->pred->divergent || phi_src->src.ssa->divergent ||
nir_def_block(phi_src->src.ssa)->divergent) {
divergent = true;
break;
}
}
return divergent;
}
static void
recompute_phi_divergence_impl(nir_function_impl *impl)
{
@ -388,33 +421,7 @@ recompute_phi_divergence_impl(nir_function_impl *impl)
break;
nir_phi_instr *phi = nir_instr_as_phi(instr);
bool divergent = false;
nir_foreach_phi_src(phi_src, phi) {
/* There is a tricky case we need to care about here where a
* convergent block has a divergent dominator. This can happen
* if, for instance, you have the following loop:
*
* loop {
* if (div) {
* %20 = load_ubo(0, 0);
* } else {
* terminate;
* }
* }
* use(%20);
*
* In this case, the load_ubo() dominates the use() even though
* the load_ubo() exists in divergent control-flow. In this
* case, we simply flag the whole phi divergent because we
* don't want to deal with inserting a r2ur somewhere.
*/
if (phi_src->pred->divergent || phi_src->src.ssa->divergent ||
nir_def_block(phi_src->src.ssa)->divergent) {
divergent = true;
break;
}
}
bool divergent = nak_nir_phi_is_divergent(phi);
if (divergent != phi->def.divergent) {
phi->def.divergent = divergent;

View file

@ -723,6 +723,7 @@ try_lower_cmat_load_to_ldsm(nir_builder *b, nir_intrinsic_instr *intr)
nir_def *base = intr->src[1].ssa;
offset = nir_u2uN(b, offset, base->bit_size);
nir_def *addr = nir_iadd(b, base, offset);
nir_def *zero = nir_imm_zero(b, addr->num_components, addr->bit_size);
/* flip the layout for B matrices */
if (desc.use == GLSL_CMAT_USE_B) {
@ -734,7 +735,7 @@ try_lower_cmat_load_to_ldsm(nir_builder *b, nir_intrinsic_instr *intr)
/* Each thread loads 32 bits per matrix */
assert(length * bit_size == 32 * ldsm_count);
return nir_cmat_load_shared_nv(b, length, bit_size, addr,
return nir_cmat_load_shared_nv(b, length, bit_size, addr, zero,
.num_matrices = ldsm_count,
.matrix_layout = layout);
}

View file

@ -56,10 +56,12 @@ lower_ldcx_to_global(nir_builder *b, nir_intrinsic_instr *load,
* simple less-than check here.
*/
nir_def *cond = nir_ilt(b, offset, size);
nir_def *zero_addr = nir_imm_zero(b, addr->num_components,
addr->bit_size);
nir_def *val = nir_load_global_nv(b,
load->def.num_components, load->def.bit_size,
nir_iadd(b, addr, nir_u2u64(b, offset)),
cond,
zero_addr, cond,
.align_mul = nir_intrinsic_align_mul(load),
.align_offset = nir_intrinsic_align_offset(load),
.access = ACCESS_CAN_REORDER,

View file

@ -370,6 +370,7 @@ bool nak_nir_lower_cmat(nir_shader *shader, const struct nak_compiler *nak);
* writing uregs from these blocks.
*/
bool nak_block_is_divergent(const nir_block *block);
bool nak_nir_phi_is_divergent(nir_phi_instr *phi);
void nak_optimize_nir(nir_shader *nir, const struct nak_compiler *nak);