diff --git a/src/gallium/frontends/rusticl/api/device.rs b/src/gallium/frontends/rusticl/api/device.rs index 5cd9d6d7ad6..38f2e1a9e1e 100644 --- a/src/gallium/frontends/rusticl/api/device.rs +++ b/src/gallium/frontends/rusticl/api/device.rs @@ -348,7 +348,9 @@ fn get_device_ids( // num_devices returns the number of OpenCL devices available that match device_type. If // num_devices is NULL, this argument is ignored. - num_devices.write_checked(devs.len() as cl_uint); + // SAFETY: Caller is responsible for providing a null pointer or one valid + // for a write of `size_of::()`. + unsafe { num_devices.write_checked(devs.len() as cl_uint) }; if !devices.is_null() { let n = min(num_entries as usize, devs.len()); @@ -429,7 +431,9 @@ fn get_host_timer(device_id: cl_device_id, host_timestamp: *mut cl_ulong) -> CLR } // Currently the best clock we have for the host_timestamp - host_timestamp.write_checked(device.screen().get_timestamp()); + // SAFETY: Caller is responsible for providing a pointer valid for a write + // of `size_of::()`. + unsafe { host_timestamp.write_checked(device.screen().get_timestamp()) }; Ok(()) } diff --git a/src/gallium/frontends/rusticl/api/icd.rs b/src/gallium/frontends/rusticl/api/icd.rs index 836be77133b..32ca834524c 100644 --- a/src/gallium/frontends/rusticl/api/icd.rs +++ b/src/gallium/frontends/rusticl/api/icd.rs @@ -549,7 +549,12 @@ extern "C" fn clLinkProgram( Err(e) => (ptr::null_mut(), e), }; - errcode_ret.write_checked(err); + // Correct behavior when `errcode_ret` is null is unspecified, but by + // analogy, we fail silently in that case. + // SAFETY: Caller is responsible for providing a pointer valid for a write + // of `size_of::()`. + unsafe { errcode_ret.write_checked(err) }; + ptr } diff --git a/src/gallium/frontends/rusticl/api/kernel.rs b/src/gallium/frontends/rusticl/api/kernel.rs index e038bc23212..ae864a67fdc 100644 --- a/src/gallium/frontends/rusticl/api/kernel.rs +++ b/src/gallium/frontends/rusticl/api/kernel.rs @@ -343,7 +343,11 @@ fn create_kernels_in_program( } num_kernels += 1; } - num_kernels_ret.write_checked(num_kernels); + + // SAFETY: Caller is responsible for providing a pointer valid for a write + // of `size_of::()`. + unsafe { num_kernels_ret.write_checked(num_kernels) }; + Ok(()) } diff --git a/src/gallium/frontends/rusticl/api/memory.rs b/src/gallium/frontends/rusticl/api/memory.rs index 5991d2a06b5..1e25eda386d 100644 --- a/src/gallium/frontends/rusticl/api/memory.rs +++ b/src/gallium/frontends/rusticl/api/memory.rs @@ -919,7 +919,9 @@ fn get_supported_image_formats( // `num_image_formats` should be the full count of supported formats, // regardless of the value of `num_entries`. It may be null, in which case // it is ignored. - num_image_formats.write_checked(res.len() as cl_uint); + // SAFETY: Callers are responsible for providing either a null pointer or + // one for which a write of `size_of::()` is valid. + unsafe { num_image_formats.write_checked(res.len() as cl_uint) }; // `image_formats` may be null, in which case it is ignored. let num_entries_to_write = cmp::min(res.len(), num_entries as usize); @@ -3155,8 +3157,12 @@ fn get_gl_object_info( match &m.gl_obj { Some(gl_obj) => { - gl_object_type.write_checked(gl_obj.gl_object_type); - gl_object_name.write_checked(gl_obj.gl_object_name); + // Either `gl_object_type` or `gl_object_name` may be null, in which + // case they are ignored. + // SAFETY: Caller is responsible for providing null pointers or ones + // which are valid for a write of the appropriate size. + unsafe { gl_object_type.write_checked(gl_obj.gl_object_type) }; + unsafe { gl_object_name.write_checked(gl_obj.gl_object_name) }; } None => { // CL_INVALID_GL_OBJECT if there is no GL object associated with memobj. diff --git a/src/gallium/frontends/rusticl/api/platform.rs b/src/gallium/frontends/rusticl/api/platform.rs index 975c093f6c8..242a8057026 100644 --- a/src/gallium/frontends/rusticl/api/platform.rs +++ b/src/gallium/frontends/rusticl/api/platform.rs @@ -57,11 +57,17 @@ fn get_platform_ids( // specific OpenCL platform. If the platforms argument is NULL, then this argument is ignored. The // number of OpenCL platforms returned is the minimum of the value specified by num_entries or the // number of OpenCL platforms available. - platforms.write_checked(Platform::get().as_ptr()); + // SAFETY: Caller is responsible for providing a null pointer or one valid + // for a write of `num_entries * size_of::()`. We are + // guaranteed to write at most one value, and if `platforms` is non-null, + // `num_entries` is guaranteed to be at least 1. + unsafe { platforms.write_checked(Platform::get().as_ptr()) }; // num_platforms returns the number of OpenCL platforms available. If num_platforms is NULL, then // this argument is ignored. - num_platforms.write_checked(1); + // SAFETY: Caller is responsible for providing a null pointer or one valid + // for a write of `size_of::()`. + unsafe { num_platforms.write_checked(1) }; Ok(()) } diff --git a/src/gallium/frontends/rusticl/util/ptr.rs b/src/gallium/frontends/rusticl/util/ptr.rs index 9af454e0720..2bb4f9fa6ff 100644 --- a/src/gallium/frontends/rusticl/util/ptr.rs +++ b/src/gallium/frontends/rusticl/util/ptr.rs @@ -63,7 +63,14 @@ pub trait CheckedPtr { /// other invariants of [`std::ptr::copy`]. unsafe fn copy_from_checked(self, src: *const T, count: usize); - fn write_checked(self, val: T); + /// Overwrites a memory location with the given value without reading or + /// dropping the old value. + /// + /// # Safety + /// + /// The nullity of `self` is checked. `self` must fulfill all other + /// invariants of [`std::ptr::write`]. + unsafe fn write_checked(self, val: T); } impl CheckedPtr for *mut T { @@ -77,10 +84,12 @@ impl CheckedPtr for *mut T { } } - fn write_checked(self, val: T) { + unsafe fn write_checked(self, val: T) { if !self.is_null() { + // SAFETY: Caller is responsible for satisfying all invariants save + // pointer nullity. unsafe { - *self = val; + self.write(val); } } }