rusticl/kernel: delay calculation of CSO info until kernel creation
Some checks are pending
macOS-CI / macOS-CI (dri) (push) Waiting to run
macOS-CI / macOS-CI (xlib) (push) Waiting to run

Reviewed-by: Karol Herbst <kherbst@redhat.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/37036>
This commit is contained in:
Seán de Búrca 2025-08-27 15:29:11 -07:00
parent ba292ac34a
commit 70794de792

View file

@ -39,6 +39,7 @@ use std::slice;
use std::sync::Arc;
use std::sync::Mutex;
use std::sync::MutexGuard;
use std::sync::OnceLock;
use std::sync::Weak;
// According to the CL spec we are not allowed to let any cl_kernel object hold any references on
@ -399,6 +400,23 @@ pub enum KernelDevStateVariant {
Nir(NirShader),
}
impl KernelDevStateVariant {
fn calculate_compute_state_info(
&self,
device: &'static Device,
) -> pipe_compute_state_object_info {
match self {
KernelDevStateVariant::Cso(cso) => cso.get_cso_info(),
KernelDevStateVariant::Nir(nir) => {
// SAFETY: We never execute the compute state; we simply extract
// information and discard the CSO.
let cso = unsafe { SharedCSOWrapper::new(device, nir) };
cso.get_cso_info()
}
}
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum NirKernelVariant {
/// Can be used under any circumstance.
@ -421,8 +439,9 @@ impl Display for NirKernelVariant {
pub struct NirKernelBuilds {
default_build: NirKernelBuild,
optimized: Option<NirKernelBuild>,
/// merged info with worst case values
info: pipe_compute_state_object_info,
/// Compute state info with the worst-case values between available builds.
info: OnceLock<pipe_compute_state_object_info>,
}
impl Index<NirKernelVariant> for NirKernelBuilds {
@ -438,27 +457,41 @@ impl Index<NirKernelVariant> for NirKernelBuilds {
impl NirKernelBuilds {
fn new(default_build: NirKernelBuild, optimized: Option<NirKernelBuild>) -> Self {
let mut info = default_build.info;
if let Some(build) = &optimized {
info.max_threads = cmp::min(info.max_threads, build.info.max_threads);
info.simd_sizes &= build.info.simd_sizes;
info.private_memory = cmp::max(info.private_memory, build.info.private_memory);
info.preferred_simd_size =
cmp::max(info.preferred_simd_size, build.info.preferred_simd_size);
}
Self {
default_build: default_build,
optimized: optimized,
info: info,
info: OnceLock::new(),
}
}
/// Calculates the worst-case compute state info between the available
/// builds and stores the result.
fn calculate_compute_state_info(&self, device: &'static Device) {
self.info.get_or_init(|| {
let mut info = self
.default_build
.nir_or_cso
.calculate_compute_state_info(device);
if let Some(optimized) = &self.optimized {
let optimized_info = optimized.nir_or_cso.calculate_compute_state_info(device);
// Calculate worst-case values for each parameter.
info.max_threads = cmp::min(info.max_threads, optimized_info.max_threads);
info.simd_sizes &= optimized_info.simd_sizes;
info.private_memory = cmp::max(info.private_memory, optimized_info.private_memory);
info.preferred_simd_size =
cmp::max(info.preferred_simd_size, optimized_info.preferred_simd_size);
}
info
});
}
}
pub struct NirKernelBuild {
nir_or_cso: KernelDevStateVariant,
constant_buffer: Option<PipeResourceOwned>,
info: pipe_compute_state_object_info,
shared_size: u64,
input_size: u32,
printf_info: Option<NirPrintfInfo>,
@ -472,24 +505,23 @@ unsafe impl Sync for NirKernelBuild {}
impl NirKernelBuild {
fn new(dev: &'static Device, mut out: CompilationResult) -> Self {
// SAFETY: we only use the cso when dev supports shareable shaders, otherwise we just
// extract some info and throw it away, which is safe.
let cso = unsafe { SharedCSOWrapper::new(dev, &out.nir) };
let info = cso.get_cso_info();
let cb = Self::create_nir_constant_buffer(dev, &out.nir);
let shared_size = out.nir.shared_size() as u64;
let printf_info = out.nir.take_printf_info();
let nir_or_cso = if !dev.shareable_shaders() {
KernelDevStateVariant::Nir(out.nir)
} else {
let nir_or_cso = if dev.shareable_shaders() {
// SAFETY: The device supports shareable shaders, upholding the
// safety requirements of `SharedCSOWrapper`.
let cso = unsafe { SharedCSOWrapper::new(dev, &out.nir) };
KernelDevStateVariant::Cso(cso)
} else {
KernelDevStateVariant::Nir(out.nir)
};
NirKernelBuild {
nir_or_cso: nir_or_cso,
constant_buffer: cb,
info: info,
shared_size: shared_size,
input_size: out.input_size,
printf_info: printf_info,
@ -1349,7 +1381,16 @@ impl Kernel {
let builds = prog_build
.builds_by_device
.iter()
.filter_map(|(&dev, b)| b.kernels.get(&name).map(|k| (dev, Arc::clone(k))))
.filter_map(|(&dev, b)| {
b.kernels.get(&name).map(|k| {
// Force initialization of the compute_state_object_info.
// This is delayed to prevent stalling the build queue
// while waiting on external processes.
k.calculate_compute_state_info(dev);
(dev, Arc::clone(k))
})
})
.collect();
let values = vec![None; kernel_info.args.len()];
@ -1866,16 +1907,20 @@ impl Kernel {
type_name.is_empty().not().then_some(type_name)
}
fn compute_state_object_info(&self, dev: &Device) -> &pipe_compute_state_object_info {
self.builds.get(dev).unwrap().info.get().unwrap()
}
pub fn priv_mem_size(&self, dev: &Device) -> cl_ulong {
self.builds.get(dev).unwrap().info.private_memory as cl_ulong
self.compute_state_object_info(dev).private_memory as cl_ulong
}
pub fn max_threads_per_block(&self, dev: &Device) -> usize {
self.builds.get(dev).unwrap().info.max_threads as usize
self.compute_state_object_info(dev).max_threads as usize
}
pub fn preferred_simd_size(&self, dev: &Device) -> usize {
self.builds.get(dev).unwrap().info.preferred_simd_size as usize
self.compute_state_object_info(dev).preferred_simd_size as usize
}
pub fn local_mem_size(&self, dev: &Device) -> cl_ulong {
@ -1902,7 +1947,7 @@ impl Kernel {
}
pub fn subgroup_sizes(&self, dev: &Device) -> impl ExactSizeIterator<Item = usize> + use<> {
SetBitIndices::from_msb(self.builds.get(dev).unwrap().info.simd_sizes).map(|bit| 1 << bit)
SetBitIndices::from_msb(self.compute_state_object_info(dev).simd_sizes).map(|bit| 1 << bit)
}
pub fn subgroups_for_block(&self, dev: &Device, block: &[usize]) -> usize {