rusticl/program: rework source code tracking

For the CL spec it really matters how a program object was created. We
never really cared all that much, but it didn't support the corner case of
having an empty string as the OpenCL C source code.

Enums feel like the more Rust way to do this kind of stuff anyway.

Signed-off-by: Karol Herbst <kherbst@redhat.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/22280>
This commit is contained in:
Karol Herbst 2023-04-07 00:42:05 +02:00 committed by Marge Bot
parent 6d7b705125
commit 55c9356d29
2 changed files with 64 additions and 50 deletions

View file

@ -5,7 +5,6 @@ use crate::core::device::*;
use crate::core::platform::*;
use crate::core::program::*;
use mesa_rust::compiler::clc::spirv::SPIRVBin;
use mesa_rust::compiler::clc::*;
use mesa_rust_util::string::*;
use rusticl_opencl_gen::*;
@ -42,19 +41,20 @@ impl CLInfo<cl_program_info> for cl_program {
.collect(),
)
}
CL_PROGRAM_IL => prog
.il
.as_ref()
.map(SPIRVBin::to_bin)
.unwrap_or_default()
.to_vec(),
CL_PROGRAM_IL => match &prog.src {
ProgramSourceType::Il(il) => il.to_bin().to_vec(),
_ => Vec::new(),
},
CL_PROGRAM_KERNEL_NAMES => cl_prop::<String>(prog.kernels().join(";")),
CL_PROGRAM_NUM_DEVICES => cl_prop::<cl_uint>(prog.devs.len() as cl_uint),
CL_PROGRAM_NUM_KERNELS => cl_prop::<usize>(prog.kernels().len()),
CL_PROGRAM_REFERENCE_COUNT => cl_prop::<cl_uint>(self.refcnt()?),
CL_PROGRAM_SCOPE_GLOBAL_CTORS_PRESENT => cl_prop::<cl_bool>(CL_FALSE),
CL_PROGRAM_SCOPE_GLOBAL_DTORS_PRESENT => cl_prop::<cl_bool>(CL_FALSE),
CL_PROGRAM_SOURCE => cl_prop::<&CStr>(prog.src.as_c_str()),
CL_PROGRAM_SOURCE => match &prog.src {
ProgramSourceType::Src(src) => cl_prop::<&CStr>(src.as_c_str()),
_ => Vec::new(),
},
// CL_INVALID_VALUE if param_name is not one of the supported values
_ => return Err(CL_INVALID_VALUE),
})
@ -327,18 +327,27 @@ pub fn compile_program(
}
let mut headers = Vec::new();
for h in 0..num_input_headers as usize {
unsafe {
headers.push(spirv::CLCHeader {
name: CStr::from_ptr(*header_include_names.add(h)).to_owned(),
source: &(*input_headers.add(h)).get_ref()?.src,
});
// If program was created using clCreateProgramWithIL, then num_input_headers, input_headers,
// and header_include_names are ignored.
if !p.is_il() {
for h in 0..num_input_headers as usize {
// SAFETY: have to trust the application here
let header = unsafe { (*input_headers.add(h)).get_ref()? };
match &header.src {
ProgramSourceType::Src(src) => headers.push(spirv::CLCHeader {
// SAFETY: have to trust the application here
name: unsafe { CStr::from_ptr(*header_include_names.add(h)).to_owned() },
source: src,
}),
_ => return Err(CL_INVALID_OPERATION),
}
}
}
// CL_INVALID_OPERATION if program has no source or IL available, i.e. it has not been created
// with clCreateProgramWithSource or clCreateProgramWithIL.
if p.is_binary() {
if !(p.is_src() || p.is_il()) {
return Err(CL_INVALID_OPERATION);
}
@ -442,7 +451,7 @@ pub fn set_program_specialization_constant(
// CL_INVALID_PROGRAM if program is not a valid program object created from an intermediate
// language (e.g. SPIR-V)
// TODO: or if the intermediate language does not support specialization constants.
if program.il.is_none() {
if !program.is_il() {
return Err(CL_INVALID_PROGRAM);
}

View file

@ -47,13 +47,19 @@ fn get_disk_cache() -> &'static Option<DiskCache> {
}
}
pub enum ProgramSourceType {
Binary,
Linked,
Src(CString),
Il(spirv::SPIRVBin),
}
#[repr(C)]
pub struct Program {
pub base: CLObjectBase<CL_INVALID_PROGRAM>,
pub context: Arc<Context>,
pub devs: Vec<Arc<Device>>,
pub src: CString,
pub il: Option<spirv::SPIRVBin>,
pub src: ProgramSourceType,
pub kernel_count: AtomicU32,
spec_constants: Mutex<HashMap<u32, nir_const_value>>,
build: Mutex<ProgramBuild>,
@ -144,8 +150,7 @@ impl Program {
base: CLObjectBase::new(),
context: context.clone(),
devs: devs.to_vec(),
src: src,
il: None,
src: ProgramSourceType::Src(src),
kernel_count: AtomicU32::new(0),
spec_constants: Mutex::new(HashMap::new()),
build: Mutex::new(ProgramBuild {
@ -217,8 +222,7 @@ impl Program {
base: CLObjectBase::new(),
context: context,
devs: devs,
src: CString::new("").unwrap(),
il: None,
src: ProgramSourceType::Binary,
kernel_count: AtomicU32::new(0),
spec_constants: Mutex::new(HashMap::new()),
build: Mutex::new(ProgramBuild {
@ -234,8 +238,7 @@ impl Program {
base: CLObjectBase::new(),
devs: context.devs.clone(),
context: context,
src: CString::new("").unwrap(),
il: Some(SPIRVBin::from_bin(spirv)),
src: ProgramSourceType::Il(SPIRVBin::from_bin(spirv)),
kernel_count: AtomicU32::new(0),
spec_constants: Mutex::new(HashMap::new()),
build: Mutex::new(ProgramBuild {
@ -388,23 +391,25 @@ impl Program {
headers: &[spirv::CLCHeader],
info: &mut MutexGuard<ProgramBuild>,
) -> bool {
if self.is_binary() {
return true;
}
let d = Self::dev_build_info(info, dev);
let (spirv, log) = if let Some(il) = self.il.as_ref() {
il.clone_on_validate()
} else {
let args = prepare_options(&options, dev);
spirv::SPIRVBin::from_clc(
&self.src,
&args,
headers,
get_disk_cache(),
dev.cl_features(),
dev.address_bits(),
)
let (spirv, log) = match &self.src {
ProgramSourceType::Il(spirv) => spirv.clone_on_validate(),
ProgramSourceType::Src(src) => {
let args = prepare_options(&options, dev);
spirv::SPIRVBin::from_clc(
src,
&args,
headers,
get_disk_cache(),
dev.cl_features(),
dev.address_bits(),
)
}
// do nothing if we got a library or binary
_ => {
return true;
}
};
d.spirv = spirv;
@ -484,8 +489,7 @@ impl Program {
base: CLObjectBase::new(),
context: context,
devs: devs,
src: CString::new("").unwrap(),
il: None,
src: ProgramSourceType::Linked,
kernel_count: AtomicU32::new(0),
spec_constants: Mutex::new(HashMap::new()),
build: Mutex::new(ProgramBuild {
@ -585,20 +589,21 @@ impl Program {
nir.unwrap()
}
pub fn is_binary(&self) -> bool {
self.src.to_bytes().is_empty() && self.il.is_none()
pub fn is_il(&self) -> bool {
matches!(self.src, ProgramSourceType::Il(_))
}
pub fn is_src(&self) -> bool {
!self.src.to_bytes().is_empty()
matches!(self.src, ProgramSourceType::Src(_))
}
pub fn get_spec_constant_size(&self, spec_id: u32) -> u8 {
self.il
.as_ref()
.unwrap()
.spec_constant(spec_id)
.map_or(0, spirv::CLCSpecConstantType::size)
match &self.src {
ProgramSourceType::Il(il) => il
.spec_constant(spec_id)
.map_or(0, spirv::CLCSpecConstantType::size),
_ => unreachable!(),
}
}
pub fn set_spec_constant(&self, spec_id: u32, data: &[u8]) {