rusticl/kernel: add optimized Kernel variant

By default we have to take into account that the application could set
offsets, or that one kernel launch won't fit into a single hw dispatch.

In order to mitigate the overhead it causes at kernel runtime, and because
those things are in most cases irrelevant, we compile an optimized kernel
making a few assumptions.

We also make use of the the workgroup_size_hint as an additional
optimization.

This should speed up relatively small kernels significantly as it can cut
the instruction count in half for those.

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/30152>
This commit is contained in:
Karol Herbst 2024-08-12 14:53:32 +02:00 committed by Marge Bot
parent 59f63381d4
commit f098620c21
2 changed files with 173 additions and 38 deletions

View file

@ -22,6 +22,7 @@ use spirv::SpirvKernelInfo;
use std::cmp;
use std::collections::HashMap;
use std::convert::TryInto;
use std::ops::Index;
use std::os::raw::c_void;
use std::ptr;
use std::slice;
@ -318,6 +319,7 @@ pub struct KernelInfo {
pub args: Vec<KernelArg>,
pub attributes_string: String,
work_group_size: [usize; 3],
work_group_size_hint: [u32; 3],
subgroup_size: usize,
num_subgroups: usize,
}
@ -355,7 +357,56 @@ enum KernelDevStateVariant {
Nir(NirShader),
}
pub struct NirKernelBuild {
#[derive(Debug, PartialEq)]
enum NirKernelVariant {
/// Can be used under any circumstance.
Default,
/// Optimized variant making the following assumptions:
/// - global_id_offsets are 0
/// - workgroup_offsets are 0
/// - local_size is info.local_size_hint
Optimized,
}
pub struct NirKernelBuilds {
default_build: NirKernelBuild,
optimized: Option<NirKernelBuild>,
/// merged info with worst case values
info: pipe_compute_state_object_info,
}
impl Index<NirKernelVariant> for NirKernelBuilds {
type Output = NirKernelBuild;
fn index(&self, index: NirKernelVariant) -> &Self::Output {
match index {
NirKernelVariant::Default => &self.default_build,
NirKernelVariant::Optimized => self.optimized.as_ref().unwrap_or(&self.default_build),
}
}
}
impl NirKernelBuilds {
fn new(default_build: NirKernelBuild, optimized: Option<NirKernelBuild>) -> Self {
let mut info = default_build.info;
if let Some(build) = &optimized {
info.max_threads = cmp::min(info.max_threads, build.info.max_threads);
info.simd_sizes &= build.info.simd_sizes;
info.private_memory = cmp::max(info.private_memory, build.info.private_memory);
info.preferred_simd_size =
cmp::max(info.preferred_simd_size, build.info.preferred_simd_size);
}
Self {
default_build: default_build,
optimized: optimized,
info: info,
}
}
}
struct NirKernelBuild {
nir_or_cso: KernelDevStateVariant,
constant_buffer: Option<Arc<PipeResource>>,
info: pipe_compute_state_object_info,
@ -420,7 +471,7 @@ pub struct Kernel {
pub prog: Arc<Program>,
pub name: String,
values: Mutex<Vec<Option<KernelArgValue>>>,
builds: HashMap<&'static Device, Arc<NirKernelBuild>>,
builds: HashMap<&'static Device, Arc<NirKernelBuilds>>,
pub kernel_info: Arc<KernelInfo>,
}
@ -439,6 +490,7 @@ where
Ok(res)
}
#[derive(Clone)]
struct CompilationResult {
nir: NirShader,
compiled_args: Vec<CompiledKernelArg>,
@ -666,7 +718,12 @@ fn compile_nir_prepare_for_variants(
nir.gather_info();
}
fn compile_nir_variant(res: &mut CompilationResult, dev: &Device, args: &[KernelArg]) {
fn compile_nir_variant(
res: &mut CompilationResult,
dev: &Device,
variant: NirKernelVariant,
args: &[KernelArg],
) {
let mut lower_state = rusticl_lower_state::default();
let compiled_args = &mut res.compiled_args;
let nir = &mut res.nir;
@ -694,10 +751,19 @@ fn compile_nir_variant(res: &mut CompilationResult, dev: &Device, args: &[Kernel
.nir_shader_compiler_options(pipe_shader_type::PIPE_SHADER_COMPUTE)
};
if variant == NirKernelVariant::Optimized {
let wgsh = nir.workgroup_size_hint();
if wgsh != [0; 3] {
nir.set_workgroup_size(wgsh);
}
}
let mut compute_options = nir_lower_compute_system_values_options::default();
compute_options.set_has_base_global_invocation_id(true);
compute_options.set_has_base_workgroup_id(true);
compute_options.set_has_global_size(true);
if variant != NirKernelVariant::Optimized {
compute_options.set_has_base_global_invocation_id(true);
compute_options.set_has_base_workgroup_id(true);
}
nir_pass!(nir, nir_lower_compute_system_values, &compute_options);
nir.gather_info();
@ -721,6 +787,7 @@ fn compile_nir_variant(res: &mut CompilationResult, dev: &Device, args: &[Kernel
};
if nir.reads_sysval(gl_system_value::SYSTEM_VALUE_BASE_GLOBAL_INVOCATION_ID) {
debug_assert_ne!(variant, NirKernelVariant::Optimized);
add_var(
nir,
&mut lower_state.base_global_invoc_id_loc,
@ -741,6 +808,7 @@ fn compile_nir_variant(res: &mut CompilationResult, dev: &Device, args: &[Kernel
}
if nir.reads_sysval(gl_system_value::SYSTEM_VALUE_BASE_WORKGROUP_ID) {
debug_assert_ne!(variant, NirKernelVariant::Optimized);
add_var(
nir,
&mut lower_state.base_workgroup_id_loc,
@ -896,7 +964,7 @@ fn compile_nir_remaining(
dev: &Device,
mut nir: NirShader,
args: &[KernelArg],
) -> CompilationResult {
) -> (CompilationResult, Option<CompilationResult>) {
// add all API kernel args
let mut compiled_args: Vec<_> = (0..args.len())
.map(|idx| CompiledKernelArg {
@ -912,14 +980,27 @@ fn compile_nir_remaining(
compiled_args: compiled_args,
};
compile_nir_variant(&mut default_build, dev, args);
// check if we even want to compile a variant before cloning the compilation state
let has_wgs_hint = default_build.nir.workgroup_size_variable()
&& default_build.nir.workgroup_size_hint() != [0; 3];
let has_offsets = default_build
.nir
.reads_sysval(gl_system_value::SYSTEM_VALUE_GLOBAL_INVOCATION_ID);
default_build
let mut optimized = (!Platform::dbg().no_variants && (has_offsets || has_wgs_hint))
.then(|| default_build.clone());
compile_nir_variant(&mut default_build, dev, NirKernelVariant::Default, args);
if let Some(optimized) = &mut optimized {
compile_nir_variant(optimized, dev, NirKernelVariant::Optimized, args);
}
(default_build, optimized)
}
pub struct SPIRVToNirResult {
pub kernel_info: KernelInfo,
pub nir_kernel_build: NirKernelBuild,
pub nir_kernel_builds: NirKernelBuilds,
}
impl SPIRVToNirResult {
@ -928,20 +1009,30 @@ impl SPIRVToNirResult {
kernel_info: &clc_kernel_info,
args: Vec<KernelArg>,
default_build: CompilationResult,
optimized: Option<CompilationResult>,
) -> Self {
// TODO: we _should_ be able to parse them out of the SPIR-V, but clc doesn't handle
// indirections yet.
let nir = &default_build.nir;
let wgs = nir.workgroup_size();
let subgroup_size = nir.subgroup_size();
let num_subgroups = nir.num_subgroups();
let default_build = NirKernelBuild::new(dev, default_build);
let optimized = optimized.map(|opt| NirKernelBuild::new(dev, opt));
let kernel_info = KernelInfo {
args: args,
attributes_string: kernel_info.attribute_str(),
work_group_size: [wgs[0] as usize, wgs[1] as usize, wgs[2] as usize],
subgroup_size: nir.subgroup_size() as usize,
num_subgroups: nir.num_subgroups() as usize,
work_group_size_hint: kernel_info.local_size_hint,
subgroup_size: subgroup_size as usize,
num_subgroups: num_subgroups as usize,
};
Self {
kernel_info: kernel_info,
nir_kernel_build: NirKernelBuild::new(dev, default_build),
nir_kernel_builds: NirKernelBuilds::new(default_build, optimized),
}
}
@ -954,14 +1045,39 @@ impl SPIRVToNirResult {
let args = KernelArg::deserialize(&mut reader)?;
let default_build = CompilationResult::deserialize(&mut reader, d)?;
Some(SPIRVToNirResult::new(d, kernel_info, args, default_build))
let optimized = match unsafe { blob_read_uint8(&mut reader) } {
0 => None,
_ => Some(CompilationResult::deserialize(&mut reader, d)?),
};
Some(SPIRVToNirResult::new(
d,
kernel_info,
args,
default_build,
optimized,
))
}
// we can't use Self here as the nir shader might be compiled to a cso already and we can't
// cache that.
fn serialize(blob: &mut blob, args: &[KernelArg], default_build: &CompilationResult) {
fn serialize(
blob: &mut blob,
args: &[KernelArg],
default_build: &CompilationResult,
optimized: &Option<CompilationResult>,
) {
KernelArg::serialize(args, blob);
default_build.serialize(blob);
match optimized {
Some(variant) => {
unsafe { blob_write_uint8(blob, 1) };
variant.serialize(blob);
}
None => unsafe {
blob_write_uint8(blob, 0);
},
}
}
}
@ -982,11 +1098,17 @@ pub(super) fn convert_spirv_to_nir(
.unwrap_or_else(|| {
let nir = build.to_nir(name, dev);
let (mut args, nir) = compile_nir_to_args(dev, nir, args, &dev.lib_clc);
let default_build = compile_nir_remaining(dev, nir, &args);
let (default_build, optimized) = compile_nir_remaining(dev, nir, &args);
for arg in &default_build.compiled_args {
if let CompiledKernelArgType::APIArg(idx) = arg.kind {
args[idx as usize].dead &= arg.dead;
for build in [Some(&default_build), optimized.as_ref()].into_iter() {
let Some(build) = build else {
continue;
};
for arg in &build.compiled_args {
if let CompiledKernelArgType::APIArg(idx) = arg.kind {
args[idx as usize].dead &= arg.dead;
}
}
}
@ -994,14 +1116,14 @@ pub(super) fn convert_spirv_to_nir(
let mut blob = blob::default();
unsafe {
blob_init(&mut blob);
SPIRVToNirResult::serialize(&mut blob, &args, &default_build);
SPIRVToNirResult::serialize(&mut blob, &args, &default_build, &optimized);
let bin = slice::from_raw_parts(blob.data, blob.size);
cache.put(bin, &mut key.unwrap());
blob_finish(&mut blob);
}
}
SPIRVToNirResult::new(dev, spirv_info, args, default_build)
SPIRVToNirResult::new(dev, spirv_info, args, default_build, optimized)
})
}
@ -1103,7 +1225,7 @@ impl Kernel {
// Clone all the data we need to execute this kernel
let kernel_info = Arc::clone(&self.kernel_info);
let arg_values = self.arg_values().clone();
let nir_kernel_build = Arc::clone(&self.builds[q.device]);
let nir_kernel_builds = Arc::clone(&self.builds[q.device]);
// operations we want to report errors to the clients
let mut block = create_kernel_arr::<u32>(block, 1)?;
@ -1115,6 +1237,27 @@ impl Kernel {
self.optimize_local_size(q.device, &mut grid, &mut block);
Ok(Box::new(move |q, ctx| {
let hw_max_grid: Vec<usize> = q
.device
.max_grid_size()
.into_iter()
.map(|val| val.try_into().unwrap_or(usize::MAX))
// clamped as pipe_launch_grid::grid is only u32
.map(|val| cmp::min(val, u32::MAX as usize))
.collect();
let variant = if offsets == [0; 3]
&& grid[0] <= hw_max_grid[0]
&& grid[1] <= hw_max_grid[1]
&& grid[2] <= hw_max_grid[2]
&& block == kernel_info.work_group_size_hint
{
NirKernelVariant::Optimized
} else {
NirKernelVariant::Default
};
let nir_kernel_build = &nir_kernel_builds[variant];
let mut workgroup_id_offset_loc = None;
let mut input = Vec::new();
// Set it once so we get the alignment padding right
@ -1359,15 +1502,6 @@ impl Kernel {
ctx.set_shader_images(&iviews);
ctx.set_global_binding(resources.as_slice(), &mut globals);
let hw_max_grid: Vec<usize> = q
.device
.max_grid_size()
.into_iter()
.map(|val| val.try_into().unwrap_or(usize::MAX))
// clamped as pipe_launch_grid::grid is only u32
.map(|val| cmp::min(val, u32::MAX as usize))
.collect();
for z in 0..grid[2].div_ceil(hw_max_grid[2]) {
for y in 0..grid[1].div_ceil(hw_max_grid[1]) {
for x in 0..grid[0].div_ceil(hw_max_grid[0]) {
@ -1535,7 +1669,8 @@ impl Kernel {
pub fn local_mem_size(&self, dev: &Device) -> cl_ulong {
// TODO include args
self.builds.get(dev).unwrap().shared_size as cl_ulong
// this is purely informational so it shouldn't even matter
self.builds.get(dev).unwrap()[NirKernelVariant::Default].shared_size as cl_ulong
}
pub fn has_svm_devs(&self) -> bool {
@ -1574,7 +1709,8 @@ impl Kernel {
*block.get(2).unwrap_or(&1) as u32,
];
match &self.builds.get(dev).unwrap().nir_or_cso {
// TODO: this _might_ bite us somewhere, but I think it probably doesn't matter
match &self.builds.get(dev).unwrap()[NirKernelVariant::Default].nir_or_cso {
KernelDevStateVariant::Cso(cso) => {
dev.helper_ctx()
.compute_state_subgroup_size(cso.cso_ptr, &block) as usize

View file

@ -122,11 +122,10 @@ impl ProgramBuild {
let build_result = convert_spirv_to_nir(self, kernel_name, &args, dev);
kernel_info_set.insert(build_result.kernel_info);
self.builds
.get_mut(dev)
.unwrap()
.kernels
.insert(kernel_name.clone(), Arc::new(build_result.nir_kernel_build));
self.builds.get_mut(dev).unwrap().kernels.insert(
kernel_name.clone(),
Arc::new(build_result.nir_kernel_builds),
);
}
// we want the same (internal) args for every compiled kernel, for now
@ -229,7 +228,7 @@ pub struct ProgramDevBuild {
options: String,
log: String,
bin_type: cl_program_binary_type,
pub kernels: HashMap<String, Arc<NirKernelBuild>>,
pub kernels: HashMap<String, Arc<NirKernelBuilds>>,
}
fn prepare_options(options: &str, dev: &Device) -> Vec<CString> {