diff --git a/src/gallium/frontends/rusticl/api/context.rs b/src/gallium/frontends/rusticl/api/context.rs index 033bfca0423..f9711274d73 100644 --- a/src/gallium/frontends/rusticl/api/context.rs +++ b/src/gallium/frontends/rusticl/api/context.rs @@ -24,7 +24,7 @@ use std::slice; #[cl_info_entrypoint(cl_get_context_info)] impl CLInfo for cl_context { fn query(&self, q: cl_context_info, _: &[u8]) -> CLResult>> { - let ctx = self.get_ref()?; + let ctx = Context::ref_from_raw(*self)?; Ok(match q { CL_CONTEXT_DEVICES => cl_prop::>( ctx.devs @@ -166,7 +166,7 @@ fn create_context( // Duplicate devices specified in devices are ignored. let set: HashSet<_> = HashSet::from_iter(unsafe { slice::from_raw_parts(devices, num_devices as usize) }.iter()); - let devs: Result<_, _> = set.into_iter().map(cl_device_id::get_ref).collect(); + let devs: Result<_, _> = set.into_iter().map(|&d| Device::ref_from_raw(d)).collect(); let devs: Vec<&Device> = devs?; let gl_ctx_manager = GLCtxManager::new(gl_context, glx_display, egl_display)?; @@ -248,7 +248,7 @@ fn set_context_destructor_callback( pfn_notify: ::std::option::Option, user_data: *mut ::std::os::raw::c_void, ) -> CLResult<()> { - let c = context.get_ref()?; + let c = Context::ref_from_raw(context)?; // SAFETY: The requirements on `DeleteContextCB::new` match the requirements // imposed by the OpenCL specification. It is the caller's duty to uphold them. diff --git a/src/gallium/frontends/rusticl/api/device.rs b/src/gallium/frontends/rusticl/api/device.rs index fbf1b4ea7e1..3bae01361a7 100644 --- a/src/gallium/frontends/rusticl/api/device.rs +++ b/src/gallium/frontends/rusticl/api/device.rs @@ -29,7 +29,7 @@ type ClDevIdpAccelProps = cl_device_integer_dot_product_acceleration_properties_ #[cl_info_entrypoint(cl_get_device_info)] impl CLInfo for cl_device_id { fn query(&self, q: cl_device_info, _: &[u8]) -> CLResult>> { - let dev = self.get_ref()?; + let dev = Device::ref_from_raw(*self)?; // curses you CL_DEVICE_INTEGER_DOT_PRODUCT_ACCELERATION_PROPERTIES_4x8BIT_PACKED_KHR #[allow(non_upper_case_globals)] @@ -405,7 +405,7 @@ fn get_host_timer(device_id: cl_device_id, host_timestamp: *mut cl_ulong) -> CLR return Err(CL_INVALID_VALUE); } - let device = device_id.get_ref()?; + let device = Device::ref_from_raw(device_id)?; if !device.has_timestamp { // CL_INVALID_OPERATION if the platform associated with device does not support device and host timer synchronization diff --git a/src/gallium/frontends/rusticl/api/event.rs b/src/gallium/frontends/rusticl/api/event.rs index 777e5085f62..19373d7aa0b 100644 --- a/src/gallium/frontends/rusticl/api/event.rs +++ b/src/gallium/frontends/rusticl/api/event.rs @@ -16,7 +16,7 @@ use std::sync::Arc; #[cl_info_entrypoint(cl_get_event_info)] impl CLInfo for cl_event { fn query(&self, q: cl_event_info, _: &[u8]) -> CLResult>> { - let event = self.get_ref()?; + let event = Event::ref_from_raw(*self)?; Ok(match *q { CL_EVENT_COMMAND_EXECUTION_STATUS => cl_prop::(event.status()), CL_EVENT_CONTEXT => { @@ -42,7 +42,7 @@ impl CLInfo for cl_event { #[cl_info_entrypoint(cl_get_event_profiling_info)] impl CLInfo for cl_event { fn query(&self, q: cl_profiling_info, _: &[u8]) -> CLResult>> { - let event = self.get_ref()?; + let event = Event::ref_from_raw(*self)?; if event.cmd_type == CL_COMMAND_USER { // CL_PROFILING_INFO_NOT_AVAILABLE [...] if event is a user event object. return Err(CL_PROFILING_INFO_NOT_AVAILABLE); @@ -118,7 +118,7 @@ fn set_event_callback( pfn_event_notify: Option, user_data: *mut ::std::os::raw::c_void, ) -> CLResult<()> { - let e = event.get_ref()?; + let e = Event::ref_from_raw(event)?; // CL_INVALID_VALUE [...] if command_exec_callback_type is not CL_SUBMITTED, CL_RUNNING, or CL_COMPLETE. if ![CL_SUBMITTED, CL_RUNNING, CL_COMPLETE].contains(&(command_exec_callback_type as cl_uint)) { @@ -136,7 +136,7 @@ fn set_event_callback( #[cl_entrypoint] fn set_user_event_status(event: cl_event, execution_status: cl_int) -> CLResult<()> { - let e = event.get_ref()?; + let e = Event::ref_from_raw(event)?; // CL_INVALID_VALUE if the execution_status is not CL_COMPLETE or a negative integer value. if execution_status != CL_COMPLETE as cl_int && execution_status > 0 { diff --git a/src/gallium/frontends/rusticl/api/icd.rs b/src/gallium/frontends/rusticl/api/icd.rs index e8a37ce600a..3bc944d213e 100644 --- a/src/gallium/frontends/rusticl/api/icd.rs +++ b/src/gallium/frontends/rusticl/api/icd.rs @@ -234,10 +234,6 @@ pub trait ReferenceCountedAPIPointer { // implement that as part of the macro where we know the real type. fn from_ptr(ptr: *const T) -> Self; - fn get_ref(&self) -> CLResult<&T> { - unsafe { Ok(self.get_ptr()?.as_ref().unwrap()) } - } - fn get_arc(&self) -> CLResult> { unsafe { let ptr = self.get_ptr()?; @@ -286,7 +282,16 @@ pub trait CLObject<'a, const ERR: i32, CL: ReferenceCountedAPIPointer Ok(res as u32) } - fn refs_from_arr(objs: *const CL, count: u32) -> CLResult> { + fn ref_from_raw(obj: CL) -> CLResult<&'a Self> { + let obj = obj.get_ptr()?; + // SAFETY: `get_ptr` already checks if it's one of our pointers and not null + Ok(unsafe { &*obj }) + } + + fn refs_from_arr(objs: *const CL, count: u32) -> CLResult> + where + CL: Copy, + { // CL spec requires validation for obj arrays, both values have to make sense if objs.is_null() && count > 0 || !objs.is_null() && count == 0 { return Err(CL_INVALID_VALUE); @@ -298,9 +303,7 @@ pub trait CLObject<'a, const ERR: i32, CL: ReferenceCountedAPIPointer } for i in 0..count as usize { - unsafe { - res.push((*objs.add(i)).get_ref()?); - } + res.push(Self::ref_from_raw(unsafe { *objs.add(i) })?); } Ok(res) } diff --git a/src/gallium/frontends/rusticl/api/kernel.rs b/src/gallium/frontends/rusticl/api/kernel.rs index 0635b4dc7d1..48e4b7eb65f 100644 --- a/src/gallium/frontends/rusticl/api/kernel.rs +++ b/src/gallium/frontends/rusticl/api/kernel.rs @@ -1,6 +1,7 @@ use crate::api::event::create_and_queue; use crate::api::icd::*; use crate::api::util::*; +use crate::core::device::*; use crate::core::event::*; use crate::core::kernel::*; @@ -20,7 +21,7 @@ use std::sync::Arc; #[cl_info_entrypoint(cl_get_kernel_info)] impl CLInfo for cl_kernel { fn query(&self, q: cl_kernel_info, _: &[u8]) -> CLResult>> { - let kernel = self.get_ref()?; + let kernel = Kernel::ref_from_raw(*self)?; Ok(match q { CL_KERNEL_ATTRIBUTES => cl_prop::<&str>(&kernel.kernel_info.attributes_string), CL_KERNEL_CONTEXT => { @@ -43,7 +44,7 @@ impl CLInfo for cl_kernel { #[cl_info_entrypoint(cl_get_kernel_arg_info)] impl CLInfoObj for cl_kernel { fn query(&self, idx: cl_uint, q: cl_kernel_arg_info) -> CLResult>> { - let kernel = self.get_ref()?; + let kernel = Kernel::ref_from_raw(*self)?; // CL_INVALID_ARG_INDEX if arg_index is not a valid argument index. if idx as usize >= kernel.kernel_info.args.len() { @@ -75,7 +76,7 @@ impl CLInfoObj for cl_kernel { dev: cl_device_id, q: cl_kernel_work_group_info, ) -> CLResult>> { - let kernel = self.get_ref()?; + let kernel = Kernel::ref_from_raw(*self)?; // CL_INVALID_DEVICE [..] if device is NULL but there is more than one device associated with kernel. let dev = if dev.is_null() { @@ -85,7 +86,7 @@ impl CLInfoObj for cl_kernel { kernel.prog.devs[0] } } else { - dev.get_ref()? + Device::ref_from_raw(dev)? }; // CL_INVALID_DEVICE if device is not in the list of devices associated with kernel @@ -120,7 +121,7 @@ impl CLInfoObj CLResult>> { - let kernel = self.get_ref()?; + let kernel = Kernel::ref_from_raw(*self)?; // CL_INVALID_DEVICE [..] if device is NULL but there is more than one device associated // with kernel. @@ -131,7 +132,7 @@ impl CLInfoObj CLResult<()> { - let kernel = kernel.get_ref()?; + let kernel = Kernel::ref_from_raw(kernel)?; let arg_index = arg_index as usize; let arg_value = arg_value as usize; @@ -454,7 +455,7 @@ fn set_kernel_exec_info( param_value_size: usize, param_value: *const ::std::os::raw::c_void, ) -> CLResult<()> { - let k = kernel.get_ref()?; + let k = Kernel::ref_from_raw(kernel)?; // CL_INVALID_OPERATION if no devices in the context associated with kernel support SVM. if !k.prog.devs.iter().any(|dev| dev.svm_supported()) { @@ -643,6 +644,6 @@ fn enqueue_task( #[cl_entrypoint] fn clone_kernel(source_kernel: cl_kernel) -> CLResult { - let k = source_kernel.get_ref()?; + let k = Kernel::ref_from_raw(source_kernel)?; Ok(cl_kernel::from_arc(Arc::new(k.clone()))) } diff --git a/src/gallium/frontends/rusticl/api/memory.rs b/src/gallium/frontends/rusticl/api/memory.rs index a38062e7498..7f00f489aea 100644 --- a/src/gallium/frontends/rusticl/api/memory.rs +++ b/src/gallium/frontends/rusticl/api/memory.rs @@ -214,7 +214,7 @@ fn validate_matching_buffer_flags(mem: &Mem, flags: cl_mem_flags) -> CLResult<() #[cl_info_entrypoint(cl_get_mem_object_info)] impl CLInfo for cl_mem { fn query(&self, q: cl_mem_info, _: &[u8]) -> CLResult>> { - let mem = self.get_ref()?; + let mem = Mem::ref_from_raw(*self)?; Ok(match *q { CL_MEM_ASSOCIATED_MEMOBJECT => { let ptr = match mem.parent.as_ref() { @@ -355,7 +355,7 @@ fn set_mem_object_destructor_callback( pfn_notify: Option, user_data: *mut ::std::os::raw::c_void, ) -> CLResult<()> { - let m = memobj.get_ref()?; + let m = Mem::ref_from_raw(memobj)?; // SAFETY: The requirements on `MemCB::new` match the requirements // imposed by the OpenCL specification. It is the caller's duty to uphold them. @@ -596,7 +596,7 @@ fn validate_buffer( // the specified memory objects data store are modified, those changes are reflected in the // contents of the image object and vice-versa at corresponding synchronization points. if !mem_object.is_null() { - let mem = mem_object.get_ref()?; + let mem = Mem::ref_from_raw(mem_object)?; match mem.mem_type { CL_MEM_OBJECT_BUFFER => { @@ -703,7 +703,7 @@ fn validate_buffer( #[cl_info_entrypoint(cl_get_image_info)] impl CLInfo for cl_mem { fn query(&self, q: cl_image_info, _: &[u8]) -> CLResult>> { - let mem = self.get_ref()?; + let mem = Mem::ref_from_raw(*self)?; Ok(match *q { CL_IMAGE_ARRAY_SIZE => cl_prop::(mem.image_desc.image_array_size), CL_IMAGE_BUFFER => cl_prop::(unsafe { mem.image_desc.anon_1.buffer }), @@ -859,7 +859,7 @@ fn get_supported_image_formats( image_formats: *mut cl_image_format, num_image_formats: *mut cl_uint, ) -> CLResult<()> { - let c = context.get_ref()?; + let c = Context::ref_from_raw(context)?; // CL_INVALID_VALUE if flags validate_mem_flags(flags, true)?; @@ -898,7 +898,7 @@ fn get_supported_image_formats( #[cl_info_entrypoint(cl_get_sampler_info)] impl CLInfo for cl_sampler { fn query(&self, q: cl_sampler_info, _: &[u8]) -> CLResult>> { - let sampler = self.get_ref()?; + let sampler = Sampler::ref_from_raw(*self)?; Ok(match q { CL_SAMPLER_ADDRESSING_MODE => cl_prop::(sampler.addressing_mode), CL_SAMPLER_CONTEXT => { @@ -2276,7 +2276,7 @@ pub fn svm_alloc( // clSVMAlloc will fail if // context is not a valid context - let c = context.get_ref()?; + let c = Context::ref_from_raw(context)?; // or no devices in context support SVM. if !c.has_svm_devs() { @@ -2337,7 +2337,7 @@ fn svm_free_impl(c: &Context, svm_pointer: *mut c_void) { } pub fn svm_free(context: cl_context, svm_pointer: *mut c_void) -> CLResult<()> { - let c = context.get_ref()?; + let c = Context::ref_from_raw(context)?; svm_free_impl(c, svm_pointer); Ok(()) } @@ -2964,7 +2964,7 @@ fn create_pipe( #[cl_info_entrypoint(cl_get_gl_texture_info)] impl CLInfo for cl_mem { fn query(&self, q: cl_gl_texture_info, _: &[u8]) -> CLResult>> { - let mem = self.get_ref()?; + let mem = Mem::ref_from_raw(*self)?; Ok(match *q { CL_GL_MIPMAP_LEVEL => cl_prop::(0), CL_GL_TEXTURE_TARGET => cl_prop::( @@ -3093,7 +3093,7 @@ fn get_gl_object_info( gl_object_type: *mut cl_gl_object_type, gl_object_name: *mut cl_GLuint, ) -> CLResult<()> { - let m = memobj.get_ref()?; + let m = Mem::ref_from_raw(memobj)?; match &m.gl_obj { Some(gl_obj) => { diff --git a/src/gallium/frontends/rusticl/api/program.rs b/src/gallium/frontends/rusticl/api/program.rs index eb034dff381..80767ab1345 100644 --- a/src/gallium/frontends/rusticl/api/program.rs +++ b/src/gallium/frontends/rusticl/api/program.rs @@ -24,7 +24,7 @@ use std::sync::Arc; #[cl_info_entrypoint(cl_get_program_info)] impl CLInfo for cl_program { fn query(&self, q: cl_program_info, vals: &[u8]) -> CLResult>> { - let prog = self.get_ref()?; + let prog = Program::ref_from_raw(*self)?; Ok(match q { CL_PROGRAM_BINARIES => cl_prop::>(prog.binaries(vals)), CL_PROGRAM_BINARY_SIZES => cl_prop::>(prog.bin_sizes()), @@ -62,7 +62,7 @@ impl CLInfo for cl_program { #[cl_info_entrypoint(cl_get_program_build_info)] impl CLInfoObj for cl_program { fn query(&self, d: cl_device_id, q: cl_program_build_info) -> CLResult>> { - let prog = self.get_ref()?; + let prog = Program::ref_from_raw(*self)?; let dev = d.get_arc()?; Ok(match q { CL_PROGRAM_BINARY_TYPE => cl_prop::(prog.bin_type(&dev)), @@ -281,7 +281,7 @@ fn build_program( user_data: *mut ::std::os::raw::c_void, ) -> CLResult<()> { let mut res = true; - let p = program.get_ref()?; + let p = Program::ref_from_raw(program)?; let devs = validate_devices(device_list, num_devices, &p.devs)?; // SAFETY: The requirements on `ProgramCB::try_new` match the requirements @@ -329,7 +329,7 @@ fn compile_program( user_data: *mut ::std::os::raw::c_void, ) -> CLResult<()> { let mut res = true; - let p = program.get_ref()?; + let p = Program::ref_from_raw(program)?; let devs = validate_devices(device_list, num_devices, &p.devs)?; // SAFETY: The requirements on `ProgramCB::try_new` match the requirements @@ -352,7 +352,7 @@ fn compile_program( if !p.is_il() { for h in 0..num_input_headers as usize { // SAFETY: have to trust the application here - let header = unsafe { (*input_headers.add(h)).get_ref()? }; + let header = Program::ref_from_raw(unsafe { *input_headers.add(h) })?; match &header.src { ProgramSourceType::Src(src) => headers.push(spirv::CLCHeader { // SAFETY: have to trust the application here @@ -468,7 +468,7 @@ fn set_program_specialization_constant( spec_size: usize, spec_value: *const ::std::os::raw::c_void, ) -> CLResult<()> { - let program = program.get_ref()?; + let program = Program::ref_from_raw(program)?; // CL_INVALID_PROGRAM if program is not a valid program object created from an intermediate // language (e.g. SPIR-V) diff --git a/src/gallium/frontends/rusticl/api/queue.rs b/src/gallium/frontends/rusticl/api/queue.rs index 9625b8777a6..d6dc637ff11 100644 --- a/src/gallium/frontends/rusticl/api/queue.rs +++ b/src/gallium/frontends/rusticl/api/queue.rs @@ -17,7 +17,7 @@ use std::sync::Arc; #[cl_info_entrypoint(cl_get_command_queue_info)] impl CLInfo for cl_command_queue { fn query(&self, q: cl_command_queue_info, _: &[u8]) -> CLResult>> { - let queue = self.get_ref()?; + let queue = Queue::ref_from_raw(*self)?; Ok(match q { CL_QUEUE_CONTEXT => { // Note we use as_ptr here which doesn't increase the reference count. @@ -74,7 +74,9 @@ pub fn create_command_queue_impl( properties_v2: Option>, ) -> CLResult { let c = context.get_arc()?; - let d = device.get_ref()?.to_static().ok_or(CL_INVALID_DEVICE)?; + let d = Device::ref_from_raw(device)? + .to_static() + .ok_or(CL_INVALID_DEVICE)?; // CL_INVALID_DEVICE if device [...] is not associated with context. if !c.devs.contains(&d) { @@ -206,13 +208,13 @@ fn enqueue_barrier_with_wait_list( #[cl_entrypoint] fn flush(command_queue: cl_command_queue) -> CLResult<()> { // CL_INVALID_COMMAND_QUEUE if command_queue is not a valid host command-queue. - command_queue.get_ref()?.flush(false) + Queue::ref_from_raw(command_queue)?.flush(false) } #[cl_entrypoint] fn finish(command_queue: cl_command_queue) -> CLResult<()> { // CL_INVALID_COMMAND_QUEUE if command_queue is not a valid host command-queue. - command_queue.get_ref()?.flush(true) + Queue::ref_from_raw(command_queue)?.flush(true) } #[cl_entrypoint]