From 70794de7929fbe2f9bfffd4e686ee39d19c0225d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Se=C3=A1n=20de=20B=C3=BArca?= Date: Wed, 27 Aug 2025 15:29:11 -0700 Subject: [PATCH] rusticl/kernel: delay calculation of CSO info until kernel creation Reviewed-by: Karol Herbst Part-of: --- src/gallium/frontends/rusticl/core/kernel.rs | 97 ++++++++++++++------ 1 file changed, 71 insertions(+), 26 deletions(-) diff --git a/src/gallium/frontends/rusticl/core/kernel.rs b/src/gallium/frontends/rusticl/core/kernel.rs index c3633cc77c3..0f858e6ad1e 100644 --- a/src/gallium/frontends/rusticl/core/kernel.rs +++ b/src/gallium/frontends/rusticl/core/kernel.rs @@ -39,6 +39,7 @@ use std::slice; use std::sync::Arc; use std::sync::Mutex; use std::sync::MutexGuard; +use std::sync::OnceLock; use std::sync::Weak; // According to the CL spec we are not allowed to let any cl_kernel object hold any references on @@ -399,6 +400,23 @@ pub enum KernelDevStateVariant { Nir(NirShader), } +impl KernelDevStateVariant { + fn calculate_compute_state_info( + &self, + device: &'static Device, + ) -> pipe_compute_state_object_info { + match self { + KernelDevStateVariant::Cso(cso) => cso.get_cso_info(), + KernelDevStateVariant::Nir(nir) => { + // SAFETY: We never execute the compute state; we simply extract + // information and discard the CSO. + let cso = unsafe { SharedCSOWrapper::new(device, nir) }; + cso.get_cso_info() + } + } + } +} + #[derive(Clone, Copy, Debug, PartialEq)] pub enum NirKernelVariant { /// Can be used under any circumstance. @@ -421,8 +439,9 @@ impl Display for NirKernelVariant { pub struct NirKernelBuilds { default_build: NirKernelBuild, optimized: Option, - /// merged info with worst case values - info: pipe_compute_state_object_info, + + /// Compute state info with the worst-case values between available builds. + info: OnceLock, } impl Index for NirKernelBuilds { @@ -438,27 +457,41 @@ impl Index for NirKernelBuilds { impl NirKernelBuilds { fn new(default_build: NirKernelBuild, optimized: Option) -> Self { - let mut info = default_build.info; - if let Some(build) = &optimized { - info.max_threads = cmp::min(info.max_threads, build.info.max_threads); - info.simd_sizes &= build.info.simd_sizes; - info.private_memory = cmp::max(info.private_memory, build.info.private_memory); - info.preferred_simd_size = - cmp::max(info.preferred_simd_size, build.info.preferred_simd_size); - } - Self { default_build: default_build, optimized: optimized, - info: info, + info: OnceLock::new(), } } + + /// Calculates the worst-case compute state info between the available + /// builds and stores the result. + fn calculate_compute_state_info(&self, device: &'static Device) { + self.info.get_or_init(|| { + let mut info = self + .default_build + .nir_or_cso + .calculate_compute_state_info(device); + + if let Some(optimized) = &self.optimized { + let optimized_info = optimized.nir_or_cso.calculate_compute_state_info(device); + + // Calculate worst-case values for each parameter. + info.max_threads = cmp::min(info.max_threads, optimized_info.max_threads); + info.simd_sizes &= optimized_info.simd_sizes; + info.private_memory = cmp::max(info.private_memory, optimized_info.private_memory); + info.preferred_simd_size = + cmp::max(info.preferred_simd_size, optimized_info.preferred_simd_size); + } + + info + }); + } } pub struct NirKernelBuild { nir_or_cso: KernelDevStateVariant, constant_buffer: Option, - info: pipe_compute_state_object_info, shared_size: u64, input_size: u32, printf_info: Option, @@ -472,24 +505,23 @@ unsafe impl Sync for NirKernelBuild {} impl NirKernelBuild { fn new(dev: &'static Device, mut out: CompilationResult) -> Self { - // SAFETY: we only use the cso when dev supports shareable shaders, otherwise we just - // extract some info and throw it away, which is safe. - let cso = unsafe { SharedCSOWrapper::new(dev, &out.nir) }; - let info = cso.get_cso_info(); let cb = Self::create_nir_constant_buffer(dev, &out.nir); let shared_size = out.nir.shared_size() as u64; let printf_info = out.nir.take_printf_info(); - let nir_or_cso = if !dev.shareable_shaders() { - KernelDevStateVariant::Nir(out.nir) - } else { + let nir_or_cso = if dev.shareable_shaders() { + // SAFETY: The device supports shareable shaders, upholding the + // safety requirements of `SharedCSOWrapper`. + let cso = unsafe { SharedCSOWrapper::new(dev, &out.nir) }; + KernelDevStateVariant::Cso(cso) + } else { + KernelDevStateVariant::Nir(out.nir) }; NirKernelBuild { nir_or_cso: nir_or_cso, constant_buffer: cb, - info: info, shared_size: shared_size, input_size: out.input_size, printf_info: printf_info, @@ -1349,7 +1381,16 @@ impl Kernel { let builds = prog_build .builds_by_device .iter() - .filter_map(|(&dev, b)| b.kernels.get(&name).map(|k| (dev, Arc::clone(k)))) + .filter_map(|(&dev, b)| { + b.kernels.get(&name).map(|k| { + // Force initialization of the compute_state_object_info. + // This is delayed to prevent stalling the build queue + // while waiting on external processes. + k.calculate_compute_state_info(dev); + + (dev, Arc::clone(k)) + }) + }) .collect(); let values = vec![None; kernel_info.args.len()]; @@ -1866,16 +1907,20 @@ impl Kernel { type_name.is_empty().not().then_some(type_name) } + fn compute_state_object_info(&self, dev: &Device) -> &pipe_compute_state_object_info { + self.builds.get(dev).unwrap().info.get().unwrap() + } + pub fn priv_mem_size(&self, dev: &Device) -> cl_ulong { - self.builds.get(dev).unwrap().info.private_memory as cl_ulong + self.compute_state_object_info(dev).private_memory as cl_ulong } pub fn max_threads_per_block(&self, dev: &Device) -> usize { - self.builds.get(dev).unwrap().info.max_threads as usize + self.compute_state_object_info(dev).max_threads as usize } pub fn preferred_simd_size(&self, dev: &Device) -> usize { - self.builds.get(dev).unwrap().info.preferred_simd_size as usize + self.compute_state_object_info(dev).preferred_simd_size as usize } pub fn local_mem_size(&self, dev: &Device) -> cl_ulong { @@ -1902,7 +1947,7 @@ impl Kernel { } pub fn subgroup_sizes(&self, dev: &Device) -> impl ExactSizeIterator + use<> { - SetBitIndices::from_msb(self.builds.get(dev).unwrap().info.simd_sizes).map(|bit| 1 << bit) + SetBitIndices::from_msb(self.compute_state_object_info(dev).simd_sizes).map(|bit| 1 << bit) } pub fn subgroups_for_block(&self, dev: &Device, block: &[usize]) -> usize {