rusticl/queue: cache bound CSO

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/33775>
This commit is contained in:
Karol Herbst 2025-03-04 00:15:44 +01:00 committed by Marge Bot
parent 00e3d75a58
commit 7d94fe8c5f
3 changed files with 110 additions and 19 deletions

View file

@ -10,6 +10,7 @@ use crate::impl_cl_type_trait;
use mesa_rust::compiler::clc::*;
use mesa_rust::compiler::nir::*;
use mesa_rust::nir_pass;
use mesa_rust::pipe::context::PipeContext;
use mesa_rust::pipe::context::RWFlags;
use mesa_rust::pipe::resource::*;
use mesa_rust::pipe::screen::ResourceType;
@ -343,7 +344,7 @@ pub struct KernelInfo {
}
/// Wraps around a compute state object which is safe to share between pipe_contexts.
struct SharedCSOWrapper {
pub struct SharedCSOWrapper {
cso_ptr: *mut c_void,
dev: &'static Device,
}
@ -364,6 +365,16 @@ impl SharedCSOWrapper {
}
}
/// # Safety
///
/// `self` needs to live until another CSOWrapper is bound to `ctx`
pub unsafe fn bind_to_ctx(&self, ctx: &PipeContext) {
// SAFETY: We make it the callers responsibility to uphold the safety requirements.
unsafe {
ctx.bind_compute_state(self.cso_ptr);
}
}
fn get_cso_info(&self) -> pipe_compute_state_object_info {
self.dev.helper_ctx().compute_state_info(self.cso_ptr)
}
@ -375,13 +386,13 @@ impl Drop for SharedCSOWrapper {
}
}
enum KernelDevStateVariant {
pub enum KernelDevStateVariant {
Cso(SharedCSOWrapper),
Nir(NirShader),
}
#[derive(Debug, PartialEq)]
enum NirKernelVariant {
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum NirKernelVariant {
/// Can be used under any circumstance.
Default,
@ -436,7 +447,7 @@ impl NirKernelBuilds {
}
}
struct NirKernelBuild {
pub struct NirKernelBuild {
nir_or_cso: KernelDevStateVariant,
constant_buffer: Option<Arc<PipeResource>>,
info: pipe_compute_state_object_info,
@ -496,6 +507,10 @@ impl NirKernelBuild {
None
}
}
pub fn nir_or_cso(&self) -> &KernelDevStateVariant {
&self.nir_or_cso
}
}
pub struct Kernel {
@ -1558,18 +1573,8 @@ impl Kernel {
globals.push(unsafe { input.as_mut_ptr().byte_add(offset) }.cast());
}
let temp_cso;
let cso = match &nir_kernel_build.nir_or_cso {
KernelDevStateVariant::Cso(cso) => cso,
KernelDevStateVariant::Nir(nir) => {
// SAFETY: this isn't safe at all, but we'll fix this in a later commit.
temp_cso = unsafe { SharedCSOWrapper::new(q.device, nir) };
&temp_cso
}
};
let sviews_len = sviews.len();
ctx.bind_compute_state(cso.cso_ptr);
ctx.bind_kernel(&nir_kernel_builds, variant)?;
ctx.bind_sampler_states(&samplers);
ctx.set_sampler_views(sviews);
ctx.set_shader_images(&iviews);
@ -1613,8 +1618,6 @@ impl Kernel {
ctx.clear_sampler_views(sviews_len as u32);
ctx.clear_sampler_states(samplers.len() as u32);
ctx.bind_compute_state(ptr::null_mut());
ctx.memory_barrier(PIPE_BARRIER_GLOBAL_BUFFER);
samplers.iter().for_each(|s| ctx.delete_sampler_state(*s));

View file

@ -2,18 +2,24 @@ use crate::api::icd::*;
use crate::core::context::*;
use crate::core::device::*;
use crate::core::event::*;
use crate::core::kernel::*;
use crate::core::platform::*;
use crate::impl_cl_type_trait;
use mesa_rust::compiler::nir::NirShader;
use mesa_rust::pipe::context::PipeContext;
use mesa_rust_gen::*;
use mesa_rust_util::properties::*;
use rusticl_opencl_gen::*;
use std::cell::RefCell;
use std::cmp;
use std::ffi::c_void;
use std::mem;
use std::mem::ManuallyDrop;
use std::ops::Deref;
use std::ptr;
use std::ptr::NonNull;
use std::sync::mpsc;
use std::sync::Arc;
use std::sync::Mutex;
@ -21,6 +27,32 @@ use std::sync::Weak;
use std::thread;
use std::thread::JoinHandle;
struct CSOWrapper<'a> {
ctx: &'a PipeContext,
cso: NonNull<c_void>,
}
impl<'a> CSOWrapper<'a> {
fn new(ctx: &QueueContext<'a>, nir: &NirShader) -> Option<CSOWrapper<'a>> {
Some(Self {
ctx: ctx.ctx,
cso: NonNull::new(ctx.create_compute_state(nir, nir.shared_size()))?,
})
}
}
impl Drop for CSOWrapper<'_> {
fn drop(&mut self) {
self.ctx.delete_compute_state(self.cso.as_ptr());
}
}
struct QueueKernelState<'a> {
builds: Option<Arc<NirKernelBuilds>>,
variant: NirKernelVariant,
cso: Option<CSOWrapper<'a>>,
}
/// State tracking wrapper for [PipeContext]
///
/// Used for tracking bound GPU state to lower CPU overhead and centralize state tracking
@ -28,9 +60,50 @@ pub struct QueueContext<'a> {
ctx: &'a PipeContext,
pub dev: &'static Device,
use_stream: bool,
kernel_state: RefCell<QueueKernelState<'a>>,
}
impl QueueContext<'_> {
// TODO: figure out how to make it &mut self without causing tons of borrowing issues.
pub fn bind_kernel(
&self,
builds: &Arc<NirKernelBuilds>,
variant: NirKernelVariant,
) -> CLResult<()> {
// this should never panic, but you never know.
let mut state = self.kernel_state.borrow_mut();
// If we already set the CSO then we don't have to bind again.
if let Some(stored_builds) = &state.builds {
if Arc::ptr_eq(stored_builds, builds) && state.variant == variant {
return Ok(());
}
}
let nir_kernel_build = &builds[variant];
match nir_kernel_build.nir_or_cso() {
// SAFETY: We keep the cso alive until a new one is set.
KernelDevStateVariant::Cso(cso) => unsafe {
cso.bind_to_ctx(self);
},
// TODO: We could cache the cso here.
KernelDevStateVariant::Nir(nir) => {
let cso = CSOWrapper::new(self, nir).ok_or(CL_OUT_OF_HOST_MEMORY)?;
unsafe {
self.bind_compute_state(cso.cso.as_ptr());
}
state.cso.replace(cso);
}
};
// We can only store the new builds after we bound the new cso otherwise we might drop it
// too early.
state.builds = Some(Arc::clone(builds));
state.variant = variant;
Ok(())
}
pub fn update_cb0(&self, data: &[u8]) -> CLResult<()> {
// only update if we actually bind data
if !data.is_empty() {
@ -58,6 +131,13 @@ impl Deref for QueueContext<'_> {
impl Drop for QueueContext<'_> {
fn drop(&mut self) {
self.set_constant_buffer(0, &[]);
if self.kernel_state.get_mut().builds.is_some() {
// SAFETY: We simply unbind here. The bound cso will only be dropped at the end of this
// drop handler.
unsafe {
self.ctx.bind_compute_state(ptr::null_mut());
}
}
}
}
@ -81,6 +161,11 @@ impl SendableQueueContext {
fn ctx(&self) -> QueueContext {
QueueContext {
ctx: &self.ctx,
kernel_state: RefCell::new(QueueKernelState {
builds: None,
variant: NirKernelVariant::Default,
cso: None,
}),
dev: self.dev,
use_stream: self.dev.prefers_real_buffer_in_cb0(),
}

View file

@ -309,7 +309,10 @@ impl PipeContext {
unsafe { self.pipe.as_ref().create_compute_state.unwrap()(self.pipe.as_ptr(), &state) }
}
pub fn bind_compute_state(&self, state: *mut c_void) {
/// # Safety
///
/// The state pointer needs to point to valid memory until a new one is set.
pub unsafe fn bind_compute_state(&self, state: *mut c_void) {
unsafe { self.pipe.as_ref().bind_compute_state.unwrap()(self.pipe.as_ptr(), state) }
}