rusticl/program: update binary format

This adds a magic number and the device name to the binary in order to
verify we indeed have a binary we can parse and matches the device.

Also save the binary header explicitly in little-endian order, so that we
at least make sure that's always the same.

Cc: mesa-stable
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/29946>
(cherry picked from commit f08f770f16)
This commit is contained in:
Karol Herbst 2024-07-02 22:58:18 +02:00 committed by Eric Engestrom
parent 42ba26edd4
commit ea640f35a9
2 changed files with 58 additions and 27 deletions

View file

@ -254,7 +254,7 @@
"description": "rusticl/program: update binary format",
"nominated": true,
"nomination_type": 0,
"resolution": 0,
"resolution": 1,
"main_sha": null,
"because_sha": null,
"notes": null

View file

@ -26,12 +26,21 @@ use std::sync::Mutex;
use std::sync::MutexGuard;
use std::sync::Once;
const BIN_HEADER_SIZE_V1: usize =
// 1. format version
// 8 bytes so we don't have any padding.
const BIN_RUSTICL_MAGIC_STRING: &[u8; 8] = b"rusticl\0";
const BIN_HEADER_SIZE_BASE: usize =
// 1. magic number
size_of::<[u8; 8]>() +
// 2. format version
size_of::<u32>();
const BIN_HEADER_SIZE_V1: usize = BIN_HEADER_SIZE_BASE +
// 3. device name length
size_of::<u32>() +
// 2. spirv len
// 4. spirv len
size_of::<u32>() +
// 3. binary_type
// 5. binary_type
size_of::<cl_program_binary_type>();
const BIN_HEADER_SIZE: usize = BIN_HEADER_SIZE_V1;
@ -373,24 +382,36 @@ impl Program {
})
}
fn spirv_from_bin_for_dev(bin: &[u8]) -> CLResult<(SPIRVBin, cl_program_binary_type)> {
fn spirv_from_bin_for_dev(
dev: &Device,
bin: &[u8],
) -> CLResult<(SPIRVBin, cl_program_binary_type)> {
if bin.is_empty() {
return Err(CL_INVALID_VALUE);
}
if bin.len() < BIN_HEADER_SIZE_BASE {
return Err(CL_INVALID_BINARY);
}
unsafe {
let mut blob = blob_reader::default();
blob_reader_init(&mut blob, bin.as_ptr().cast(), bin.len());
// 1. version
let version = blob_read_uint32(&mut blob);
let read_magic: &[u8] = slice::from_raw_parts(
blob_read_bytes(&mut blob, BIN_RUSTICL_MAGIC_STRING.len()).cast(),
BIN_RUSTICL_MAGIC_STRING.len(),
);
if read_magic != *BIN_RUSTICL_MAGIC_STRING {
return Err(CL_INVALID_BINARY);
}
let version = u32::from_le(blob_read_uint32(&mut blob));
match version {
1 => {
// 2. size of the spirv
let spirv_size = blob_read_uint32(&mut blob) as usize;
// 3. binary_type
let bin_type = blob_read_uint32(&mut blob);
let name_length = u32::from_le(blob_read_uint32(&mut blob)) as usize;
let spirv_size = u32::from_le(blob_read_uint32(&mut blob)) as usize;
let bin_type = u32::from_le(blob_read_uint32(&mut blob));
debug_assert!(
// `blob_read_*` doesn't advance the pointer on failure to read
@ -398,7 +419,7 @@ impl Program {
|| blob.overrun,
);
// 4. the spirv
let name = blob_read_bytes(&mut blob, name_length);
let spirv_data = blob_read_bytes(&mut blob, spirv_size);
// check that all the reads are valid before accessing the data, which might
@ -407,6 +428,11 @@ impl Program {
return Err(CL_INVALID_BINARY);
}
let name: &[u8] = slice::from_raw_parts(name.cast(), name_length);
if dev.screen().name().as_bytes() != name {
return Err(CL_INVALID_BINARY);
}
let spirv = spirv::SPIRVBin::from_bin(slice::from_raw_parts(
spirv_data.cast(),
spirv_size,
@ -429,7 +455,7 @@ impl Program {
let mut errors = vec![CL_SUCCESS as cl_int; devs.len()];
for (idx, (&d, b)) in devs.iter().zip(bins).enumerate() {
let build = match Self::spirv_from_bin_for_dev(b) {
let build = match Self::spirv_from_bin_for_dev(d, b) {
Ok((spirv, bin_type)) => {
for k in spirv.kernels() {
kernels.insert(k);
@ -517,11 +543,9 @@ impl Program {
for d in &self.devs {
let info = lock.dev_build(d);
res.push(
info.spirv
.as_ref()
.map_or(0, |s| s.to_bin().len() + BIN_HEADER_SIZE),
);
res.push(info.spirv.as_ref().map_or(0, |s| {
s.to_bin().len() + d.screen().name().as_bytes().len() + BIN_HEADER_SIZE
}));
}
res
}
@ -557,17 +581,24 @@ impl Program {
// sadly we have to trust the buffer to be correctly sized...
blob_init_fixed(&mut blob, ptrs[i].cast(), usize::MAX);
// 1. binary format version
blob_write_uint32(&mut blob, 1);
blob_write_bytes(
&mut blob,
BIN_RUSTICL_MAGIC_STRING.as_ptr().cast(),
BIN_RUSTICL_MAGIC_STRING.len(),
);
// 2. size of the spirv
blob_write_uint32(&mut blob, spirv.len() as u32);
// binary format version
blob_write_uint32(&mut blob, 1_u32.to_le());
// 3. binary_type
blob_write_uint32(&mut blob, info.bin_type);
let device_name = d.screen().name();
let device_name = device_name.as_bytes();
blob_write_uint32(&mut blob, (device_name.len() as u32).to_le());
blob_write_uint32(&mut blob, (spirv.len() as u32).to_le());
blob_write_uint32(&mut blob, info.bin_type.to_le());
debug_assert!(blob.size == BIN_HEADER_SIZE);
// 4. the spirv
blob_write_bytes(&mut blob, device_name.as_ptr().cast(), device_name.len());
blob_write_bytes(&mut blob, spirv.as_ptr().cast(), spirv.len());
blob_finish(&mut blob);
}