nouveau/headers: Add a Rust struct for each method

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/28255>
This commit is contained in:
Daniel Almeida 2024-05-10 14:05:28 -03:00 committed by Marge Bot
parent 591b5da49b
commit 63770a163a
3 changed files with 228 additions and 0 deletions

View file

@ -6,6 +6,8 @@
import argparse
import os.path
import sys
import re
import subprocess
from mako.template import Template
@ -230,6 +232,126 @@ pub const ${version[0]}: u16 = ${version[1]};
% endif
""")
TEMPLATE_RS_MTHD = Template("""\
// parsed class ${nvcl}
## Write out the methods in Rust
%for mthd_name, mthd in mthddict.items():
## Identify the field type.
<%
for field_name, field_value in mthd.field_defs.items():
if field_name == 'V' and len(field_value) > 0:
mthd.field_rs_types[field_name] = to_camel(mthd_name) + 'V'
mthd.field_is_rs_enum[field_name] = True
elif len(field_value) > 0:
assert(field_name != "")
mthd.field_rs_types[field_name] = to_camel(mthd_name) + to_camel(field_name)
mthd.field_is_rs_enum[field_name] = True
elif mthd.is_float:
mthd.field_rs_types[field_name] = "f32"
mthd.field_is_rs_enum[field_name] = False
else:
mthd.field_rs_types[field_name] = "u32"
mthd.field_is_rs_enum[field_name] = False
# TRUE and FALSE are special cases.
if len(field_value) == 2:
for enumerant in field_value:
if enumerant.lower() == 'true' or enumerant.lower() == 'false':
mthd.field_rs_types[field_name] = "bool"
mthd.field_is_rs_enum[field_name] = False
break
%>
## If there are a range of values for a field, we define an enum.
%for field_name in mthd.field_defs:
%if mthd.field_is_rs_enum[field_name]:
#[repr(u16)]
#[derive(Copy, Clone, Debug, PartialEq)]
pub enum ${mthd.field_rs_types[field_name]} {
%for field_name, field_value in mthd.field_defs[field_name].items():
${to_camel(rs_field_name(field_name))} = ${field_value.lower()},
%endfor
}
%endif
%endfor
## We also define a struct with the fields for the mthd.
#[derive(Copy, Clone, Debug, PartialEq)]
pub struct ${to_camel(mthd_name)} {
%for field_name in mthddict[mthd_name].field_name_start:
pub ${rs_field_name(field_name.lower())}: ${mthd.field_rs_types[field_name]},
%endfor
}
## Notice that the "to_bits" implementation is identical, so the first brace is
## not closed.
% if not mthd.is_array:
## This trait lays out how the conversion to u32 happens
impl Mthd for ${to_camel(mthd_name)} {
const ADDR: u16 = ${mthd.addr.replace('(', '').replace(')', '')};
const CLASS: u16 = ${version[1].lower() if version is not None else nvcl.lower().replace("nv", "0x")};
%else:
impl ArrayMthd for ${to_camel(mthd_name)} {
const CLASS: u16 = ${version[1].lower() if version is not None else nvcl.lower().replace("nv", "0x")};
fn addr(i: usize) -> u16 {
<% assert not ('i' in mthd.addr and 'j' in mthd.addr) %>
(${mthd.addr.replace('j', 'i').replace('(', '').replace(')', '')}).try_into().unwrap()
}
%endif
#[inline]
fn to_bits(self) -> u32 {
let mut val = 0;
%for field_name in mthddict[mthd_name].field_name_start:
<%
field_start = int(mthd.field_name_start[field_name])
field_end = int(mthd.field_name_end[field_name])
field_width = field_end - field_start + 1
field = rs_field_name(field_name.lower()) if mthd.field_rs_types[field_name] == "u32" else f"{rs_field_name(field_name)} as u32"
%>
%if field_width == 32:
val |= self.${field};
%else:
%if "as u32" in field:
assert!((self.${field}) < (1 << ${field_width}));
val |= (self.${field}) << ${field_start};
%else:
assert!(self.${field} < (1 << ${field_width}));
val |= self.${field} << ${field_start};
%endif
%endif
%endfor
val
}
## Close the first brace.
}
%endfor
""")
## A mere convenience to convert snake_case to CamelCase. Numbers are prefixed
## with "_".
def to_camel(snake_str):
result = ''.join(word.title() for word in snake_str.split('_'))
return result if not result[0].isdigit() else '_' + result
def rs_field_name(name):
name = name.lower()
# Fix up some Rust keywords
if name == 'type':
return 'type_'
elif name == 'override':
return 'override_'
elif name[0].isdigit():
return '_' + name
else:
return name
def glob_match(glob, name):
if glob.endswith('*'):
return name.startswith(glob[:-1])
@ -338,6 +460,8 @@ def parse_header(nvcl, f):
x.field_name_start = {}
x.field_name_end = {}
x.field_defs = {}
x.field_rs_types = {}
x.field_is_rs_enum = {}
mthddict[x.name] = x
curmthd = x
@ -345,11 +469,72 @@ def parse_header(nvcl, f):
return (version, mthddict)
def convert_to_rust_constants(filename):
with open(filename, 'r') as file:
lines = file.readlines()
rust_items = []
processed_constants = {}
file_prefix = "NV" + os.path.splitext(os.path.basename(filename))[0].upper() + "_"
file_prefix = file_prefix.replace('CL', '')
for line in lines:
match = re.match(r'#define\s+(\w+)\((\w+)\)\s+(.+)', line.strip())
if match:
name, arg, expr = match.groups()
if name in processed_constants:
processed_constants[name] += 1
name += f"_{processed_constants[name]}"
else:
processed_constants[name] = 0
name = name.replace(file_prefix, '')
# convert to snake case
name = re.sub(r'(?<=[a-z])(?=[A-Z])', '_', name).lower()
rust_items.append(f"#[inline]\npub fn {name} ({arg}: u32) -> u32 {{ {expr.replace('(', '').replace(')', '')} }} ")
else:
match = re.match(r'#define\s+(\w+)\s+(?:MW\()?(\d+):(\d+)\)?', line.strip())
if match:
name, high, low = match.groups()
high = int(high) + 1 # Convert to exclusive range
if name in processed_constants:
processed_constants[name] += 1
name += f"_{processed_constants[name]}"
else:
processed_constants[name] = 0
# name = name.replace('__', '_').replace(file_prefix, '')
name = name.replace(file_prefix, '')
rust_items.append(f"pub const {name}: Range<u32> = {low}..{high};")
else:
match = re.match(r'#define\s+(\w+)\s+\(?0x(\w+)\)?', line.strip())
if match:
name, value = match.groups()
if name in processed_constants:
processed_constants[name] += 1
name += f"_{processed_constants[name]}"
else:
processed_constants[name] = 0
name = name.replace(file_prefix, '')
rust_items.append(f"pub const {name}: u32 = 0x{value};")
else:
match = re.match(r'#define\s+(\w+)\s+\(?(\d+)\)?', line.strip())
if match:
name, value = match.groups()
if name in processed_constants:
processed_constants[name] += 1
name += f"_{processed_constants[name]}"
else:
processed_constants[name] = 0
name = name.replace(file_prefix, '')
rust_items.append(f"pub const {name}: u32 = {value};")
return '\n'.join(rust_items)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--out-h', required=False, help='Output C header.')
parser.add_argument('--out-c', required=False, help='Output C file.')
parser.add_argument('--out-rs', required=False, help='Output Rust file.')
parser.add_argument('--out-rs-mthd', required=False,
help='Output Rust file for methods.')
parser.add_argument('--in-h',
help='Input class header file.',
required=True)
@ -370,6 +555,8 @@ def main():
'nvcl': nvcl,
'version': version,
'mthddict': mthddict,
'rs_field_name': rs_field_name,
'to_camel': to_camel,
'bs': '\\'
}
@ -384,6 +571,18 @@ def main():
if args.out_rs is not None:
with open(args.out_rs, 'w', encoding='utf-8') as f:
f.write(TEMPLATE_RS.render(**environment))
if args.out_rs_mthd is not None:
with open(args.out_rs_mthd, 'w', encoding='utf-8') as f:
f.write("#![allow(non_camel_case_types)]\n")
f.write("#![allow(non_snake_case)]\n")
f.write("#![allow(non_upper_case_globals)]\n\n")
f.write("use std::ops::Range;\n")
f.write("use crate::Mthd;\n")
f.write("use crate::ArrayMthd;\n")
f.write("\n")
f.write(convert_to_rust_constants(args.in_h))
f.write("\n")
f.write(TEMPLATE_RS_MTHD.render(**environment))
except Exception:
# In the event there's an error, this imports some helpers from mako

View file

@ -9,6 +9,7 @@
import argparse
import os.path
import re
import subprocess
import sys
from mako.template import Template
@ -48,6 +49,26 @@ ${decl_mod(m.children[name], path + [name])}
}
% endif
</%def>
/// Converts a method to its raw representation.
pub trait Mthd {
/// The hardware address of the method.
const ADDR: u16;
/// The class of the method.
const CLASS: u16;
/// Converts the method to its raw representation.
fn to_bits(self) -> u32;
}
pub trait ArrayMthd {
/// The class of the method.
const CLASS: u16;
/// The hardware address of the method for the given index.
fn addr(i: usize) -> u16;
/// Converts the method to its raw representation.
fn to_bits(self) -> u32;
}
${decl_mod(root, [])}
""")

View file

@ -83,6 +83,14 @@ if with_nouveau_vk
'--out-rs', '@OUTPUT0@'],
)
cl_rs_generated += custom_target(
'nvh_' + cl + '_mthd.rs',
input : ['class_parser.py', 'nvidia/classes/'+cl+'.h'],
output : ['nvh_classes_'+cl+'_mthd.rs'],
command : [prog_python, '@INPUT0@', '--in-h', '@INPUT1@',
'--out-rs-mthd', '@OUTPUT0@'],
)
fs = import('fs')
if cl.endswith('c0') and fs.is_file('nvidia/classes/'+cl+'qmd.h')
cl_rs_generated += custom_target(