rusticl/icd: move get_ref()

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/27376>
This commit is contained in:
Karol Herbst 2024-01-28 20:11:41 +01:00 committed by Marge Bot
parent 13241264f1
commit e63e21ac74
8 changed files with 52 additions and 46 deletions

View file

@ -24,7 +24,7 @@ use std::slice;
#[cl_info_entrypoint(cl_get_context_info)]
impl CLInfo<cl_context_info> for cl_context {
fn query(&self, q: cl_context_info, _: &[u8]) -> CLResult<Vec<MaybeUninit<u8>>> {
let ctx = self.get_ref()?;
let ctx = Context::ref_from_raw(*self)?;
Ok(match q {
CL_CONTEXT_DEVICES => cl_prop::<Vec<cl_device_id>>(
ctx.devs
@ -166,7 +166,7 @@ fn create_context(
// Duplicate devices specified in devices are ignored.
let set: HashSet<_> =
HashSet::from_iter(unsafe { slice::from_raw_parts(devices, num_devices as usize) }.iter());
let devs: Result<_, _> = set.into_iter().map(cl_device_id::get_ref).collect();
let devs: Result<_, _> = set.into_iter().map(|&d| Device::ref_from_raw(d)).collect();
let devs: Vec<&Device> = devs?;
let gl_ctx_manager = GLCtxManager::new(gl_context, glx_display, egl_display)?;
@ -248,7 +248,7 @@ fn set_context_destructor_callback(
pfn_notify: ::std::option::Option<FuncDeleteContextCB>,
user_data: *mut ::std::os::raw::c_void,
) -> CLResult<()> {
let c = context.get_ref()?;
let c = Context::ref_from_raw(context)?;
// SAFETY: The requirements on `DeleteContextCB::new` match the requirements
// imposed by the OpenCL specification. It is the caller's duty to uphold them.

View file

@ -29,7 +29,7 @@ type ClDevIdpAccelProps = cl_device_integer_dot_product_acceleration_properties_
#[cl_info_entrypoint(cl_get_device_info)]
impl CLInfo<cl_device_info> for cl_device_id {
fn query(&self, q: cl_device_info, _: &[u8]) -> CLResult<Vec<MaybeUninit<u8>>> {
let dev = self.get_ref()?;
let dev = Device::ref_from_raw(*self)?;
// curses you CL_DEVICE_INTEGER_DOT_PRODUCT_ACCELERATION_PROPERTIES_4x8BIT_PACKED_KHR
#[allow(non_upper_case_globals)]
@ -405,7 +405,7 @@ fn get_host_timer(device_id: cl_device_id, host_timestamp: *mut cl_ulong) -> CLR
return Err(CL_INVALID_VALUE);
}
let device = device_id.get_ref()?;
let device = Device::ref_from_raw(device_id)?;
if !device.has_timestamp {
// CL_INVALID_OPERATION if the platform associated with device does not support device and host timer synchronization

View file

@ -16,7 +16,7 @@ use std::sync::Arc;
#[cl_info_entrypoint(cl_get_event_info)]
impl CLInfo<cl_event_info> for cl_event {
fn query(&self, q: cl_event_info, _: &[u8]) -> CLResult<Vec<MaybeUninit<u8>>> {
let event = self.get_ref()?;
let event = Event::ref_from_raw(*self)?;
Ok(match *q {
CL_EVENT_COMMAND_EXECUTION_STATUS => cl_prop::<cl_int>(event.status()),
CL_EVENT_CONTEXT => {
@ -42,7 +42,7 @@ impl CLInfo<cl_event_info> for cl_event {
#[cl_info_entrypoint(cl_get_event_profiling_info)]
impl CLInfo<cl_profiling_info> for cl_event {
fn query(&self, q: cl_profiling_info, _: &[u8]) -> CLResult<Vec<MaybeUninit<u8>>> {
let event = self.get_ref()?;
let event = Event::ref_from_raw(*self)?;
if event.cmd_type == CL_COMMAND_USER {
// CL_PROFILING_INFO_NOT_AVAILABLE [...] if event is a user event object.
return Err(CL_PROFILING_INFO_NOT_AVAILABLE);
@ -118,7 +118,7 @@ fn set_event_callback(
pfn_event_notify: Option<FuncEventCB>,
user_data: *mut ::std::os::raw::c_void,
) -> CLResult<()> {
let e = event.get_ref()?;
let e = Event::ref_from_raw(event)?;
// CL_INVALID_VALUE [...] if command_exec_callback_type is not CL_SUBMITTED, CL_RUNNING, or CL_COMPLETE.
if ![CL_SUBMITTED, CL_RUNNING, CL_COMPLETE].contains(&(command_exec_callback_type as cl_uint)) {
@ -136,7 +136,7 @@ fn set_event_callback(
#[cl_entrypoint]
fn set_user_event_status(event: cl_event, execution_status: cl_int) -> CLResult<()> {
let e = event.get_ref()?;
let e = Event::ref_from_raw(event)?;
// CL_INVALID_VALUE if the execution_status is not CL_COMPLETE or a negative integer value.
if execution_status != CL_COMPLETE as cl_int && execution_status > 0 {

View file

@ -234,10 +234,6 @@ pub trait ReferenceCountedAPIPointer<T, const ERR: i32> {
// implement that as part of the macro where we know the real type.
fn from_ptr(ptr: *const T) -> Self;
fn get_ref(&self) -> CLResult<&T> {
unsafe { Ok(self.get_ptr()?.as_ref().unwrap()) }
}
fn get_arc(&self) -> CLResult<Arc<T>> {
unsafe {
let ptr = self.get_ptr()?;
@ -286,7 +282,16 @@ pub trait CLObject<'a, const ERR: i32, CL: ReferenceCountedAPIPointer<Self, ERR>
Ok(res as u32)
}
fn refs_from_arr(objs: *const CL, count: u32) -> CLResult<Vec<&'a Self>> {
fn ref_from_raw(obj: CL) -> CLResult<&'a Self> {
let obj = obj.get_ptr()?;
// SAFETY: `get_ptr` already checks if it's one of our pointers and not null
Ok(unsafe { &*obj })
}
fn refs_from_arr(objs: *const CL, count: u32) -> CLResult<Vec<&'a Self>>
where
CL: Copy,
{
// CL spec requires validation for obj arrays, both values have to make sense
if objs.is_null() && count > 0 || !objs.is_null() && count == 0 {
return Err(CL_INVALID_VALUE);
@ -298,9 +303,7 @@ pub trait CLObject<'a, const ERR: i32, CL: ReferenceCountedAPIPointer<Self, ERR>
}
for i in 0..count as usize {
unsafe {
res.push((*objs.add(i)).get_ref()?);
}
res.push(Self::ref_from_raw(unsafe { *objs.add(i) })?);
}
Ok(res)
}

View file

@ -1,6 +1,7 @@
use crate::api::event::create_and_queue;
use crate::api::icd::*;
use crate::api::util::*;
use crate::core::device::*;
use crate::core::event::*;
use crate::core::kernel::*;
@ -20,7 +21,7 @@ use std::sync::Arc;
#[cl_info_entrypoint(cl_get_kernel_info)]
impl CLInfo<cl_kernel_info> for cl_kernel {
fn query(&self, q: cl_kernel_info, _: &[u8]) -> CLResult<Vec<MaybeUninit<u8>>> {
let kernel = self.get_ref()?;
let kernel = Kernel::ref_from_raw(*self)?;
Ok(match q {
CL_KERNEL_ATTRIBUTES => cl_prop::<&str>(&kernel.kernel_info.attributes_string),
CL_KERNEL_CONTEXT => {
@ -43,7 +44,7 @@ impl CLInfo<cl_kernel_info> for cl_kernel {
#[cl_info_entrypoint(cl_get_kernel_arg_info)]
impl CLInfoObj<cl_kernel_arg_info, cl_uint> for cl_kernel {
fn query(&self, idx: cl_uint, q: cl_kernel_arg_info) -> CLResult<Vec<MaybeUninit<u8>>> {
let kernel = self.get_ref()?;
let kernel = Kernel::ref_from_raw(*self)?;
// CL_INVALID_ARG_INDEX if arg_index is not a valid argument index.
if idx as usize >= kernel.kernel_info.args.len() {
@ -75,7 +76,7 @@ impl CLInfoObj<cl_kernel_work_group_info, cl_device_id> for cl_kernel {
dev: cl_device_id,
q: cl_kernel_work_group_info,
) -> CLResult<Vec<MaybeUninit<u8>>> {
let kernel = self.get_ref()?;
let kernel = Kernel::ref_from_raw(*self)?;
// CL_INVALID_DEVICE [..] if device is NULL but there is more than one device associated with kernel.
let dev = if dev.is_null() {
@ -85,7 +86,7 @@ impl CLInfoObj<cl_kernel_work_group_info, cl_device_id> for cl_kernel {
kernel.prog.devs[0]
}
} else {
dev.get_ref()?
Device::ref_from_raw(dev)?
};
// CL_INVALID_DEVICE if device is not in the list of devices associated with kernel
@ -120,7 +121,7 @@ impl CLInfoObj<cl_kernel_sub_group_info, (cl_device_id, usize, *const c_void, us
),
q: cl_program_build_info,
) -> CLResult<Vec<MaybeUninit<u8>>> {
let kernel = self.get_ref()?;
let kernel = Kernel::ref_from_raw(*self)?;
// CL_INVALID_DEVICE [..] if device is NULL but there is more than one device associated
// with kernel.
@ -131,7 +132,7 @@ impl CLInfoObj<cl_kernel_sub_group_info, (cl_device_id, usize, *const c_void, us
kernel.prog.devs[0]
}
} else {
dev.get_ref()?
Device::ref_from_raw(dev)?
};
// CL_INVALID_DEVICE if device is not in the list of devices associated with kernel
@ -421,7 +422,7 @@ fn set_kernel_arg_svm_pointer(
arg_index: cl_uint,
arg_value: *const ::std::os::raw::c_void,
) -> CLResult<()> {
let kernel = kernel.get_ref()?;
let kernel = Kernel::ref_from_raw(kernel)?;
let arg_index = arg_index as usize;
let arg_value = arg_value as usize;
@ -454,7 +455,7 @@ fn set_kernel_exec_info(
param_value_size: usize,
param_value: *const ::std::os::raw::c_void,
) -> CLResult<()> {
let k = kernel.get_ref()?;
let k = Kernel::ref_from_raw(kernel)?;
// CL_INVALID_OPERATION if no devices in the context associated with kernel support SVM.
if !k.prog.devs.iter().any(|dev| dev.svm_supported()) {
@ -643,6 +644,6 @@ fn enqueue_task(
#[cl_entrypoint]
fn clone_kernel(source_kernel: cl_kernel) -> CLResult<cl_kernel> {
let k = source_kernel.get_ref()?;
let k = Kernel::ref_from_raw(source_kernel)?;
Ok(cl_kernel::from_arc(Arc::new(k.clone())))
}

View file

@ -214,7 +214,7 @@ fn validate_matching_buffer_flags(mem: &Mem, flags: cl_mem_flags) -> CLResult<()
#[cl_info_entrypoint(cl_get_mem_object_info)]
impl CLInfo<cl_mem_info> for cl_mem {
fn query(&self, q: cl_mem_info, _: &[u8]) -> CLResult<Vec<MaybeUninit<u8>>> {
let mem = self.get_ref()?;
let mem = Mem::ref_from_raw(*self)?;
Ok(match *q {
CL_MEM_ASSOCIATED_MEMOBJECT => {
let ptr = match mem.parent.as_ref() {
@ -355,7 +355,7 @@ fn set_mem_object_destructor_callback(
pfn_notify: Option<FuncMemCB>,
user_data: *mut ::std::os::raw::c_void,
) -> CLResult<()> {
let m = memobj.get_ref()?;
let m = Mem::ref_from_raw(memobj)?;
// SAFETY: The requirements on `MemCB::new` match the requirements
// imposed by the OpenCL specification. It is the caller's duty to uphold them.
@ -596,7 +596,7 @@ fn validate_buffer(
// the specified memory objects data store are modified, those changes are reflected in the
// contents of the image object and vice-versa at corresponding synchronization points.
if !mem_object.is_null() {
let mem = mem_object.get_ref()?;
let mem = Mem::ref_from_raw(mem_object)?;
match mem.mem_type {
CL_MEM_OBJECT_BUFFER => {
@ -703,7 +703,7 @@ fn validate_buffer(
#[cl_info_entrypoint(cl_get_image_info)]
impl CLInfo<cl_image_info> for cl_mem {
fn query(&self, q: cl_image_info, _: &[u8]) -> CLResult<Vec<MaybeUninit<u8>>> {
let mem = self.get_ref()?;
let mem = Mem::ref_from_raw(*self)?;
Ok(match *q {
CL_IMAGE_ARRAY_SIZE => cl_prop::<usize>(mem.image_desc.image_array_size),
CL_IMAGE_BUFFER => cl_prop::<cl_mem>(unsafe { mem.image_desc.anon_1.buffer }),
@ -859,7 +859,7 @@ fn get_supported_image_formats(
image_formats: *mut cl_image_format,
num_image_formats: *mut cl_uint,
) -> CLResult<()> {
let c = context.get_ref()?;
let c = Context::ref_from_raw(context)?;
// CL_INVALID_VALUE if flags
validate_mem_flags(flags, true)?;
@ -898,7 +898,7 @@ fn get_supported_image_formats(
#[cl_info_entrypoint(cl_get_sampler_info)]
impl CLInfo<cl_sampler_info> for cl_sampler {
fn query(&self, q: cl_sampler_info, _: &[u8]) -> CLResult<Vec<MaybeUninit<u8>>> {
let sampler = self.get_ref()?;
let sampler = Sampler::ref_from_raw(*self)?;
Ok(match q {
CL_SAMPLER_ADDRESSING_MODE => cl_prop::<cl_addressing_mode>(sampler.addressing_mode),
CL_SAMPLER_CONTEXT => {
@ -2276,7 +2276,7 @@ pub fn svm_alloc(
// clSVMAlloc will fail if
// context is not a valid context
let c = context.get_ref()?;
let c = Context::ref_from_raw(context)?;
// or no devices in context support SVM.
if !c.has_svm_devs() {
@ -2337,7 +2337,7 @@ fn svm_free_impl(c: &Context, svm_pointer: *mut c_void) {
}
pub fn svm_free(context: cl_context, svm_pointer: *mut c_void) -> CLResult<()> {
let c = context.get_ref()?;
let c = Context::ref_from_raw(context)?;
svm_free_impl(c, svm_pointer);
Ok(())
}
@ -2964,7 +2964,7 @@ fn create_pipe(
#[cl_info_entrypoint(cl_get_gl_texture_info)]
impl CLInfo<cl_gl_texture_info> for cl_mem {
fn query(&self, q: cl_gl_texture_info, _: &[u8]) -> CLResult<Vec<MaybeUninit<u8>>> {
let mem = self.get_ref()?;
let mem = Mem::ref_from_raw(*self)?;
Ok(match *q {
CL_GL_MIPMAP_LEVEL => cl_prop::<cl_GLint>(0),
CL_GL_TEXTURE_TARGET => cl_prop::<cl_GLenum>(
@ -3093,7 +3093,7 @@ fn get_gl_object_info(
gl_object_type: *mut cl_gl_object_type,
gl_object_name: *mut cl_GLuint,
) -> CLResult<()> {
let m = memobj.get_ref()?;
let m = Mem::ref_from_raw(memobj)?;
match &m.gl_obj {
Some(gl_obj) => {

View file

@ -24,7 +24,7 @@ use std::sync::Arc;
#[cl_info_entrypoint(cl_get_program_info)]
impl CLInfo<cl_program_info> for cl_program {
fn query(&self, q: cl_program_info, vals: &[u8]) -> CLResult<Vec<MaybeUninit<u8>>> {
let prog = self.get_ref()?;
let prog = Program::ref_from_raw(*self)?;
Ok(match q {
CL_PROGRAM_BINARIES => cl_prop::<Vec<*mut u8>>(prog.binaries(vals)),
CL_PROGRAM_BINARY_SIZES => cl_prop::<Vec<usize>>(prog.bin_sizes()),
@ -62,7 +62,7 @@ impl CLInfo<cl_program_info> for cl_program {
#[cl_info_entrypoint(cl_get_program_build_info)]
impl CLInfoObj<cl_program_build_info, cl_device_id> for cl_program {
fn query(&self, d: cl_device_id, q: cl_program_build_info) -> CLResult<Vec<MaybeUninit<u8>>> {
let prog = self.get_ref()?;
let prog = Program::ref_from_raw(*self)?;
let dev = d.get_arc()?;
Ok(match q {
CL_PROGRAM_BINARY_TYPE => cl_prop::<cl_program_binary_type>(prog.bin_type(&dev)),
@ -281,7 +281,7 @@ fn build_program(
user_data: *mut ::std::os::raw::c_void,
) -> CLResult<()> {
let mut res = true;
let p = program.get_ref()?;
let p = Program::ref_from_raw(program)?;
let devs = validate_devices(device_list, num_devices, &p.devs)?;
// SAFETY: The requirements on `ProgramCB::try_new` match the requirements
@ -329,7 +329,7 @@ fn compile_program(
user_data: *mut ::std::os::raw::c_void,
) -> CLResult<()> {
let mut res = true;
let p = program.get_ref()?;
let p = Program::ref_from_raw(program)?;
let devs = validate_devices(device_list, num_devices, &p.devs)?;
// SAFETY: The requirements on `ProgramCB::try_new` match the requirements
@ -352,7 +352,7 @@ fn compile_program(
if !p.is_il() {
for h in 0..num_input_headers as usize {
// SAFETY: have to trust the application here
let header = unsafe { (*input_headers.add(h)).get_ref()? };
let header = Program::ref_from_raw(unsafe { *input_headers.add(h) })?;
match &header.src {
ProgramSourceType::Src(src) => headers.push(spirv::CLCHeader {
// SAFETY: have to trust the application here
@ -468,7 +468,7 @@ fn set_program_specialization_constant(
spec_size: usize,
spec_value: *const ::std::os::raw::c_void,
) -> CLResult<()> {
let program = program.get_ref()?;
let program = Program::ref_from_raw(program)?;
// CL_INVALID_PROGRAM if program is not a valid program object created from an intermediate
// language (e.g. SPIR-V)

View file

@ -17,7 +17,7 @@ use std::sync::Arc;
#[cl_info_entrypoint(cl_get_command_queue_info)]
impl CLInfo<cl_command_queue_info> for cl_command_queue {
fn query(&self, q: cl_command_queue_info, _: &[u8]) -> CLResult<Vec<MaybeUninit<u8>>> {
let queue = self.get_ref()?;
let queue = Queue::ref_from_raw(*self)?;
Ok(match q {
CL_QUEUE_CONTEXT => {
// Note we use as_ptr here which doesn't increase the reference count.
@ -74,7 +74,9 @@ pub fn create_command_queue_impl(
properties_v2: Option<Properties<cl_queue_properties>>,
) -> CLResult<cl_command_queue> {
let c = context.get_arc()?;
let d = device.get_ref()?.to_static().ok_or(CL_INVALID_DEVICE)?;
let d = Device::ref_from_raw(device)?
.to_static()
.ok_or(CL_INVALID_DEVICE)?;
// CL_INVALID_DEVICE if device [...] is not associated with context.
if !c.devs.contains(&d) {
@ -206,13 +208,13 @@ fn enqueue_barrier_with_wait_list(
#[cl_entrypoint]
fn flush(command_queue: cl_command_queue) -> CLResult<()> {
// CL_INVALID_COMMAND_QUEUE if command_queue is not a valid host command-queue.
command_queue.get_ref()?.flush(false)
Queue::ref_from_raw(command_queue)?.flush(false)
}
#[cl_entrypoint]
fn finish(command_queue: cl_command_queue) -> CLResult<()> {
// CL_INVALID_COMMAND_QUEUE if command_queue is not a valid host command-queue.
command_queue.get_ref()?.flush(true)
Queue::ref_from_raw(command_queue)?.flush(true)
}
#[cl_entrypoint]