rusticl/kernel: rework validation in clSetKernelExecInfo

We should use the cl_slice code to get proper validation, which also makes
it simpler to read out data and gets rid of some UB there.

This also fixes CL_KERNEL_EXEC_INFO_SVM_PTRS with param_value being null.

Cc: mesa-stable
Reviewed-by: Adam Jackson <ajax@redhat.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/32942>
This commit is contained in:
Karol Herbst 2025-01-19 10:20:46 +01:00 committed by Marge Bot
parent c5411351ad
commit 35a9829391
2 changed files with 33 additions and 9 deletions

View file

@ -504,22 +504,29 @@ fn set_kernel_exec_info(
return Err(CL_INVALID_OPERATION);
}
// CL_INVALID_VALUE ... if param_value is NULL
if param_value.is_null() {
return Err(CL_INVALID_VALUE);
}
// CL_INVALID_VALUE ... if the size specified by param_value_size is not valid.
match param_name {
CL_KERNEL_EXEC_INFO_SVM_PTRS | CL_KERNEL_EXEC_INFO_SVM_PTRS_ARM => {
// it's a list of pointers
if param_value_size % mem::size_of::<*const c_void>() != 0 {
return Err(CL_INVALID_VALUE);
// To specify that no SVM allocations will be accessed by a kernel other than those set
// as kernel arguments, specify an empty set by passing param_value_size equal to zero
// and param_value equal to NULL.
if !param_value.is_null() || param_value_size != 0 {
let _ = unsafe {
cl_slice::from_raw_parts_bytes_len::<*const c_void>(
param_value,
param_value_size,
)?
};
}
}
CL_KERNEL_EXEC_INFO_SVM_FINE_GRAIN_SYSTEM
| CL_KERNEL_EXEC_INFO_SVM_FINE_GRAIN_SYSTEM_ARM => {
if param_value_size != mem::size_of::<cl_bool>() {
let val = unsafe {
cl_slice::from_raw_parts_bytes_len::<cl_bool>(param_value, param_value_size)?
};
// we must explicitly check that we only got one element
if val.len() != 1 {
return Err(CL_INVALID_VALUE);
}
}

View file

@ -541,6 +541,7 @@ pub mod cl_slice {
use crate::api::util::CLResult;
use mesa_rust_util::ptr::addr;
use rusticl_opencl_gen::CL_INVALID_VALUE;
use std::ffi::c_void;
use std::mem;
use std::slice;
@ -567,6 +568,22 @@ pub mod cl_slice {
unsafe { Ok(slice::from_raw_parts(data, len)) }
}
/// same as [self::from_raw_parts] just that `len` is provided in bytes and must be a multiple
/// of Ts size.
#[inline]
pub unsafe fn from_raw_parts_bytes_len<'a, T>(
data: *const c_void,
len: usize,
) -> CLResult<&'a [T]> {
let size = mem::size_of::<T>();
if len % size != 0 {
return Err(CL_INVALID_VALUE);
}
let len = len / size;
unsafe { self::from_raw_parts(data.cast(), len) }
}
/// Wrapper around [`std::slice::from_raw_parts_mut`] that returns `Err(CL_INVALID_VALUE)` if any of these conditions is met:
/// - `data` is null
/// - `data` is not correctly aligned for `T`