rusticl/kernel: convert name and type_name to Option<CString>

This also lets us throw CL_KERNEL_ARG_INFO_NOT_AVAILABLE easily on non
existing metadata.

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/32253>
This commit is contained in:
Karol Herbst 2024-11-19 03:19:41 +01:00 committed by Marge Bot
parent 813edb6cea
commit 4750619491
3 changed files with 41 additions and 25 deletions

View file

@ -15,6 +15,7 @@ use rusticl_proc_macros::cl_entrypoint;
use rusticl_proc_macros::cl_info_entrypoint;
use std::cmp;
use std::ffi::CStr;
use std::mem::{self, MaybeUninit};
use std::os::raw::c_void;
use std::ptr;
@ -61,8 +62,16 @@ impl CLInfoObj<cl_kernel_arg_info, cl_uint> for cl_kernel {
CL_KERNEL_ARG_ADDRESS_QUALIFIER => {
cl_prop::<cl_kernel_arg_address_qualifier>(kernel.address_qualifier(idx))
}
CL_KERNEL_ARG_NAME => cl_prop::<&str>(kernel.arg_name(idx)),
CL_KERNEL_ARG_TYPE_NAME => cl_prop::<&str>(kernel.arg_type_name(idx)),
CL_KERNEL_ARG_NAME => cl_prop::<&CStr>(
kernel
.arg_name(idx)
.ok_or(CL_KERNEL_ARG_INFO_NOT_AVAILABLE)?,
),
CL_KERNEL_ARG_TYPE_NAME => cl_prop::<&CStr>(
kernel
.arg_type_name(idx)
.ok_or(CL_KERNEL_ARG_INFO_NOT_AVAILABLE)?,
),
CL_KERNEL_ARG_TYPE_QUALIFIER => {
cl_prop::<cl_kernel_arg_type_qualifier>(kernel.type_qualifier(idx))
}

View file

@ -22,6 +22,7 @@ use spirv::SpirvKernelInfo;
use std::cmp;
use std::collections::HashMap;
use std::convert::TryInto;
use std::ffi::CStr;
use std::fmt::Debug;
use std::fmt::Display;
use std::ops::Index;
@ -1692,12 +1693,14 @@ impl Kernel {
self.kernel_info.subgroup_size
}
pub fn arg_name(&self, idx: cl_uint) -> &String {
&self.kernel_info.args[idx as usize].spirv.name
pub fn arg_name(&self, idx: cl_uint) -> Option<&CStr> {
let name = &self.kernel_info.args[idx as usize].spirv.name;
name.is_empty().not().then_some(name)
}
pub fn arg_type_name(&self, idx: cl_uint) -> &String {
&self.kernel_info.args[idx as usize].spirv.type_name
pub fn arg_type_name(&self, idx: cl_uint) -> Option<&CStr> {
let type_name = &self.kernel_info.args[idx as usize].spirv.type_name;
type_name.is_empty().not().then_some(type_name)
}
pub fn priv_mem_size(&self, dev: &Device) -> cl_ulong {

View file

@ -33,8 +33,8 @@ unsafe impl Sync for SPIRVBin {}
#[derive(PartialEq, Eq, Hash, Clone)]
pub struct SPIRVKernelArg {
pub name: String,
pub type_name: String,
pub name: CString,
pub type_name: CString,
pub access_qualifier: clc_kernel_arg_access_qualifier,
pub address_qualifier: clc_kernel_arg_address_qualifier,
pub type_qualifier: clc_kernel_arg_type_qualifier,
@ -253,8 +253,19 @@ impl SPIRVBin {
unsafe { slice::from_raw_parts(info.args, info.num_args) }
.iter()
.map(|a| SPIRVKernelArg {
name: c_string_to_string(a.name),
type_name: c_string_to_string(a.type_name),
// SAFETY: we have a valid C string pointer here
name: a
.name
.is_null()
.not()
.then(|| unsafe { CStr::from_ptr(a.name) }.to_owned())
.unwrap_or_default(),
type_name: a
.type_name
.is_null()
.not()
.then(|| unsafe { CStr::from_ptr(a.type_name) }.to_owned())
.unwrap_or_default(),
access_qualifier: clc_kernel_arg_access_qualifier(a.access_qualifier),
address_qualifier: a.address_qualifier,
type_qualifier: clc_kernel_arg_type_qualifier(a.type_qualifier),
@ -450,18 +461,12 @@ impl Drop for SPIRVBin {
impl SPIRVKernelArg {
pub fn serialize(&self, blob: &mut blob) {
let name_arr = self.name.as_bytes();
let type_name_arr = self.type_name.as_bytes();
unsafe {
blob_write_uint32(blob, self.access_qualifier.0);
blob_write_uint32(blob, self.type_qualifier.0);
blob_write_uint16(blob, name_arr.len() as u16);
blob_write_uint16(blob, type_name_arr.len() as u16);
blob_write_bytes(blob, name_arr.as_ptr().cast(), name_arr.len());
blob_write_bytes(blob, type_name_arr.as_ptr().cast(), type_name_arr.len());
blob_write_string(blob, self.name.as_ptr());
blob_write_string(blob, self.type_name.as_ptr());
blob_write_uint8(blob, self.address_qualifier as u8);
}
@ -472,11 +477,8 @@ impl SPIRVKernelArg {
let access_qualifier = blob_read_uint32(blob);
let type_qualifier = blob_read_uint32(blob);
let name_len = blob_read_uint16(blob) as usize;
let type_len = blob_read_uint16(blob) as usize;
let name = slice::from_raw_parts(blob_read_bytes(blob, name_len).cast(), name_len);
let type_name = slice::from_raw_parts(blob_read_bytes(blob, type_len).cast(), type_len);
let name = blob_read_string(blob);
let type_name = blob_read_string(blob);
let address_qualifier = match blob_read_uint8(blob) {
0 => clc_kernel_arg_address_qualifier::CLC_KERNEL_ARG_ADDRESS_PRIVATE,
@ -488,8 +490,10 @@ impl SPIRVKernelArg {
// check overrun to ensure nothing went wrong
blob.overrun.not().then(|| Self {
name: String::from_utf8_unchecked(name.to_owned()),
type_name: String::from_utf8_unchecked(type_name.to_owned()),
// SAFETY: blob_read_string checks for a valid nul character already and sets the
// blob to overrun state if none was found.
name: CStr::from_ptr(name).to_owned(),
type_name: CStr::from_ptr(type_name).to_owned(),
access_qualifier: clc_kernel_arg_access_qualifier(access_qualifier),
address_qualifier: address_qualifier,
type_qualifier: clc_kernel_arg_type_qualifier(type_qualifier),