diff --git a/src/gallium/frontends/rusticl/api/kernel.rs b/src/gallium/frontends/rusticl/api/kernel.rs index eda3e498f49..9e86859fbe7 100644 --- a/src/gallium/frontends/rusticl/api/kernel.rs +++ b/src/gallium/frontends/rusticl/api/kernel.rs @@ -15,6 +15,7 @@ use rusticl_proc_macros::cl_entrypoint; use rusticl_proc_macros::cl_info_entrypoint; use std::cmp; +use std::ffi::CStr; use std::mem::{self, MaybeUninit}; use std::os::raw::c_void; use std::ptr; @@ -61,8 +62,16 @@ impl CLInfoObj for cl_kernel { CL_KERNEL_ARG_ADDRESS_QUALIFIER => { cl_prop::(kernel.address_qualifier(idx)) } - CL_KERNEL_ARG_NAME => cl_prop::<&str>(kernel.arg_name(idx)), - CL_KERNEL_ARG_TYPE_NAME => cl_prop::<&str>(kernel.arg_type_name(idx)), + CL_KERNEL_ARG_NAME => cl_prop::<&CStr>( + kernel + .arg_name(idx) + .ok_or(CL_KERNEL_ARG_INFO_NOT_AVAILABLE)?, + ), + CL_KERNEL_ARG_TYPE_NAME => cl_prop::<&CStr>( + kernel + .arg_type_name(idx) + .ok_or(CL_KERNEL_ARG_INFO_NOT_AVAILABLE)?, + ), CL_KERNEL_ARG_TYPE_QUALIFIER => { cl_prop::(kernel.type_qualifier(idx)) } diff --git a/src/gallium/frontends/rusticl/core/kernel.rs b/src/gallium/frontends/rusticl/core/kernel.rs index c9cfdfdebda..5bb3a4d7d83 100644 --- a/src/gallium/frontends/rusticl/core/kernel.rs +++ b/src/gallium/frontends/rusticl/core/kernel.rs @@ -22,6 +22,7 @@ use spirv::SpirvKernelInfo; use std::cmp; use std::collections::HashMap; use std::convert::TryInto; +use std::ffi::CStr; use std::fmt::Debug; use std::fmt::Display; use std::ops::Index; @@ -1692,12 +1693,14 @@ impl Kernel { self.kernel_info.subgroup_size } - pub fn arg_name(&self, idx: cl_uint) -> &String { - &self.kernel_info.args[idx as usize].spirv.name + pub fn arg_name(&self, idx: cl_uint) -> Option<&CStr> { + let name = &self.kernel_info.args[idx as usize].spirv.name; + name.is_empty().not().then_some(name) } - pub fn arg_type_name(&self, idx: cl_uint) -> &String { - &self.kernel_info.args[idx as usize].spirv.type_name + pub fn arg_type_name(&self, idx: cl_uint) -> Option<&CStr> { + let type_name = &self.kernel_info.args[idx as usize].spirv.type_name; + type_name.is_empty().not().then_some(type_name) } pub fn priv_mem_size(&self, dev: &Device) -> cl_ulong { diff --git a/src/gallium/frontends/rusticl/mesa/compiler/clc/spirv.rs b/src/gallium/frontends/rusticl/mesa/compiler/clc/spirv.rs index b7afd0a69c1..a204fb51cb8 100644 --- a/src/gallium/frontends/rusticl/mesa/compiler/clc/spirv.rs +++ b/src/gallium/frontends/rusticl/mesa/compiler/clc/spirv.rs @@ -33,8 +33,8 @@ unsafe impl Sync for SPIRVBin {} #[derive(PartialEq, Eq, Hash, Clone)] pub struct SPIRVKernelArg { - pub name: String, - pub type_name: String, + pub name: CString, + pub type_name: CString, pub access_qualifier: clc_kernel_arg_access_qualifier, pub address_qualifier: clc_kernel_arg_address_qualifier, pub type_qualifier: clc_kernel_arg_type_qualifier, @@ -253,8 +253,19 @@ impl SPIRVBin { unsafe { slice::from_raw_parts(info.args, info.num_args) } .iter() .map(|a| SPIRVKernelArg { - name: c_string_to_string(a.name), - type_name: c_string_to_string(a.type_name), + // SAFETY: we have a valid C string pointer here + name: a + .name + .is_null() + .not() + .then(|| unsafe { CStr::from_ptr(a.name) }.to_owned()) + .unwrap_or_default(), + type_name: a + .type_name + .is_null() + .not() + .then(|| unsafe { CStr::from_ptr(a.type_name) }.to_owned()) + .unwrap_or_default(), access_qualifier: clc_kernel_arg_access_qualifier(a.access_qualifier), address_qualifier: a.address_qualifier, type_qualifier: clc_kernel_arg_type_qualifier(a.type_qualifier), @@ -450,18 +461,12 @@ impl Drop for SPIRVBin { impl SPIRVKernelArg { pub fn serialize(&self, blob: &mut blob) { - let name_arr = self.name.as_bytes(); - let type_name_arr = self.type_name.as_bytes(); - unsafe { blob_write_uint32(blob, self.access_qualifier.0); blob_write_uint32(blob, self.type_qualifier.0); - blob_write_uint16(blob, name_arr.len() as u16); - blob_write_uint16(blob, type_name_arr.len() as u16); - - blob_write_bytes(blob, name_arr.as_ptr().cast(), name_arr.len()); - blob_write_bytes(blob, type_name_arr.as_ptr().cast(), type_name_arr.len()); + blob_write_string(blob, self.name.as_ptr()); + blob_write_string(blob, self.type_name.as_ptr()); blob_write_uint8(blob, self.address_qualifier as u8); } @@ -472,11 +477,8 @@ impl SPIRVKernelArg { let access_qualifier = blob_read_uint32(blob); let type_qualifier = blob_read_uint32(blob); - let name_len = blob_read_uint16(blob) as usize; - let type_len = blob_read_uint16(blob) as usize; - - let name = slice::from_raw_parts(blob_read_bytes(blob, name_len).cast(), name_len); - let type_name = slice::from_raw_parts(blob_read_bytes(blob, type_len).cast(), type_len); + let name = blob_read_string(blob); + let type_name = blob_read_string(blob); let address_qualifier = match blob_read_uint8(blob) { 0 => clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_PRIVATE, @@ -488,8 +490,10 @@ impl SPIRVKernelArg { // check overrun to ensure nothing went wrong blob.overrun.not().then(|| Self { - name: String::from_utf8_unchecked(name.to_owned()), - type_name: String::from_utf8_unchecked(type_name.to_owned()), + // SAFETY: blob_read_string checks for a valid nul character already and sets the + // blob to overrun state if none was found. + name: CStr::from_ptr(name).to_owned(), + type_name: CStr::from_ptr(type_name).to_owned(), access_qualifier: clc_kernel_arg_access_qualifier(access_qualifier), address_qualifier: address_qualifier, type_qualifier: clc_kernel_arg_type_qualifier(type_qualifier),