nak: Implement basic control-flow

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/24998>
This commit is contained in:
Faith Ekstrand 2023-01-30 20:53:19 -06:00 committed by Marge Bot
parent 1b3382b861
commit 2bf9cafbe7
4 changed files with 172 additions and 26 deletions

View file

@ -6,6 +6,7 @@
use crate::bitset::*;
use crate::nak_ir::*;
use std::collections::HashMap;
use std::ops::Range;
struct ALURegRef {
@ -199,7 +200,7 @@ impl SM75Instr {
}
fn set_pred(&mut self, pred: &Pred, pred_inv: bool) {
assert!(pred.is_none() || !pred_inv);
assert!(!pred.is_none() || !pred_inv);
self.set_pred_reg(
12..15,
match pred {
@ -666,6 +667,28 @@ impl SM75Instr {
assert!(!op.access.out_load);
}
fn encode_bra(
&mut self,
op: &OpBra,
ip: usize,
block_offsets: &HashMap<u32, usize>,
) {
let ip = u64::try_from(ip).unwrap();
assert!(ip < i64::MAX as u64);
let ip = ip as i64;
let target_ip = *block_offsets.get(&op.target).unwrap();
let target_ip = u64::try_from(target_ip).unwrap();
assert!(target_ip < i64::MAX as u64);
let target_ip = target_ip as i64;
let rel_offset = target_ip - ip - 4;
self.set_opcode(0x947);
self.set_field(34..82, rel_offset);
self.set_field(87..90, 0x7_u8); /* TODO: Pred? */
}
fn encode_exit(&mut self, op: &OpExit) {
self.set_opcode(0x94d);
@ -682,7 +705,12 @@ impl SM75Instr {
self.set_field(72..80, op.idx);
}
pub fn encode(instr: &Instr, sm: u8) -> [u32; 4] {
pub fn encode(
instr: &Instr,
sm: u8,
ip: usize,
block_offsets: &HashMap<u32, usize>,
) -> [u32; 4] {
assert!(sm >= 75);
let mut si = SM75Instr {
@ -706,6 +734,7 @@ impl SM75Instr {
Op::St(op) => si.encode_st(&op),
Op::ALd(op) => si.encode_ald(&op),
Op::ASt(op) => si.encode_ast(&op),
Op::Bra(op) => si.encode_bra(&op, ip, block_offsets),
Op::Exit(op) => si.encode_exit(&op),
Op::S2R(op) => si.encode_s2r(&op),
_ => panic!("Unhandled instruction"),
@ -722,9 +751,22 @@ pub fn encode_shader(shader: &Shader) -> Vec<u32> {
let mut encoded = Vec::new();
assert!(shader.functions.len() == 1);
let func = &shader.functions[0];
let mut num_instrs = 0_usize;
let mut block_offsets = HashMap::new();
for b in &func.blocks {
block_offsets.insert(b.id, num_instrs);
num_instrs += b.instrs.len() * 4;
}
for b in &func.blocks {
for instr in &b.instrs {
let e = SM75Instr::encode(instr, shader.sm);
let e = SM75Instr::encode(
instr,
shader.sm,
encoded.len(),
&block_offsets,
);
encoded.extend_from_slice(&e[..]);
}
}

View file

@ -13,8 +13,10 @@ use nak_bindings::*;
struct ShaderFromNir<'a> {
nir: &'a nir_shader,
func: Option<Function>,
blocks: Vec<BasicBlock>,
instrs: Vec<Instr>,
fs_out_regs: Vec<Src>,
end_block_id: u32,
}
impl<'a> ShaderFromNir<'a> {
@ -28,8 +30,10 @@ impl<'a> ShaderFromNir<'a> {
Self {
nir: nir,
func: None,
blocks: Vec::new(),
instrs: Vec::new(),
fs_out_regs: fs_out_regs,
end_block_id: 0,
}
}
@ -320,9 +324,7 @@ impl<'a> ShaderFromNir<'a> {
}
fn parse_jump(&mut self, jump: &nir_jump_instr) {
match jump.type_ {
_ => panic!("Unsupported jump instruction"),
}
/* Nothing to do */
}
fn parse_tex(&mut self, _tex: &nir_tex_instr) {
@ -450,7 +452,7 @@ impl<'a> ShaderFromNir<'a> {
panic!("SSA undef not implemented yet");
}
fn parse_basic_block(&mut self, nb: &nir_block) -> BasicBlock {
fn parse_block(&mut self, nb: &nir_block) {
for ni in nb.iter_instr_list() {
match ni.type_ {
nir_instr_type_alu => self.parse_alu(ni.as_alu().unwrap()),
@ -468,32 +470,78 @@ impl<'a> ShaderFromNir<'a> {
_ => panic!("Unsupported instruction type"),
}
}
let mut b = BasicBlock::new(0 /* TODO: Block indices */);
let succ = nb.successors();
let s0 = succ[0].unwrap();
if let Some(s1) = succ[1] {
/* Jump to the else. We'll come back and fix up the predicate as
* part of our handling of nir_if.
*/
self.instrs.push(Instr::new_bra(s1.index));
} else if s0.index == self.end_block_id {
self.instrs.push(Instr::new_exit());
} else {
self.instrs.push(Instr::new_bra(s0.index));
}
let mut b = BasicBlock::new(nb.index);
b.instrs.append(&mut self.instrs);
b
self.blocks.push(b);
}
fn parse_if(&mut self, ni: &nir_if) {
let cond = self.get_ssa(&ni.condition.as_def());
let if_bra = self.blocks.last_mut().unwrap().branch_mut().unwrap();
if_bra.pred = cond.into();
/* This is the branch to jump to the else */
if_bra.pred_inv = true;
self.parse_cf_list(ni.iter_then_list());
self.parse_cf_list(ni.iter_else_list());
}
fn parse_loop(&mut self, nl: &nir_loop) {
self.parse_cf_list(nl.iter_body());
}
fn parse_cf_list(&mut self, list: ExecListIter<nir_cf_node>) {
for node in list {
match node.type_ {
nir_cf_node_block => {
self.parse_block(node.as_block().unwrap());
}
nir_cf_node_if => {
self.parse_if(node.as_if().unwrap());
}
nir_cf_node_loop => {
self.parse_loop(node.as_loop().unwrap());
}
_ => panic!("Invalid inner CF node type"),
}
}
}
pub fn parse_function_impl(&mut self, nfi: &nir_function_impl) -> Function {
self.func = Some(Function::new(0, nfi.ssa_alloc));
for node in nfi.iter_body() {
/* TODO: Control-flow */
let b = self.parse_basic_block(node.as_block().unwrap());
self.func.as_mut().unwrap().blocks.push(b);
}
self.end_block_id = nfi.end_block().index;
let end_block = self.func.as_mut().unwrap().blocks.last_mut().unwrap();
self.parse_cf_list(nfi.iter_body());
let end_block = self.blocks.last_mut().unwrap();
if self.nir.info.stage() == MESA_SHADER_FRAGMENT
&& nfi.function().is_entrypoint
{
let fs_out_regs =
std::mem::replace(&mut self.fs_out_regs, Vec::new());
end_block.instrs.push(Instr::new_fs_out(&fs_out_regs));
let fs_out = Instr::new_fs_out(&fs_out_regs);
end_block.instrs.insert(end_block.instrs.len() - 1, fs_out);
}
end_block.instrs.push(Instr::new_exit());
self.func.take().unwrap()
let mut f = self.func.take().unwrap();
f.blocks.append(&mut self.blocks);
f
}
pub fn parse_shader(&mut self, sm: u8) -> Shader {

View file

@ -1194,6 +1194,18 @@ impl fmt::Display for OpASt {
}
}
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpBra {
pub target: u32,
}
impl fmt::Display for OpBra {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "BRA B{}", self.target)
}
}
#[repr(C)]
#[derive(SrcsAsSlice, DstsAsSlice)]
pub struct OpExit {}
@ -1345,6 +1357,7 @@ pub enum Op {
St(OpSt),
ALd(OpALd),
ASt(OpASt),
Bra(OpBra),
Exit(OpExit),
S2R(OpS2R),
FMov(OpFMov),
@ -1383,6 +1396,18 @@ impl Pred {
}
}
impl From<RegRef> for Pred {
fn from(reg: RegRef) -> Pred {
Pred::Reg(reg)
}
}
impl From<SSAValue> for Pred {
fn from(ssa: SSAValue) -> Pred {
Pred::SSA(ssa)
}
}
impl fmt::Display for Pred {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
@ -1656,6 +1681,10 @@ impl Instr {
}))
}
pub fn new_bra(block: u32) -> Instr {
Instr::new(Op::Bra(OpBra { target: block }))
}
pub fn new_exit() -> Instr {
Instr::new(Op::Exit(OpExit {}))
}
@ -1700,9 +1729,20 @@ impl Instr {
self.op.srcs_as_mut_slice()
}
pub fn is_branch(&self) -> bool {
match self.op {
Op::Bra(_) => true,
_ => false,
}
}
pub fn can_eliminate(&self) -> bool {
match self.op {
Op::ASt(_) | Op::St(_) | Op::Exit(_) | Op::FSOut(_) => false,
Op::ASt(_)
| Op::St(_)
| Op::Bra(_)
| Op::Exit(_)
| Op::FSOut(_) => false,
_ => true,
}
}
@ -1724,7 +1764,7 @@ impl Instr {
Op::ASt(_) => Some(15),
Op::Ld(_) => None,
Op::St(_) => None,
Op::Exit(_) => Some(15),
Op::Bra(_) | Op::Exit(_) => Some(15),
Op::FMov(_)
| Op::IMov(_)
| Op::Vec(_)
@ -1740,9 +1780,9 @@ impl fmt::Display for Instr {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if !self.pred.is_none() {
if self.pred_inv {
write!(f, "@!{}", self.pred)?;
write!(f, "@!{} ", self.pred)?;
} else {
write!(f, "@{}", self.pred)?;
write!(f, "@{} ", self.pred)?;
}
}
write!(f, "{} {}", self.op, self.deps)
@ -1750,7 +1790,7 @@ impl fmt::Display for Instr {
}
pub struct BasicBlock {
id: u32,
pub id: u32,
pub instrs: Vec<Instr>,
}
@ -1769,6 +1809,18 @@ impl BasicBlock {
}
self.instrs = instrs;
}
pub fn branch_mut(&mut self) -> Option<&mut Instr> {
if let Some(i) = self.instrs.last_mut() {
if i.is_branch() {
Some(i)
} else {
None
}
} else {
None
}
}
}
impl fmt::Display for BasicBlock {

View file

@ -473,10 +473,14 @@ nak_postprocess_nir(nir_shader *nir, const struct nak_compiler *nak)
nak_optimize_nir(nir, nak);
nir_divergence_analysis(nir);
/* Compact SSA defs because we'll use them to index arrays */
/* Re-index blocks and compact SSA defs because we'll use them to index
* arrays
*/
nir_foreach_function(func, nir) {
if (func->impl)
if (func->impl) {
nir_index_blocks(func->impl);
nir_index_ssa_defs(func->impl);
}
}
nir_print_shader(nir, stderr);