diff --git a/.pick_status.json b/.pick_status.json index 19c2de2f868..f722f59f78c 100644 --- a/.pick_status.json +++ b/.pick_status.json @@ -254,7 +254,7 @@ "description": "rusticl/program: update binary format", "nominated": true, "nomination_type": 0, - "resolution": 0, + "resolution": 1, "main_sha": null, "because_sha": null, "notes": null diff --git a/src/gallium/frontends/rusticl/core/program.rs b/src/gallium/frontends/rusticl/core/program.rs index c4672719098..6b8a7d1c2e8 100644 --- a/src/gallium/frontends/rusticl/core/program.rs +++ b/src/gallium/frontends/rusticl/core/program.rs @@ -26,12 +26,21 @@ use std::sync::Mutex; use std::sync::MutexGuard; use std::sync::Once; -const BIN_HEADER_SIZE_V1: usize = - // 1. format version +// 8 bytes so we don't have any padding. +const BIN_RUSTICL_MAGIC_STRING: &[u8; 8] = b"rusticl\0"; + +const BIN_HEADER_SIZE_BASE: usize = + // 1. magic number + size_of::<[u8; 8]>() + + // 2. format version + size_of::(); + +const BIN_HEADER_SIZE_V1: usize = BIN_HEADER_SIZE_BASE + + // 3. device name length size_of::() + - // 2. spirv len + // 4. spirv len size_of::() + - // 3. binary_type + // 5. binary_type size_of::(); const BIN_HEADER_SIZE: usize = BIN_HEADER_SIZE_V1; @@ -373,24 +382,36 @@ impl Program { }) } - fn spirv_from_bin_for_dev(bin: &[u8]) -> CLResult<(SPIRVBin, cl_program_binary_type)> { + fn spirv_from_bin_for_dev( + dev: &Device, + bin: &[u8], + ) -> CLResult<(SPIRVBin, cl_program_binary_type)> { if bin.is_empty() { return Err(CL_INVALID_VALUE); } + if bin.len() < BIN_HEADER_SIZE_BASE { + return Err(CL_INVALID_BINARY); + } + unsafe { let mut blob = blob_reader::default(); blob_reader_init(&mut blob, bin.as_ptr().cast(), bin.len()); - // 1. version - let version = blob_read_uint32(&mut blob); + let read_magic: &[u8] = slice::from_raw_parts( + blob_read_bytes(&mut blob, BIN_RUSTICL_MAGIC_STRING.len()).cast(), + BIN_RUSTICL_MAGIC_STRING.len(), + ); + if read_magic != *BIN_RUSTICL_MAGIC_STRING { + return Err(CL_INVALID_BINARY); + } + + let version = u32::from_le(blob_read_uint32(&mut blob)); match version { 1 => { - // 2. size of the spirv - let spirv_size = blob_read_uint32(&mut blob) as usize; - - // 3. binary_type - let bin_type = blob_read_uint32(&mut blob); + let name_length = u32::from_le(blob_read_uint32(&mut blob)) as usize; + let spirv_size = u32::from_le(blob_read_uint32(&mut blob)) as usize; + let bin_type = u32::from_le(blob_read_uint32(&mut blob)); debug_assert!( // `blob_read_*` doesn't advance the pointer on failure to read @@ -398,7 +419,7 @@ impl Program { || blob.overrun, ); - // 4. the spirv + let name = blob_read_bytes(&mut blob, name_length); let spirv_data = blob_read_bytes(&mut blob, spirv_size); // check that all the reads are valid before accessing the data, which might @@ -407,6 +428,11 @@ impl Program { return Err(CL_INVALID_BINARY); } + let name: &[u8] = slice::from_raw_parts(name.cast(), name_length); + if dev.screen().name().as_bytes() != name { + return Err(CL_INVALID_BINARY); + } + let spirv = spirv::SPIRVBin::from_bin(slice::from_raw_parts( spirv_data.cast(), spirv_size, @@ -429,7 +455,7 @@ impl Program { let mut errors = vec![CL_SUCCESS as cl_int; devs.len()]; for (idx, (&d, b)) in devs.iter().zip(bins).enumerate() { - let build = match Self::spirv_from_bin_for_dev(b) { + let build = match Self::spirv_from_bin_for_dev(d, b) { Ok((spirv, bin_type)) => { for k in spirv.kernels() { kernels.insert(k); @@ -517,11 +543,9 @@ impl Program { for d in &self.devs { let info = lock.dev_build(d); - res.push( - info.spirv - .as_ref() - .map_or(0, |s| s.to_bin().len() + BIN_HEADER_SIZE), - ); + res.push(info.spirv.as_ref().map_or(0, |s| { + s.to_bin().len() + d.screen().name().as_bytes().len() + BIN_HEADER_SIZE + })); } res } @@ -557,17 +581,24 @@ impl Program { // sadly we have to trust the buffer to be correctly sized... blob_init_fixed(&mut blob, ptrs[i].cast(), usize::MAX); - // 1. binary format version - blob_write_uint32(&mut blob, 1); + blob_write_bytes( + &mut blob, + BIN_RUSTICL_MAGIC_STRING.as_ptr().cast(), + BIN_RUSTICL_MAGIC_STRING.len(), + ); - // 2. size of the spirv - blob_write_uint32(&mut blob, spirv.len() as u32); + // binary format version + blob_write_uint32(&mut blob, 1_u32.to_le()); - // 3. binary_type - blob_write_uint32(&mut blob, info.bin_type); + let device_name = d.screen().name(); + let device_name = device_name.as_bytes(); + + blob_write_uint32(&mut blob, (device_name.len() as u32).to_le()); + blob_write_uint32(&mut blob, (spirv.len() as u32).to_le()); + blob_write_uint32(&mut blob, info.bin_type.to_le()); debug_assert!(blob.size == BIN_HEADER_SIZE); - // 4. the spirv + blob_write_bytes(&mut blob, device_name.as_ptr().cast(), device_name.len()); blob_write_bytes(&mut blob, spirv.as_ptr().cast(), spirv.len()); blob_finish(&mut blob); }