From 6e2ba679ff7be6ecbe935049eca7d4eab32aa1e0 Mon Sep 17 00:00:00 2001 From: LingMan <18294-LingMan@users.noreply.gitlab.freedesktop.org> Date: Thu, 12 Oct 2023 21:31:31 +0200 Subject: [PATCH] rusticl: add a safe abstraction to execute an SVMFreeCb Reviewed-by: Karol Herbst Part-of: --- src/gallium/frontends/rusticl/api/memory.rs | 13 ++++--------- src/gallium/frontends/rusticl/api/types.rs | 17 +++++++++++++++++ 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/src/gallium/frontends/rusticl/api/memory.rs b/src/gallium/frontends/rusticl/api/memory.rs index dfac9ff39aa..33941eb6298 100644 --- a/src/gallium/frontends/rusticl/api/memory.rs +++ b/src/gallium/frontends/rusticl/api/memory.rs @@ -2369,7 +2369,7 @@ fn enqueue_svm_free_impl( // The application is allowed to reuse or free the memory referenced by `svm_pointers` after this // function returns so we have to make a copy. // SAFETY: num_svm_pointers specifies the amount of elements in svm_pointers - let svm_pointers = + let mut svm_pointers = unsafe { slice::from_raw_parts(svm_pointers, num_svm_pointers as usize) }.to_vec(); // SAFETY: The requirements on `SVMFreeCb::new` match the requirements // imposed by the OpenCL specification. It is the caller's duty to uphold them. @@ -2382,15 +2382,10 @@ fn enqueue_svm_free_impl( event, false, Box::new(move |q, _| { - if let Some(cb) = &cb_opt { - let mut svm_pointers = svm_pointers.clone(); - let ptr = svm_pointers.as_mut_ptr(); - // SAFETY: it's undefined behavior if the application screws up - unsafe { - (cb.func)(command_queue, num_svm_pointers, ptr, cb.data); - } + if let Some(cb) = cb_opt { + cb.call(q, &mut svm_pointers); } else { - for &ptr in &svm_pointers { + for ptr in svm_pointers { svm_free_impl(&q.context, ptr); } } diff --git a/src/gallium/frontends/rusticl/api/types.rs b/src/gallium/frontends/rusticl/api/types.rs index cad831cc264..0e8c7082cb3 100644 --- a/src/gallium/frontends/rusticl/api/types.rs +++ b/src/gallium/frontends/rusticl/api/types.rs @@ -3,6 +3,7 @@ use crate::api::icd::ReferenceCountedAPIPointer; use crate::core::context::Context; use crate::core::event::Event; use crate::core::memory::Mem; +use crate::core::queue::Queue; use rusticl_opencl_gen::*; @@ -172,6 +173,22 @@ cl_callback!( } ); +impl SVMFreeCb { + pub fn call(self, queue: &Queue, svm_pointers: &mut [*mut c_void]) { + let cl = cl_command_queue::from_ptr(queue); + // SAFETY: `cl` must be a valid pointer to an OpenCL queue, which is where we just got it from. + // All other requirements are covered by this callback's type invariants. + unsafe { + (self.func)( + cl, + svm_pointers.len() as u32, + svm_pointers.as_mut_ptr(), + self.data, + ) + }; + } +} + // a lot of APIs use 3 component vectors passed as C arrays #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] pub struct CLVec {