nak/calc_instr_deps: Rewrite calc_delays() again

This time we take into account WaR and WaW dependencies and not just RaW
dependencies.  The NVIDIA ISA is actually quite dynamic and the not
everything is nicely pipelined such that writes always happen at
consistent cycles.  There are exact rules, of course, but we don't know
what those are so we need to make some worst-case assumptions.

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/29591>
This commit is contained in:
Faith Ekstrand 2024-06-03 18:30:04 -05:00 committed by Marge Bot
parent 434af5b98b
commit 2d4e445099
2 changed files with 174 additions and 81 deletions

View file

@ -17,18 +17,6 @@ struct RegTracker<T> {
carry: [T; 1], carry: [T; 1],
} }
impl<T: Copy> RegTracker<T> {
pub fn new(v: T) -> Self {
Self {
reg: [v; 255],
ureg: [v; 63],
pred: [v; 7],
upred: [v; 7],
carry: [v; 1],
}
}
}
fn new_array_with<T, const N: usize>(f: &impl Fn() -> T) -> [T; N] { fn new_array_with<T, const N: usize>(f: &impl Fn() -> T) -> [T; N] {
let mut v = Vec::new(); let mut v = Vec::new();
for _ in 0..N { for _ in 0..N {
@ -64,12 +52,12 @@ impl<T> RegTracker<T> {
pub fn for_each_instr_src_mut( pub fn for_each_instr_src_mut(
&mut self, &mut self,
instr: &Instr, instr: &Instr,
mut f: impl FnMut(&mut T), mut f: impl FnMut(usize, &mut T),
) { ) {
for src in instr.srcs() { for (i, src) in instr.srcs().iter().enumerate() {
if let SrcRef::Reg(reg) = &src.src_ref { if let SrcRef::Reg(reg) = &src.src_ref {
for i in &mut self[*reg] { for t in &mut self[*reg] {
f(i); f(i, t);
} }
} }
} }
@ -78,12 +66,12 @@ impl<T> RegTracker<T> {
pub fn for_each_instr_dst_mut( pub fn for_each_instr_dst_mut(
&mut self, &mut self,
instr: &Instr, instr: &Instr,
mut f: impl FnMut(&mut T), mut f: impl FnMut(usize, &mut T),
) { ) {
for dst in instr.dsts() { for (i, dst) in instr.dsts().iter().enumerate() {
if let Dst::Reg(reg) = dst { if let Dst::Reg(reg) = dst {
for i in &mut self[*reg] { for t in &mut self[*reg] {
f(i); f(i, t);
} }
} }
} }
@ -133,14 +121,14 @@ impl<T> IndexMut<RegRef> for RegTracker<T> {
} }
#[derive(Clone)] #[derive(Clone)]
enum RegUse { enum RegUse<T: Clone> {
None, None,
Write(usize), Write(T),
Reads(Vec<usize>), Reads(Vec<T>),
} }
impl RegUse { impl<T: Clone> RegUse<T> {
pub fn deps(&self) -> &[usize] { pub fn deps(&self) -> &[T] {
match self { match self {
RegUse::None => &[], RegUse::None => &[],
RegUse::Write(dep) => slice::from_ref(dep), RegUse::Write(dep) => slice::from_ref(dep),
@ -148,11 +136,11 @@ impl RegUse {
} }
} }
pub fn clear(&mut self) -> RegUse { pub fn clear(&mut self) -> Self {
std::mem::replace(self, RegUse::None) std::mem::replace(self, RegUse::None)
} }
pub fn clear_write(&mut self) -> RegUse { pub fn clear_write(&mut self) -> Self {
if matches!(self, RegUse::Write(_)) { if matches!(self, RegUse::Write(_)) {
std::mem::replace(self, RegUse::None) std::mem::replace(self, RegUse::None)
} else { } else {
@ -160,7 +148,7 @@ impl RegUse {
} }
} }
pub fn add_read(&mut self, dep: usize) -> RegUse { pub fn add_read(&mut self, dep: T) -> Self {
match self { match self {
RegUse::None => { RegUse::None => {
*self = RegUse::Reads(vec![dep]); *self = RegUse::Reads(vec![dep]);
@ -176,7 +164,7 @@ impl RegUse {
} }
} }
pub fn set_write(&mut self, dep: usize) -> RegUse { pub fn set_write(&mut self, dep: T) -> Self {
std::mem::replace(self, RegUse::Write(dep)) std::mem::replace(self, RegUse::Write(dep))
} }
} }
@ -383,17 +371,17 @@ fn assign_barriers(f: &mut Function, sm: u8) {
if instr.has_fixed_latency(sm) { if instr.has_fixed_latency(sm) {
// Delays will cover us here. We just need to make sure // Delays will cover us here. We just need to make sure
// that we wait on any uses that we consume. // that we wait on any uses that we consume.
uses.for_each_instr_src_mut(instr, |u| { uses.for_each_instr_src_mut(instr, |_, u| {
let u = u.clear_write(); let u = u.clear_write();
waits.extend_from_slice(u.deps()); waits.extend_from_slice(u.deps());
}); });
uses.for_each_instr_dst_mut(instr, |u| { uses.for_each_instr_dst_mut(instr, |_, u| {
let u = u.clear(); let u = u.clear();
waits.extend_from_slice(u.deps()); waits.extend_from_slice(u.deps());
}); });
} else { } else {
let (rd, wr) = deps.add_instr(bi, ip); let (rd, wr) = deps.add_instr(bi, ip);
uses.for_each_instr_src_mut(instr, |u| { uses.for_each_instr_src_mut(instr, |_, u| {
// Only mark a dep as signaled if we actually have // Only mark a dep as signaled if we actually have
// something that shows up in the register file as // something that shows up in the register file as
// needing scoreboarding // needing scoreboarding
@ -401,7 +389,7 @@ fn assign_barriers(f: &mut Function, sm: u8) {
let u = u.add_read(rd); let u = u.add_read(rd);
waits.extend_from_slice(u.deps()); waits.extend_from_slice(u.deps());
}); });
uses.for_each_instr_dst_mut(instr, |u| { uses.for_each_instr_dst_mut(instr, |_, u| {
// Only mark a dep as signaled if we actually have // Only mark a dep as signaled if we actually have
// something that shows up in the register file as // something that shows up in the register file as
// needing scoreboarding // needing scoreboarding
@ -464,30 +452,160 @@ fn assign_barriers(f: &mut Function, sm: u8) {
} }
} }
fn exec_latency(sm: u8, op: &Op) -> u32 {
match op {
Op::Bar(_) | Op::MemBar(_) => {
if sm >= 80 {
6
} else {
5
}
}
Op::CCtl(_op) => {
// CCTL.C needs 8, CCTL.I needs 11
11
}
// Op::DepBar(_) => 4,
_ => 1, // TODO: co-issue
}
}
fn instr_latency(op: &Op, dst_idx: usize) -> u32 {
let file = match op.dsts_as_slice()[dst_idx] {
Dst::None => return 0,
Dst::SSA(vec) => vec.file().unwrap(),
Dst::Reg(reg) => reg.file(),
};
// This is BS and we know it
if file.is_predicate() {
13
} else {
6
}
}
/// Read-after-write latency
fn raw_latency(
_sm: u8,
write: &Op,
dst_idx: usize,
_read: &Op,
_src_idx: usize,
) -> u32 {
instr_latency(write, dst_idx)
}
/// Write-after-read latency
fn war_latency(
_sm: u8,
_read: &Op,
_src_idx: usize,
_write: &Op,
_dst_idx: usize,
) -> u32 {
// We assume the source gets read in the first 4 cycles. We don't know how
// quickly the write will happen. This is all a guess.
4
}
/// Write-after-write latency
fn waw_latency(
_sm: u8,
a: &Op,
a_dst_idx: usize,
_b: &Op,
_b_dst_idx: usize,
) -> u32 {
// We know our latencies are wrong so assume the wrote could happen anywhere
// between 0 and instr_latency(a) cycles
instr_latency(a, a_dst_idx)
}
/// Predicate read-after-write latency
fn paw_latency(_sm: u8, _write: &Op, _dst_idx: usize) -> u32 {
13
}
fn calc_delays(f: &mut Function, sm: u8) { fn calc_delays(f: &mut Function, sm: u8) {
for b in f.blocks.iter_mut().rev() { for b in f.blocks.iter_mut().rev() {
let mut cycle = 0_u32; let mut cycle = 0_u32;
let mut reads = RegTracker::new(0_u32);
// Vector mapping IP to start cycle
let mut instr_cycle = Vec::new();
instr_cycle.resize(b.instrs.len(), 0_u32);
// Maps registers to RegUse<ip, src_dst_idx>. Predicates are
// represented by src_idx = usize::MAX.
let mut uses: RegTracker<RegUse<(usize, usize)>> =
RegTracker::new_with(&|| RegUse::None);
// Map from barrier to last waited cycle
let mut bars = [0_u32; 6]; let mut bars = [0_u32; 6];
for instr in b.instrs.iter_mut().rev() {
// TODO: co-issue for ip in (0..b.instrs.len()).rev() {
let mut min_start = cycle + instr.get_exec_latency(sm); let instr = &b.instrs[ip];
let mut min_start = cycle + exec_latency(sm, &instr.op);
if let Some(bar) = instr.deps.rd_bar() { if let Some(bar) = instr.deps.rd_bar() {
min_start = max(min_start, bars[usize::from(bar)] + 2); min_start = max(min_start, bars[usize::from(bar)] + 2);
} }
if let Some(bar) = instr.deps.wr_bar() { if let Some(bar) = instr.deps.wr_bar() {
min_start = max(min_start, bars[usize::from(bar)] + 2); min_start = max(min_start, bars[usize::from(bar)] + 2);
} }
if instr.has_fixed_latency(sm) { uses.for_each_instr_dst_mut(instr, |i, u| match u {
for (idx, dst) in instr.dsts().iter().enumerate() { RegUse::None => {
if let Dst::Reg(reg) = dst { // We don't know how it will be used but it may be used in
let latency = instr.get_dst_latency(sm, idx); // the next block so we need at least assume the maximum
for c in &reads[*reg] { // destination latency from the end of the block.
min_start = max(min_start, *c + latency); let s = instr_latency(&instr.op, i);
} min_start = max(min_start, s);
}
RegUse::Write((w_ip, w_dst_idx)) => {
let s = instr_cycle[*w_ip]
+ waw_latency(
sm,
&instr.op,
i,
&b.instrs[*w_ip].op,
*w_dst_idx,
);
min_start = max(min_start, s);
}
RegUse::Reads(reads) => {
for (r_ip, r_src_idx) in reads {
let c = instr_cycle[*r_ip];
let s = if *r_src_idx == usize::MAX {
c + paw_latency(sm, &instr.op, i)
} else {
c + raw_latency(
sm,
&instr.op,
i,
&b.instrs[*r_ip].op,
*r_src_idx,
)
};
min_start = max(min_start, s);
} }
} }
} });
uses.for_each_instr_src_mut(instr, |i, u| match u {
RegUse::None => (),
RegUse::Write((w_ip, w_dst_idx)) => {
let s = instr_cycle[*w_ip]
+ war_latency(
sm,
&instr.op,
i,
&b.instrs[*w_ip].op,
*w_dst_idx,
);
min_start = max(min_start, s);
}
RegUse::Reads(_) => (),
});
let instr = &mut b.instrs[ip];
let delay = min_start - cycle; let delay = min_start - cycle;
let delay = delay let delay = delay
@ -496,8 +614,16 @@ fn calc_delays(f: &mut Function, sm: u8) {
.unwrap(); .unwrap();
instr.deps.set_delay(delay); instr.deps.set_delay(delay);
reads.for_each_instr_pred_mut(instr, |c| *c = min_start); instr_cycle[ip] = min_start;
reads.for_each_instr_src_mut(instr, |c| *c = min_start); uses.for_each_instr_pred_mut(instr, |c| {
c.add_read((ip, usize::MAX));
});
uses.for_each_instr_src_mut(instr, |i, c| {
c.add_read((ip, i));
});
uses.for_each_instr_dst_mut(instr, |i, c| {
c.set_write((ip, i));
});
for (bar, c) in bars.iter_mut().enumerate() { for (bar, c) in bars.iter_mut().enumerate() {
if instr.deps.wt_bar_mask & (1 << bar) != 0 { if instr.deps.wt_bar_mask & (1 << bar) != 0 {
*c = min_start; *c = min_start;
@ -516,7 +642,7 @@ fn calc_delays(f: &mut Function, sm: u8) {
if matches!(instr.op, Op::SrcBar(_)) { if matches!(instr.op, Op::SrcBar(_)) {
instr.op = Op::Nop(OpNop { label: None }); instr.op = Op::Nop(OpNop { label: None });
MappedInstrs::One(instr) MappedInstrs::One(instr)
} else if instr.get_exec_latency(sm) > 1 { } else if exec_latency(sm, &instr.op) > 1 {
let mut nop = Instr::new_boxed(OpNop { label: None }); let mut nop = Instr::new_boxed(OpNop { label: None });
nop.deps.set_delay(2); nop.deps.set_delay(2);
MappedInstrs::Many(vec![instr, nop]) MappedInstrs::Many(vec![instr, nop])

View file

@ -5832,39 +5832,6 @@ impl Instr {
} }
} }
/// Minimum latency before another instruction can execute
pub fn get_exec_latency(&self, sm: u8) -> u32 {
match &self.op {
Op::Bar(_) | Op::MemBar(_) => {
if sm >= 80 {
6
} else {
5
}
}
Op::CCtl(_op) => {
// CCTL.C needs 8, CCTL.I needs 11
11
}
// Op::DepBar(_) => 4,
_ => 1, // TODO: co-issue
}
}
pub fn get_dst_latency(&self, sm: u8, dst_idx: usize) -> u32 {
debug_assert!(self.has_fixed_latency(sm));
let file = match self.dsts()[dst_idx] {
Dst::None => return 0,
Dst::SSA(vec) => vec.file().unwrap(),
Dst::Reg(reg) => reg.file(),
};
if file.is_predicate() {
13
} else {
6
}
}
pub fn needs_yield(&self) -> bool { pub fn needs_yield(&self) -> bool {
matches!(&self.op, Op::Bar(_) | Op::BSync(_)) matches!(&self.op, Op::Bar(_) | Op::BSync(_))
} }