rusticl/program: enable spirv

Signed-off-by: Karol Herbst <kherbst@redhat.com>
Reviewed-by: Adam Jackson <ajax@redhat.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/19008>
This commit is contained in:
Karol Herbst 2022-09-05 17:22:56 +02:00 committed by Marge Bot
parent 2a0b58434d
commit 13a4c49cb1
6 changed files with 115 additions and 37 deletions

View file

@ -2,6 +2,7 @@ use crate::api::icd::*;
use crate::api::platform::*;
use crate::api::util::*;
use crate::core::device::*;
use crate::core::version::*;
use mesa_rust_gen::*;
use mesa_rust_util::ptr::*;
@ -14,16 +15,13 @@ use std::ptr;
use std::sync::Arc;
use std::sync::Once;
// TODO spec constants need to be implemented
const SPIRV_SUPPORT_STRING: &str = "";
// "SPIR-V_1.0 SPIR-V_1.1 SPIR-V_1.2 SPIR-V_1.3 SPIR-V_1.4 SPIR-V_1.5";
const SPIRV_SUPPORT: [cl_name_version; 0] = [
/* mk_cl_version_ext(1, 0, 0, b"SPIR-V"),
mk_cl_version_ext(1, 1, 0, b"SPIR-V"),
mk_cl_version_ext(1, 2, 0, b"SPIR-V"),
mk_cl_version_ext(1, 3, 0, b"SPIR-V"),
mk_cl_version_ext(1, 4, 0, b"SPIR-V"),
mk_cl_version_ext(1, 5, 0, b"SPIR-V"),*/
const SPIRV_SUPPORT_STRING: &str = "SPIR-V_1.0 SPIR-V_1.1 SPIR-V_1.2 SPIR-V_1.3 SPIR-V_1.4";
const SPIRV_SUPPORT: [cl_name_version; 5] = [
mk_cl_version_ext(1, 0, 0, "SPIR-V"),
mk_cl_version_ext(1, 1, 0, "SPIR-V"),
mk_cl_version_ext(1, 2, 0, "SPIR-V"),
mk_cl_version_ext(1, 3, 0, "SPIR-V"),
mk_cl_version_ext(1, 4, 0, "SPIR-V"),
];
impl CLInfo<cl_device_info> for cl_device_id {

View file

@ -10,7 +10,7 @@ use rusticl_opencl_gen::*;
#[allow(non_camel_case_types)]
pub struct _cl_platform_id {
dispatch: &'static cl_icd_dispatch,
extensions: [cl_name_version; 1],
extensions: [cl_name_version; 2],
}
impl CLInfo<cl_platform_info> for cl_platform_id {
@ -18,8 +18,7 @@ impl CLInfo<cl_platform_info> for cl_platform_id {
let p = self.get_ref()?;
Ok(match q {
// TODO spirv
CL_PLATFORM_EXTENSIONS => cl_prop("cl_khr_icd"),
// CL_PLATFORM_EXTENSIONS => cl_prop("cl_khr_icd cl_khr_il_program"),
CL_PLATFORM_EXTENSIONS => cl_prop("cl_khr_icd cl_khr_il_program"),
CL_PLATFORM_EXTENSIONS_WITH_VERSION => {
cl_prop::<Vec<cl_name_version>>(p.extensions.to_vec())
}
@ -41,8 +40,7 @@ static PLATFORM: _cl_platform_id = _cl_platform_id {
dispatch: &DISPATCH,
extensions: [
mk_cl_version_ext(1, 0, 0, "cl_khr_icd"),
// TODO spirv
// mk_cl_version_ext(1, 0, 0, "cl_khr_il_program"),
mk_cl_version_ext(1, 0, 0, "cl_khr_il_program"),
],
};

View file

@ -237,17 +237,16 @@ pub fn create_program_with_il(
il: *const ::std::os::raw::c_void,
length: usize,
) -> CLResult<cl_program> {
let _c = context.get_arc()?;
let c = context.get_arc()?;
// CL_INVALID_VALUE if il is NULL or if length is zero.
if il.is_null() || length == 0 {
return Err(CL_INVALID_VALUE);
}
// let spirv = unsafe { slice::from_raw_parts(il.cast(), length) };
// TODO SPIR-V
// Ok(cl_program::from_arc(Program::from_spirv(c, spirv)))
Err(CL_INVALID_OPERATION)
// SAFETY: according to API spec
let spirv = unsafe { slice::from_raw_parts(il.cast(), length) };
Ok(cl_program::from_arc(Program::from_spirv(c, spirv)))
}
pub fn build_program(
@ -417,29 +416,36 @@ pub fn link_program(
pub fn set_program_specialization_constant(
program: cl_program,
_spec_id: cl_uint,
_spec_size: usize,
spec_id: cl_uint,
spec_size: usize,
spec_value: *const ::std::os::raw::c_void,
) -> CLResult<()> {
let _program = program.get_ref()?;
let program = program.get_ref()?;
// CL_INVALID_PROGRAM if program is not a valid program object created from an intermediate
// language (e.g. SPIR-V)
// TODO: or if the intermediate language does not support specialization constants.
// if program.il.is_empty() {
// Err(CL_INVALID_PROGRAM)?
// }
if program.il.is_empty() {
return Err(CL_INVALID_PROGRAM);
}
// TODO: CL_INVALID_VALUE if spec_size does not match the size of the specialization constant in the module,
if spec_size != program.get_spec_constant_size(spec_id).into() {
// CL_INVALID_VALUE if spec_size does not match the size of the specialization constant in
// the module,
return Err(CL_INVALID_VALUE);
}
// or if spec_value is NULL.
if spec_value.is_null() {
return Err(CL_INVALID_VALUE);
}
Err(CL_INVALID_OPERATION)
// SAFETY: according to API spec
program.set_spec_constant(spec_id, unsafe {
slice::from_raw_parts(spec_value.cast(), spec_size)
});
//• CL_INVALID_SPEC_ID if spec_id is not a valid specialization constant identifier.
Ok(())
}
pub fn set_program_release_callback(

View file

@ -483,8 +483,7 @@ impl Device {
add_ext(1, 0, 0, "cl_khr_byte_addressable_store", "");
add_ext(1, 0, 0, "cl_khr_global_int32_base_atomics", "");
add_ext(1, 0, 0, "cl_khr_global_int32_extended_atomics", "");
// TODO spirv
// add_ext(1, 0, 0, "cl_khr_il_program", "");
add_ext(1, 0, 0, "cl_khr_il_program", "");
add_ext(1, 0, 0, "cl_khr_local_int32_base_atomics", "");
add_ext(1, 0, 0, "cl_khr_local_int32_extended_atomics", "");

View file

@ -53,7 +53,7 @@ pub struct Program {
pub src: CString,
pub il: Vec<u8>,
pub kernel_count: AtomicU32,
spec_constants: Mutex<Vec<spirv::SpecConstant>>,
spec_constants: Mutex<HashMap<u32, nir_const_value>>,
build: Mutex<ProgramBuild>,
}
@ -144,7 +144,7 @@ impl Program {
src: src,
il: Vec::new(),
kernel_count: AtomicU32::new(0),
spec_constants: Mutex::new(Vec::new()),
spec_constants: Mutex::new(HashMap::new()),
build: Mutex::new(ProgramBuild {
builds: builds,
kernels: Vec::new(),
@ -217,7 +217,7 @@ impl Program {
src: CString::new("").unwrap(),
il: Vec::new(),
kernel_count: AtomicU32::new(0),
spec_constants: Mutex::new(Vec::new()),
spec_constants: Mutex::new(HashMap::new()),
build: Mutex::new(ProgramBuild {
builds: builds,
kernels: kernels.into_iter().collect(),
@ -250,7 +250,7 @@ impl Program {
src: CString::new("").unwrap(),
il: spirv.to_vec(),
kernel_count: AtomicU32::new(0),
spec_constants: Mutex::new(Vec::new()),
spec_constants: Mutex::new(HashMap::new()),
build: Mutex::new(ProgramBuild {
builds: builds,
kernels: Vec::new(),
@ -518,7 +518,7 @@ impl Program {
src: CString::new("").unwrap(),
il: Vec::new(),
kernel_count: AtomicU32::new(0),
spec_constants: Mutex::new(Vec::new()),
spec_constants: Mutex::new(HashMap::new()),
build: Mutex::new(ProgramBuild {
builds: builds,
kernels: kernels.into_iter().collect(),
@ -535,6 +535,15 @@ impl Program {
let spirv = info.spirv.as_ref().unwrap();
let mut bin = spirv.to_bin().to_vec();
bin.extend_from_slice(name.as_bytes());
for (k, v) in self.spec_constants.lock().unwrap().iter() {
bin.extend_from_slice(&k.to_ne_bytes());
unsafe {
// SAFETY: we fully initialize this union
bin.extend_from_slice(&v.u64_.to_ne_bytes());
}
}
Some(cache.gen_key(&bin))
} else {
None
@ -571,7 +580,19 @@ impl Program {
}
pub fn to_nir(&self, kernel: &str, d: &Arc<Device>) -> NirShader {
let constants = self.spec_constants.lock().unwrap();
let mut spec_constants: Vec<_> = constants
.iter()
.map(|(&id, &value)| nir_spirv_specialization {
id: id,
value: value,
defined_on_module: true,
})
.collect();
drop(constants);
let mut lock = self.build_info();
let info = Self::dev_build_info(&mut lock, d);
assert_eq!(info.status, CL_BUILD_SUCCESS as cl_build_status);
info.spirv
@ -582,7 +603,7 @@ impl Program {
d.screen
.nir_shader_compiler_options(pipe_shader_type::PIPE_SHADER_COMPUTE),
&d.lib_clc,
&mut [],
&mut spec_constants,
d.address_bits(),
)
.unwrap()
@ -595,4 +616,27 @@ impl Program {
pub fn is_src(&self) -> bool {
!self.src.to_bytes().is_empty()
}
pub fn get_spec_constant_size(&self, spec_id: u32) -> u8 {
let lock = self.build_info();
let spirv = lock.builds.values().next().unwrap().spirv.as_ref().unwrap();
spirv
.spec_constant(spec_id)
.map_or(0, spirv::CLCSpecConstantType::size)
}
pub fn set_spec_constant(&self, spec_id: u32, data: &[u8]) {
let mut lock = self.spec_constants.lock().unwrap();
let mut val = nir_const_value::default();
match data.len() {
1 => val.u8_ = u8::from_ne_bytes(data.try_into().unwrap()),
2 => val.u16_ = u16::from_ne_bytes(data.try_into().unwrap()),
4 => val.u32_ = u32::from_ne_bytes(data.try_into().unwrap()),
8 => val.u64_ = u64::from_ne_bytes(data.try_into().unwrap()),
_ => unreachable!("Spec constant with invalid size!"),
};
lock.insert(spec_id, val);
}
}

View file

@ -368,6 +368,17 @@ impl SPIRVBin {
}
}
pub fn spec_constant(&self, spec_id: u32) -> Option<clc_spec_constant_type> {
let info = self.info?;
let spec_constants =
unsafe { slice::from_raw_parts(info.spec_constants, info.num_spec_constants as usize) };
spec_constants
.iter()
.find(|sc| sc.id == spec_id)
.map(|sc| sc.type_)
}
pub fn print(&self) {
unsafe {
clc_dump_spirv(&self.spirv, stderr_ptr());
@ -429,3 +440,25 @@ impl SPIRVKernelArg {
})
}
}
pub trait CLCSpecConstantType {
fn size(self) -> u8;
}
impl CLCSpecConstantType for clc_spec_constant_type {
fn size(self) -> u8 {
match self {
Self::CLC_SPEC_CONSTANT_INT64
| Self::CLC_SPEC_CONSTANT_UINT64
| Self::CLC_SPEC_CONSTANT_DOUBLE => 8,
Self::CLC_SPEC_CONSTANT_INT32
| Self::CLC_SPEC_CONSTANT_UINT32
| Self::CLC_SPEC_CONSTANT_FLOAT => 4,
Self::CLC_SPEC_CONSTANT_INT16 | Self::CLC_SPEC_CONSTANT_UINT16 => 2,
Self::CLC_SPEC_CONSTANT_INT8
| Self::CLC_SPEC_CONSTANT_UINT8
| Self::CLC_SPEC_CONSTANT_BOOL => 1,
Self::CLC_SPEC_CONSTANT_UNKNOWN => 0,
}
}
}