nak: Add a non-trivial register allocator

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/24998>
This commit is contained in:
Faith Ekstrand 2023-04-10 17:23:27 -05:00 committed by Marge Bot
parent 40fbf6bed2
commit af752f73dc
3 changed files with 831 additions and 2 deletions

View file

@ -332,7 +332,9 @@ pub extern "C" fn nak_compile_shader(
println!("NAK IR:\n{}", &s);
s.assign_regs_trivial();
s.assign_regs();
//s.assign_regs_trivial();
println!("NAK IR:\n{}", &s);
s.lower_vec_split();
s.lower_par_copies();
s.lower_swap();

View file

@ -5,10 +5,809 @@
#![allow(unstable_name_collisions)]
use crate::bitset::BitSet;
use crate::nak_ir::*;
use crate::nak_liveness::{BlockLiveness, Liveness};
use crate::util::NextMultipleOf;
use std::collections::HashMap;
use std::cmp::{max, Ordering};
use std::collections::{HashMap, HashSet};
struct KillSet {
set: HashSet<SSAValue>,
vec: Vec<SSAValue>,
}
impl KillSet {
pub fn new() -> KillSet {
KillSet {
set: HashSet::new(),
vec: Vec::new(),
}
}
pub fn clear(&mut self) {
self.set.clear();
self.vec.clear();
}
pub fn insert(&mut self, ssa: SSAValue) {
if self.set.insert(ssa) {
self.vec.push(ssa);
}
}
pub fn contains(&self, ssa: &SSAValue) -> bool {
self.set.contains(ssa)
}
pub fn iter(&self) -> std::slice::Iter<'_, SSAValue> {
self.vec.iter()
}
}
#[derive(Clone, Copy, Eq, Hash, PartialEq)]
struct PhiComp {
v: SSAValue,
}
impl PhiComp {
pub fn new(idx: u32, comp: u8) -> PhiComp {
PhiComp {
v: SSAValue::new(RegFile::GPR, idx, comp + 1),
}
}
pub fn idx(&self) -> u32 {
self.v.idx()
}
pub fn comp(&self) -> u8 {
self.v.comps() - 1
}
}
#[derive(Clone, Copy, Eq, Hash, PartialEq)]
enum LiveRef {
SSA(SSAComp),
Phi(PhiComp),
}
#[derive(Clone, Copy, Eq, Hash, PartialEq)]
struct LiveValue {
pub live_ref: LiveRef,
pub reg_ref: RegRef,
}
/* We need a stable ordering of live values so that RA is deterministic */
impl Ord for LiveValue {
fn cmp(&self, other: &Self) -> Ordering {
let s_file = u8::from(self.reg_ref.file());
let o_file = u8::from(other.reg_ref.file());
match s_file.cmp(&o_file) {
Ordering::Equal => {
let s_idx = self.reg_ref.base_idx();
let o_idx = other.reg_ref.base_idx();
s_idx.cmp(&o_idx)
}
ord => ord,
}
}
}
impl PartialOrd for LiveValue {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
#[derive(Clone)]
struct RegFileAllocation {
file: RegFile,
max_reg: u8,
used: BitSet,
pinned: BitSet,
reg_ssa: Vec<SSAComp>,
ssa_reg: HashMap<SSAComp, u8>,
}
impl RegFileAllocation {
pub fn new(file: RegFile, sm: u8) -> Self {
Self {
file: file,
max_reg: file.num_regs(sm) - 1,
used: BitSet::new(),
pinned: BitSet::new(),
reg_ssa: Vec::new(),
ssa_reg: HashMap::new(),
}
}
fn file(&self) -> RegFile {
self.file
}
pub fn begin_alloc(&mut self) {
self.pinned.clear();
}
pub fn end_alloc(&mut self) {}
fn is_reg_in_bounds(&self, reg: u8, comps: u8) -> bool {
if let Some(max_reg) = reg.checked_add(comps - 1) {
max_reg <= self.max_reg
} else {
false
}
}
pub fn get_reg_comp(&self, ssa: SSAComp) -> u8 {
*self.ssa_reg.get(&ssa).unwrap()
}
pub fn get_ssa_comp(&self, reg: u8) -> Option<SSAComp> {
if self.used.get(reg.into()) {
Some(self.reg_ssa[usize::from(reg)])
} else {
None
}
}
pub fn try_get_reg(&self, ssa: SSAValue) -> Option<u8> {
let align = ssa.comps().next_power_of_two();
let reg = self.get_reg_comp(ssa.comp(0));
if reg % align == 0 {
for i in 1..ssa.comps() {
if self.get_reg_comp(ssa.comp(i)) != reg + i {
return None;
}
}
Some(reg)
} else {
None
}
}
pub fn free_ssa_comp(&mut self, ssa: SSAComp) -> u8 {
assert!(ssa.file() == self.file);
let reg = self.ssa_reg.remove(&ssa).unwrap();
assert!(self.used.get(reg.into()));
self.used.remove(reg.into());
reg
}
pub fn free_ssa(&mut self, ssa: SSAValue) {
for i in 0..ssa.comps() {
self.free_ssa_comp(ssa.comp(i));
}
}
pub fn free_killed(&mut self, killed: &KillSet) {
for ssa in killed.iter() {
if ssa.file() == self.file {
self.free_ssa(*ssa);
}
}
}
pub fn assign_reg_comp(&mut self, ssa: SSAComp, reg: u8) -> RegRef {
assert!(ssa.file() == self.file);
assert!(reg <= self.max_reg);
assert!(!self.used.get(reg.into()));
if usize::from(reg) >= self.reg_ssa.len() {
self.reg_ssa
.resize(usize::from(reg) + 1, SSAComp::new(RegFile::GPR, 0, 0));
}
self.reg_ssa[usize::from(reg)] = ssa;
self.ssa_reg.insert(ssa, reg);
self.used.insert(reg.into());
self.pinned.insert(reg.into());
RegRef::new(self.file, reg, 1)
}
pub fn assign_reg(&mut self, ssa: SSAValue, reg: u8) -> RegRef {
for i in 0..ssa.comps() {
self.assign_reg_comp(ssa.comp(i), reg + i);
}
RegRef::new(self.file, reg, ssa.comps())
}
pub fn try_assign_reg(&mut self, ssa: SSAValue, reg: u8) -> Option<RegRef> {
if ssa.file() != self.file() {
return None;
}
if !self.is_reg_in_bounds(reg, ssa.comps()) {
return None;
}
for c in 0..ssa.comps() {
if self.used.get((reg + c).into()) {
return None;
}
}
Some(self.assign_reg(ssa, reg))
}
pub fn try_find_unused_reg(&self, comps: u8) -> Option<u8> {
assert!(comps > 0);
let comps_mask = u32::MAX >> (32 - comps);
let align = comps.next_power_of_two();
for (w, word) in self.used.words().iter().enumerate() {
let mut avail = !word;
if w < self.pinned.words().len() {
avail &= !self.pinned.words()[w];
}
while avail != 0 {
let bit = u8::try_from(avail.trailing_zeros()).unwrap();
/* Ensure we're properly aligned */
if bit & (align - 1) != 0 {
avail &= !(1 << bit);
continue;
}
let mask = comps_mask << bit;
if avail & mask == mask {
let reg = u8::try_from(w * 32).unwrap() + bit;
if self.is_reg_in_bounds(reg, comps) {
return Some(reg);
} else {
return None;
}
}
avail &= !mask;
}
}
if let Ok(reg) = u8::try_from(self.used.words().len() * 32) {
if self.is_reg_in_bounds(reg, comps) {
Some(reg)
} else {
None
}
} else {
None
}
}
fn get_reg_near_reg(&self, reg: u8, comps: u8) -> u8 {
let align = comps.next_power_of_two();
/* Pick something properly aligned near component 0 */
let mut reg = reg & (align - 1);
if !self.is_reg_in_bounds(reg, comps) {
reg -= align;
}
reg
}
pub fn get_reg_near_ssa(&self, ssa: SSAValue) -> u8 {
/* Get something near component 0 */
self.get_reg_near_reg(self.get_reg_comp(ssa.comp(0)), ssa.comps())
}
pub fn get_any_reg(&self, comps: u8) -> u8 {
let mut pick_comps = comps;
while pick_comps > 0 {
if let Some(reg) = self.try_find_unused_reg(pick_comps) {
return self.get_reg_near_reg(reg, comps);
}
pick_comps = pick_comps >> 1;
}
panic!("Failed to find any free registers");
}
pub fn get_scalar(&mut self, ssa: SSAComp) -> RegRef {
assert!(ssa.file() == self.file);
let reg = self.get_reg_comp(ssa);
self.pinned.insert(reg.into());
RegRef::new(self.file, reg, 1)
}
pub fn move_to_reg(
&mut self,
pcopy: &mut OpParCopy,
ssa: SSAValue,
reg: u8,
) -> RegRef {
for c in 0..ssa.comps() {
let old_reg = self.get_reg_comp(ssa.comp(c));
if old_reg == reg + c {
continue;
}
self.free_ssa_comp(ssa.comp(c));
/* If something already exists in the destination, swap it to the
* source.
*/
if let Some(evicted) = self.get_ssa_comp(reg + c) {
self.free_ssa_comp(evicted);
pcopy.srcs.push(RegRef::new(self.file, reg + c, 1).into());
pcopy.dsts.push(RegRef::new(self.file, old_reg, 1).into());
self.assign_reg_comp(evicted, old_reg);
}
pcopy.srcs.push(RegRef::new(self.file, old_reg, 1).into());
pcopy.dsts.push(RegRef::new(self.file, reg + c, 1).into());
self.assign_reg_comp(ssa.comp(c), reg + c);
}
RegRef::new(self.file, reg, ssa.comps())
}
pub fn get_vector(
&mut self,
pcopy: &mut OpParCopy,
ssa: SSAValue,
) -> RegRef {
let reg = if let Some(reg) = self.try_get_reg(ssa) {
reg
} else if let Some(reg) = self.try_find_unused_reg(ssa.comps()) {
reg
} else {
self.get_reg_near_ssa(ssa)
};
self.move_to_reg(pcopy, ssa, reg)
}
pub fn alloc_scalar(&mut self, ssa: SSAComp) -> RegRef {
let reg = self.try_find_unused_reg(1).unwrap();
self.assign_reg_comp(ssa, reg)
}
}
fn instr_remap_srcs_file(
instr: &mut Instr,
pcopy: &mut OpParCopy,
ra: &mut RegFileAllocation,
) {
if let Pred::SSA(pred) = instr.pred {
if pred.file() == ra.file() {
instr.pred = ra.get_scalar(pred.as_comp()).into();
}
}
for src in instr.srcs_mut() {
if let SrcRef::SSA(ssa) = src.src_ref {
if ssa.file() == ra.file() {
src.src_ref = ra.get_vector(pcopy, ssa).into();
}
}
}
}
fn instr_alloc_scalar_dsts_file(instr: &mut Instr, ra: &mut RegFileAllocation) {
for dst in instr.dsts_mut() {
if let Dst::SSA(ssa) = dst {
if ssa.file() == ra.file() {
*dst = ra.alloc_scalar(ssa.as_comp()).into();
}
}
}
}
fn instr_assign_regs_file(
instr: &mut Instr,
killed: &KillSet,
pcopy: &mut OpParCopy,
ra: &mut RegFileAllocation,
) {
let mut max_killed_vec_comps = 0;
let mut total_killed_vec_comps = 0;
for ssa in killed.iter() {
if ssa.file() == ra.file() && ssa.comps() > 1 {
max_killed_vec_comps = max(max_killed_vec_comps, ssa.comps());
total_killed_vec_comps += ssa.comps();
}
}
let mut vec_dst = None;
for (i, dst) in instr.dsts().iter().enumerate() {
if let Dst::SSA(ssa) = dst {
if ssa.file() == ra.file() {
if ssa.comps() > 1 {
assert!(vec_dst.is_none());
vec_dst = Some(i)
}
}
}
}
ra.begin_alloc();
if let Some(vec_dst) = vec_dst {
/* Predicates can't be vectors. This lets us ignore instr.pred in our
* analysis in the cases below. We only have to handle it in the final
* simple case for scalar destinations.
*/
assert!(!ra.file().is_predicate());
let vec_dst_comps = instr.dsts()[vec_dst].as_ssa().unwrap().comps();
if let Some(reg) = ra.try_find_unused_reg(vec_dst_comps) {
let vec_dst = &mut instr.dsts_mut()[vec_dst];
let vec_ssa = *vec_dst.as_ssa().unwrap();
*vec_dst = ra.assign_reg(vec_ssa, reg).into();
instr_remap_srcs_file(instr, pcopy, ra);
ra.free_killed(killed);
instr_alloc_scalar_dsts_file(instr, ra);
} else if vec_dst_comps <= max_killed_vec_comps {
instr_remap_srcs_file(instr, pcopy, ra);
let mut reg = None;
for ssa in killed.iter() {
if ssa.file() == ra.file() {
if ssa.comps() >= vec_dst_comps {
reg = Some(ra.try_get_reg(*ssa).unwrap());
}
ra.free_ssa(*ssa);
}
}
assert!(reg.is_some());
let vec_dst = &mut instr.dsts_mut()[vec_dst];
let vec_ssa = *vec_dst.as_ssa().unwrap();
*vec_dst = ra.assign_reg(vec_ssa, reg.unwrap()).into();
instr_alloc_scalar_dsts_file(instr, ra);
} else {
let vec_comps = max(max_killed_vec_comps, vec_dst_comps);
let vec_reg = ra.get_any_reg(vec_comps);
let mut ssa_reg = HashMap::new();
let mut src_vec_reg = vec_reg;
for src in instr.srcs_mut() {
if let SrcRef::SSA(ssa) = src.src_ref {
if ssa.file() == ra.file() {
if killed.contains(&ssa) && ssa.comps() > 1 {
let reg =
*ssa_reg.entry(ssa).or_insert_with(|| {
let align = ssa.comps().next_power_of_two();
let reg = src_vec_reg;
src_vec_reg += ssa.comps();
assert!(reg % align == 0);
ra.move_to_reg(pcopy, ssa, reg)
});
src.src_ref = reg.into();
}
}
}
}
/* Handle the scalar and not killed sources */
instr_remap_srcs_file(instr, pcopy, ra);
ra.free_killed(killed);
let vec_dst = &mut instr.dsts_mut()[vec_dst];
let vec_ssa = *vec_dst.as_ssa().unwrap();
*vec_dst = ra.assign_reg(vec_ssa, vec_reg).into();
instr_alloc_scalar_dsts_file(instr, ra);
}
} else {
instr_remap_srcs_file(instr, pcopy, ra);
ra.free_killed(killed);
instr_alloc_scalar_dsts_file(instr, ra);
}
ra.end_alloc();
}
#[derive(Clone)]
struct RegAllocation {
files: [RegFileAllocation; 4],
phi_ssa: HashMap<u32, SSAValue>,
}
impl RegAllocation {
pub fn new(sm: u8) -> Self {
Self {
files: [
RegFileAllocation::new(RegFile::GPR, sm),
RegFileAllocation::new(RegFile::UGPR, sm),
RegFileAllocation::new(RegFile::Pred, sm),
RegFileAllocation::new(RegFile::UPred, sm),
],
phi_ssa: HashMap::new(),
}
}
pub fn file(&self, file: RegFile) -> &RegFileAllocation {
for f in &self.files {
if f.file() == file {
return f;
}
}
panic!("Unknown register file");
}
pub fn file_mut(&mut self, file: RegFile) -> &mut RegFileAllocation {
for f in &mut self.files {
if f.file() == file {
return f;
}
}
panic!("Unknown register file");
}
pub fn free_ssa(&mut self, ssa: SSAValue) {
self.file_mut(ssa.file()).free_ssa(ssa);
}
pub fn free_killed(&mut self, killed: &KillSet) {
for ssa in killed.iter() {
self.free_ssa(*ssa);
}
}
pub fn get_scalar(&mut self, ssa: SSAComp) -> RegRef {
self.file_mut(ssa.file()).get_scalar(ssa)
}
pub fn alloc_scalar(&mut self, ssa: SSAComp) -> RegRef {
self.file_mut(ssa.file()).alloc_scalar(ssa)
}
}
struct AssignRegsBlock {
ra: RegAllocation,
live_in: Vec<LiveValue>,
phi_out: HashMap<PhiComp, SrcRef>,
}
impl AssignRegsBlock {
fn new(ra: RegAllocation) -> AssignRegsBlock {
AssignRegsBlock {
ra: ra,
live_in: Vec::new(),
phi_out: HashMap::new(),
}
}
fn assign_regs_split(
&mut self,
split: &OpSplit,
killed: &KillSet,
pcopy: &mut OpParCopy,
) {
let src = split.src.src_ref.as_ssa().unwrap();
let comps = src.comps();
assert!(usize::from(comps) == split.dsts.len());
let mut coalesced = BitSet::new();
if killed.contains(src) {
for c in 0..comps {
/* Feee the component regardless of any dest checks */
let src_ra = self.ra.file_mut(src.file());
let reg = src_ra.free_ssa_comp(src.comp(c));
let src_ref = RegRef::new(src.file(), reg, 1);
/* If we have an OpSplit which kills its source, we can coalesce
* on the spot into the destinations.
*/
if let Dst::SSA(dst) = &split.dsts[usize::from(c)] {
if dst.file() == src.file() {
/* Assign destinations to source components when the
* register files match.
*/
let dst_ra = src_ra;
dst_ra.assign_reg_comp(dst.as_comp(), reg);
coalesced.insert(c.into());
} else {
/* Otherwise, they come from different files so
* allocating a destination register won't affect the
* source and it's okay to alloc before we've finished
* freeing the source.
*/
let dst_ra = self.ra.file_mut(dst.file());
let dst_ref = dst_ra.alloc_scalar(dst.as_comp());
pcopy.srcs.push(src_ref.into());
pcopy.dsts.push(dst_ref.into());
}
}
}
} else {
for c in 0..comps {
if let Dst::SSA(dst) = &split.dsts[usize::from(c)] {
pcopy.srcs.push(self.ra.get_scalar(src.comp(c)).into());
pcopy.dsts.push(self.ra.alloc_scalar(dst.as_comp()).into());
}
}
}
}
fn assign_regs_instr(
&mut self,
mut instr: Instr,
killed: &KillSet,
pcopy: &mut OpParCopy,
) -> Option<Instr> {
match &instr.op {
Op::Split(split) => {
assert!(instr.pred.is_none());
assert!(split.src.src_mod.is_none());
self.assign_regs_split(split, killed, pcopy);
None
}
Op::PhiSrcs(phi) => {
for (id, src) in phi.iter() {
assert!(src.src_mod.is_none());
if let SrcRef::SSA(ssa) = src.src_ref {
for c in 0..ssa.comps() {
let src = self.ra.get_scalar(ssa.comp(c)).into();
self.phi_out.insert(PhiComp::new(*id, 0), src);
}
} else {
self.phi_out.insert(PhiComp::new(*id, 0), src.src_ref);
}
}
self.ra.free_killed(killed);
None
}
Op::PhiDsts(phi) => {
assert!(instr.pred.is_none());
for (id, dst) in phi.iter() {
if let Dst::SSA(ssa) = dst {
for c in 0..ssa.comps() {
self.live_in.push(LiveValue {
live_ref: LiveRef::Phi(PhiComp::new(*id, c)),
reg_ref: self.ra.alloc_scalar(ssa.as_comp()),
});
}
}
}
None
}
_ => {
for file in &mut self.ra.files {
instr_assign_regs_file(&mut instr, killed, pcopy, file);
}
Some(instr)
}
}
}
fn first_pass(&mut self, b: &mut BasicBlock, bl: &BlockLiveness) {
/* Populate live in from the register file we're handed. We'll add more
* live in when we process the OpPhiDst, if any.
*/
for raf in &self.ra.files {
for (comp, reg) in &raf.ssa_reg {
self.live_in.push(LiveValue {
live_ref: LiveRef::SSA(*comp),
reg_ref: RegRef::new(raf.file(), *reg, 1),
});
}
}
let mut instrs = Vec::new();
let mut killed = KillSet::new();
for (ip, instr) in b.instrs.drain(..).enumerate() {
/* Build up the kill set */
killed.clear();
if let Pred::SSA(ssa) = &instr.pred {
if !bl.is_live_after(ssa, ip) {
killed.insert(*ssa);
}
}
for src in instr.srcs() {
if let SrcRef::SSA(ssa) = &src.src_ref {
if !bl.is_live_after(ssa, ip) {
killed.insert(*ssa);
}
}
}
let mut pcopy = OpParCopy::new();
let instr = self.assign_regs_instr(instr, &killed, &mut pcopy);
if !pcopy.is_empty() {
instrs.push(Instr::new(Op::ParCopy(pcopy)));
}
if let Some(instr) = instr {
instrs.push(instr);
}
}
/* Sort live-in to maintain determinism */
self.live_in.sort();
b.instrs = instrs;
}
fn second_pass(&self, target: &AssignRegsBlock, b: &mut BasicBlock) {
let mut pcopy = OpParCopy::new();
for lv in &target.live_in {
let src = match lv.live_ref {
LiveRef::SSA(ssa) => {
let reg = self.ra.file(ssa.file()).get_reg_comp(ssa);
SrcRef::from(RegRef::new(ssa.file(), reg, 1))
}
LiveRef::Phi(phi) => *self.phi_out.get(&phi).unwrap(),
};
let dst = lv.reg_ref;
if let SrcRef::Reg(src_reg) = src {
if dst == src_reg {
continue;
}
}
pcopy.srcs.push(src.into());
pcopy.dsts.push(dst.into());
}
let pcopy = Instr::new(Op::ParCopy(pcopy));
if b.branch().is_some() {
b.instrs.insert(b.instrs.len() - 1, pcopy);
} else {
b.instrs.push(pcopy);
};
}
}
struct AssignRegs {
sm: u8,
blocks: HashMap<u32, AssignRegsBlock>,
}
impl AssignRegs {
pub fn new(sm: u8) -> Self {
Self {
sm: sm,
blocks: HashMap::new(),
}
}
pub fn run(&mut self, f: &mut Function) {
let live = Liveness::for_function(f);
for b in &mut f.blocks {
let bl = live.block(&b);
let ra = if bl.predecessors.is_empty() {
RegAllocation::new(self.sm)
} else {
/* Start with the previous block's. */
self.blocks.get(&bl.predecessors[0]).unwrap().ra.clone()
};
let mut arb = AssignRegsBlock::new(ra);
arb.first_pass(b, bl);
self.blocks.insert(b.id, arb);
}
for b in &mut f.blocks {
let bl = live.block(&b);
let arb = self.blocks.get(&b.id).unwrap();
for succ in bl.successors {
if let Some(succ) = succ {
let target = self.blocks.get(&succ).unwrap();
arb.second_pass(target, b);
}
}
}
}
}
impl Shader {
pub fn assign_regs(&mut self) {
for f in &mut self.functions {
AssignRegs::new(self.sm).run(f);
}
}
}
struct TrivialRegAlloc {
next_reg: u8,

View file

@ -34,6 +34,27 @@ impl RegFile {
RegFile::Pred | RegFile::UPred => true,
}
}
pub fn num_regs(&self, sm: u8) -> u8 {
match self {
RegFile::GPR => 255,
RegFile::UGPR => {
if sm >= 75 {
63
} else {
0
}
}
RegFile::Pred => 7,
RegFile::UPred => {
if sm >= 75 {
7
} else {
0
}
}
}
}
}
impl From<RegFile> for u8 {
@ -1430,6 +1451,13 @@ pub struct OpPhiSrcs {
}
impl OpPhiSrcs {
pub fn new() -> OpPhiSrcs {
OpPhiSrcs {
srcs: Vec::new(),
ids: Vec::new(),
}
}
pub fn is_empty(&self) -> bool {
assert!(self.ids.len() == self.srcs.len());
self.ids.is_empty()