diff --git a/src/gallium/frontends/rusticl/core/device.rs b/src/gallium/frontends/rusticl/core/device.rs index 88c21e95000..b8f1a2d9dff 100644 --- a/src/gallium/frontends/rusticl/core/device.rs +++ b/src/gallium/frontends/rusticl/core/device.rs @@ -73,6 +73,9 @@ pub trait HelperContextWrapper { fn texture_map_coherent(&self, res: &PipeResource, bx: &pipe_box, rw: RWFlags) -> PipeTransfer; + fn create_compute_state(&self, nir: &NirShader, static_local_mem: u32) -> *mut c_void; + fn delete_compute_state(&self, cso: *mut c_void); + fn unmap(&self, tx: PipeTransfer); } @@ -148,6 +151,14 @@ impl<'a> HelperContextWrapper for HelperContext<'a> { .texture_map(res, bx, rw, ResourceMapType::Coherent) } + fn create_compute_state(&self, nir: &NirShader, static_local_mem: u32) -> *mut c_void { + self.lock.create_compute_state(nir, static_local_mem) + } + + fn delete_compute_state(&self, cso: *mut c_void) { + self.lock.delete_compute_state(cso) + } + fn unmap(&self, tx: PipeTransfer) { tx.with_ctx(&self.lock); } @@ -727,6 +738,10 @@ impl Device { id as u32 } + pub fn shareable_shaders(&self) -> bool { + self.screen.param(pipe_cap::PIPE_CAP_SHAREABLE_SHADERS) == 1 + } + pub fn helper_ctx(&self) -> impl HelperContextWrapper + '_ { HelperContext { lock: self.helper_ctx.lock().unwrap(), diff --git a/src/gallium/frontends/rusticl/core/kernel.rs b/src/gallium/frontends/rusticl/core/kernel.rs index c767709267b..251e773c11b 100644 --- a/src/gallium/frontends/rusticl/core/kernel.rs +++ b/src/gallium/frontends/rusticl/core/kernel.rs @@ -245,6 +245,48 @@ impl InternalKernelArg { } } +struct KernelDevStateInner { + nir: NirShader, + cso: *mut c_void, +} + +struct KernelDevState { + states: HashMap, KernelDevStateInner>, +} + +impl Drop for KernelDevState { + fn drop(&mut self) { + self.states.iter().for_each(|(dev, dev_state)| { + if !dev_state.cso.is_null() { + dev.helper_ctx().delete_compute_state(dev_state.cso); + } + }) + } +} + +impl KernelDevState { + fn new(nirs: HashMap, NirShader>) -> Arc { + let states = nirs + .into_iter() + .map(|(dev, nir)| { + let cso = if dev.shareable_shaders() { + dev.helper_ctx() + .create_compute_state(&nir, nir.shared_size()) + } else { + ptr::null_mut() + }; + (dev, KernelDevStateInner { nir: nir, cso: cso }) + }) + .collect(); + + Arc::new(Self { states: states }) + } + + fn get(&self, dev: &Device) -> &KernelDevStateInner { + self.states.get(dev).unwrap() + } +} + #[repr(C)] pub struct Kernel { pub base: CLObjectBase, @@ -255,7 +297,7 @@ pub struct Kernel { pub work_group_size: [usize; 3], pub attributes_string: String, internal_args: Vec, - nirs: HashMap, NirShader>, + dev_state: Arc, } impl_cl_type_trait!(cl_kernel, Kernel, CL_INVALID_KERNEL); @@ -790,8 +832,7 @@ impl Kernel { attributes_string: attributes_string, values: values, internal_args: internal_args, - // caller has to verify all kernels have the same sig - nirs: nirs, + dev_state: KernelDevState::new(nirs), }) } @@ -805,7 +846,7 @@ impl Kernel { grid: &[usize], offsets: &[usize], ) -> CLResult { - let nir = self.nirs.get(&q.device).unwrap(); + let nir = &self.dev_state.get(&q.device).nir; let mut block = create_kernel_arr::(block, 1); let mut grid = create_kernel_arr::(grid, 1); let offsets = create_kernel_arr::(offsets, 0); @@ -969,11 +1010,11 @@ impl Kernel { let k = Arc::clone(self); Ok(Box::new(move |q, ctx| { - let nir = k.nirs.get(&q.device).unwrap(); + let dev_state = k.dev_state.get(&q.device); let mut input = input.clone(); let mut resources = Vec::with_capacity(resource_info.len()); let mut globals: Vec<*mut u32> = Vec::new(); - let printf_format = nir.printf_format(); + let printf_format = dev_state.nir.printf_format(); let mut sviews: Vec<_> = sviews .iter() @@ -998,7 +1039,12 @@ impl Kernel { init_data.len() as u32, ); } - let cso = ctx.create_compute_state(nir, static_local_size as u32); + + let cso = if dev_state.cso.is_null() { + ctx.create_compute_state(&dev_state.nir, static_local_size as u32) + } else { + dev_state.cso + }; ctx.bind_compute_state(cso); ctx.bind_sampler_states(&samplers); @@ -1013,8 +1059,12 @@ impl Kernel { ctx.clear_shader_images(iviews.len() as u32); ctx.clear_sampler_views(sviews.len() as u32); ctx.clear_sampler_states(samplers.len() as u32); + ctx.bind_compute_state(ptr::null_mut()); - ctx.delete_compute_state(cso); + if dev_state.cso.is_null() { + ctx.delete_compute_state(cso); + } + ctx.memory_barrier(PIPE_BARRIER_GLOBAL_BUFFER); samplers.iter().for_each(|s| ctx.delete_sampler_state(*s)); @@ -1115,12 +1165,12 @@ impl Kernel { } pub fn priv_mem_size(&self, dev: &Arc) -> cl_ulong { - self.nirs.get(dev).unwrap().scratch_size() as cl_ulong + self.dev_state.get(dev).nir.scratch_size() as cl_ulong } pub fn local_mem_size(&self, dev: &Arc) -> cl_ulong { // TODO include args - self.nirs.get(dev).unwrap().shared_size() as cl_ulong + self.dev_state.get(dev).nir.shared_size() as cl_ulong } } @@ -1135,7 +1185,7 @@ impl Clone for Kernel { work_group_size: self.work_group_size, attributes_string: self.attributes_string.clone(), internal_args: self.internal_args.clone(), - nirs: self.nirs.clone(), + dev_state: self.dev_state.clone(), } } }