rusticl: add a safe abstraction to execute an SVMFreeCb

Reviewed-by: Karol Herbst <kherbst@redhat.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/25669>
This commit is contained in:
LingMan 2023-10-12 21:31:31 +02:00 committed by Marge Bot
parent 8b1d73ff23
commit 6e2ba679ff
2 changed files with 21 additions and 9 deletions

View file

@ -2369,7 +2369,7 @@ fn enqueue_svm_free_impl(
// The application is allowed to reuse or free the memory referenced by `svm_pointers` after this
// function returns so we have to make a copy.
// SAFETY: num_svm_pointers specifies the amount of elements in svm_pointers
let svm_pointers =
let mut svm_pointers =
unsafe { slice::from_raw_parts(svm_pointers, num_svm_pointers as usize) }.to_vec();
// SAFETY: The requirements on `SVMFreeCb::new` match the requirements
// imposed by the OpenCL specification. It is the caller's duty to uphold them.
@ -2382,15 +2382,10 @@ fn enqueue_svm_free_impl(
event,
false,
Box::new(move |q, _| {
if let Some(cb) = &cb_opt {
let mut svm_pointers = svm_pointers.clone();
let ptr = svm_pointers.as_mut_ptr();
// SAFETY: it's undefined behavior if the application screws up
unsafe {
(cb.func)(command_queue, num_svm_pointers, ptr, cb.data);
}
if let Some(cb) = cb_opt {
cb.call(q, &mut svm_pointers);
} else {
for &ptr in &svm_pointers {
for ptr in svm_pointers {
svm_free_impl(&q.context, ptr);
}
}

View file

@ -3,6 +3,7 @@ use crate::api::icd::ReferenceCountedAPIPointer;
use crate::core::context::Context;
use crate::core::event::Event;
use crate::core::memory::Mem;
use crate::core::queue::Queue;
use rusticl_opencl_gen::*;
@ -172,6 +173,22 @@ cl_callback!(
}
);
impl SVMFreeCb {
pub fn call(self, queue: &Queue, svm_pointers: &mut [*mut c_void]) {
let cl = cl_command_queue::from_ptr(queue);
// SAFETY: `cl` must be a valid pointer to an OpenCL queue, which is where we just got it from.
// All other requirements are covered by this callback's type invariants.
unsafe {
(self.func)(
cl,
svm_pointers.len() as u32,
svm_pointers.as_mut_ptr(),
self.data,
)
};
}
}
// a lot of APIs use 3 component vectors passed as C arrays
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub struct CLVec<T> {