rusticl/kernel: reduce CPU overhead of set_global_binding

This gets rid of the slice conversion we had to do previously.

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/36917>
This commit is contained in:
Karol Herbst 2025-08-21 12:42:36 +02:00 committed by Marge Bot
parent 357299052a
commit fc6f646262
2 changed files with 12 additions and 7 deletions

View file

@ -23,6 +23,7 @@ use mesa_rust_util::serialize::*;
use rusticl_opencl_gen::*;
use spirv::SpirvKernelInfo;
use std::borrow::Borrow;
use std::cmp;
use std::collections::HashMap;
use std::collections::HashSet;
@ -1422,11 +1423,11 @@ impl Kernel {
fn add_global<'a>(
ctx: &QueueContext,
input: &mut Vec<u8>,
resource_info: &mut Vec<(&'a PipeResourceOwned, usize)>,
resource_info: &mut Vec<(&'a PipeResource, usize)>,
res: &'a PipeResourceOwned,
offset: usize,
) {
resource_info.push((res, input.len()));
resource_info.push((res.borrow(), input.len()));
add_pointer(ctx, input, offset as u64);
}
@ -1656,7 +1657,7 @@ impl Kernel {
ctx.bind_sampler_states(samplers);
ctx.bind_sampler_views(sviews);
ctx.bind_shader_images(&iviews);
ctx.set_global_binding(resources.as_slice(), &mut globals);
ctx.set_global_binding(resources.as_mut_slice(), &mut globals);
for z in 0..grid[2].div_ceil(hw_max_grid[2]) {
for y in 0..grid[1].div_ceil(hw_max_grid[1]) {

View file

@ -517,14 +517,18 @@ impl PipeContext {
unsafe { self.pipe.as_ref().launch_grid.unwrap()(self.pipe.as_ptr(), &info) }
}
pub fn set_global_binding(&self, res: &[&PipeResourceOwned], out: &mut [*mut u32]) {
let mut res: Vec<_> = res.iter().copied().map(PipeResourceOwned::pipe).collect();
pub fn set_global_binding(&self, res: &mut [&PipeResource], out: &mut [*mut u32]) {
let len = res.len();
let res = PipeResource::slice_as_mut_ptr_slice(res);
// SAFETY: We can safely cast the *mut *const pointer to *mut *mut as drivers aren't going
// to change any of the pipe_resource fields, but merely allows them to change
// fields of their own subclass.
unsafe {
self.pipe.as_ref().set_global_binding.unwrap()(
self.pipe.as_ptr(),
0,
res.len() as u32,
res.as_mut_ptr(),
len as u32,
res.cast(),
out.as_mut_ptr(),
)
}