rusticl/kernel: big kernel arg rework

The main change here is that instead of having two vectors for API and
internal arguments, there is just one per built kernel.

Some of the API level information is still in its own structure and
referenced by the above mentioned merged vector, but with this change each
device and also each kernel variant can have arguments placed at different
locations or even have a different set of arguments.

This rework will be necessary to add kernel variants in a non messy way.

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/30602>
This commit is contained in:
Karol Herbst 2024-08-11 17:32:28 +02:00 committed by Marge Bot
parent 0b98e47d83
commit d26d17bbaf

View file

@ -92,6 +92,7 @@ impl KernelArgType {
#[derive(Hash, PartialEq, Eq, Clone)]
enum CompiledKernelArgType {
APIArg(u32),
ConstantBuffer,
GlobalWorkOffsets,
GlobalWorkSize,
@ -108,16 +109,14 @@ enum CompiledKernelArgType {
pub struct KernelArg {
spirv: spirv::SPIRVKernelArg,
pub kind: KernelArgType,
/// The offset into the input buffer
offset: u32,
/// The actual binding slot
binding: u32,
pub dead: bool,
}
#[derive(Hash, PartialEq, Eq, Clone)]
struct CompiledKernelArg {
kind: CompiledKernelArgType,
/// The binding for image/sampler args, the offset into the input buffer
/// for anything else.
offset: u32,
}
@ -169,8 +168,6 @@ impl KernelArg {
spirv: s.clone(),
// we'll update it later in the 2nd pass
kind: kind,
offset: 0,
binding: 0,
dead: true,
});
}
@ -185,16 +182,19 @@ impl KernelArg {
for var in nir.variables_with_mode(
nir_variable_mode::nir_var_uniform | nir_variable_mode::nir_var_image,
) {
if let Some(arg) = args.get_mut(var.data.location as usize) {
arg.offset = var.data.driver_location;
arg.binding = var.data.binding;
arg.dead = false;
} else {
compiled_args
.get_mut(var.data.location as usize - args.len())
.unwrap()
.offset = var.data.driver_location;
let arg = &mut compiled_args[var.data.location as usize];
if let CompiledKernelArgType::APIArg(idx) = arg.kind {
args[idx as usize].dead = false;
}
let t = var.type_;
arg.offset = if unsafe {
glsl_type_is_image(t) || glsl_type_is_texture(t) || glsl_type_is_sampler(t)
} {
var.data.binding
} else {
var.data.driver_location
};
}
}
@ -204,8 +204,6 @@ impl KernelArg {
for arg in args {
arg.spirv.serialize(blob);
blob_write_uint32(blob, arg.offset);
blob_write_uint32(blob, arg.binding);
blob_write_uint8(blob, arg.dead.into());
arg.kind.serialize(blob);
}
@ -219,16 +217,12 @@ impl KernelArg {
for _ in 0..len {
let spirv = spirv::SPIRVKernelArg::deserialize(blob)?;
let offset = blob_read_uint32(blob);
let binding = blob_read_uint32(blob);
let dead = blob_read_uint8(blob) != 0;
let kind = KernelArgType::deserialize(blob)?;
res.push(Self {
spirv: spirv,
kind: kind,
offset: offset,
binding: binding,
dead: dead,
});
}
@ -260,6 +254,10 @@ impl CompiledKernelArg {
CompiledKernelArgType::WorkGroupOffsets => blob_write_uint8(blob, 7),
CompiledKernelArgType::NumWorkgroups => blob_write_uint8(blob, 8),
CompiledKernelArgType::GlobalWorkSize => blob_write_uint8(blob, 9),
CompiledKernelArgType::APIArg(idx) => {
blob_write_uint8(blob, 10);
blob_write_uint32(blob, idx)
}
};
}
}
@ -289,6 +287,10 @@ impl CompiledKernelArg {
7 => CompiledKernelArgType::WorkGroupOffsets,
8 => CompiledKernelArgType::NumWorkgroups,
9 => CompiledKernelArgType::GlobalWorkSize,
10 => {
let idx = blob_read_uint32(blob);
CompiledKernelArgType::APIArg(idx)
}
_ => return None,
};
@ -581,7 +583,14 @@ fn lower_and_optimize_nir(
opt_nir(nir, dev, false);
let mut args = KernelArg::from_spirv_nir(args, nir);
let mut compiled_args = Vec::new();
// add all API kernel args
let mut compiled_args: Vec<_> = (0..args.len())
.map(|idx| CompiledKernelArg {
kind: CompiledKernelArgType::APIArg(idx as u32),
offset: 0,
})
.collect();
// asign locations for inline samplers.
// IMPORTANT: this needs to happen before nir_remove_dead_variables.
@ -654,11 +663,11 @@ fn lower_and_optimize_nir(
nir.gather_info();
if nir.reads_sysval(gl_system_value::SYSTEM_VALUE_BASE_GLOBAL_INVOCATION_ID) {
lower_state.base_global_invoc_id_loc = compiled_args.len();
compiled_args.push(CompiledKernelArg {
kind: CompiledKernelArgType::GlobalWorkOffsets,
offset: 0,
});
lower_state.base_global_invoc_id_loc = args.len() + compiled_args.len() - 1;
nir.add_var(
nir_variable_mode::nir_var_uniform,
unsafe { glsl_vector_type(address_bits_base_type, 3) },
@ -668,11 +677,11 @@ fn lower_and_optimize_nir(
}
if nir.reads_sysval(gl_system_value::SYSTEM_VALUE_GLOBAL_GROUP_SIZE) {
lower_state.global_size_loc = compiled_args.len();
compiled_args.push(CompiledKernelArg {
kind: CompiledKernelArgType::GlobalWorkSize,
offset: 0,
});
lower_state.global_size_loc = args.len() + compiled_args.len() - 1;
nir.add_var(
nir_variable_mode::nir_var_uniform,
unsafe { glsl_vector_type(address_bits_base_type, 3) },
@ -682,11 +691,11 @@ fn lower_and_optimize_nir(
}
if nir.reads_sysval(gl_system_value::SYSTEM_VALUE_BASE_WORKGROUP_ID) {
lower_state.base_workgroup_id_loc = compiled_args.len();
compiled_args.push(CompiledKernelArg {
kind: CompiledKernelArgType::WorkGroupOffsets,
offset: 0,
});
lower_state.base_workgroup_id_loc = args.len() + compiled_args.len() - 1;
nir.add_var(
nir_variable_mode::nir_var_uniform,
unsafe { glsl_vector_type(address_bits_base_type, 3) },
@ -696,12 +705,11 @@ fn lower_and_optimize_nir(
}
if nir.reads_sysval(gl_system_value::SYSTEM_VALUE_NUM_WORKGROUPS) {
lower_state.num_workgroups_loc = compiled_args.len();
compiled_args.push(CompiledKernelArg {
kind: CompiledKernelArgType::NumWorkgroups,
offset: 0,
});
lower_state.num_workgroups_loc = args.len() + compiled_args.len() - 1;
nir.add_var(
nir_variable_mode::nir_var_uniform,
unsafe { glsl_vector_type(glsl_base_type::GLSL_TYPE_UINT, 3) },
@ -711,11 +719,11 @@ fn lower_and_optimize_nir(
}
if nir.has_constant() {
lower_state.const_buf_loc = compiled_args.len();
compiled_args.push(CompiledKernelArg {
kind: CompiledKernelArgType::ConstantBuffer,
offset: 0,
});
lower_state.const_buf_loc = args.len() + compiled_args.len() - 1;
nir.add_var(
nir_variable_mode::nir_var_uniform,
address_bits_ptr_type,
@ -724,11 +732,11 @@ fn lower_and_optimize_nir(
);
}
if nir.has_printf() {
lower_state.printf_buf_loc = compiled_args.len();
compiled_args.push(CompiledKernelArg {
kind: CompiledKernelArgType::PrintfBuffer,
offset: 0,
});
lower_state.printf_buf_loc = args.len() + compiled_args.len() - 1;
nir.add_var(
nir_variable_mode::nir_var_uniform,
address_bits_ptr_type,
@ -738,6 +746,9 @@ fn lower_and_optimize_nir(
}
if nir.num_images() > 0 || nir.num_textures() > 0 {
lower_state.format_arr_loc = compiled_args.len();
lower_state.order_arr_loc = compiled_args.len() + 1;
let count = nir.num_images() + nir.num_textures();
compiled_args.push(CompiledKernelArg {
kind: CompiledKernelArgType::FormatArray,
@ -749,7 +760,6 @@ fn lower_and_optimize_nir(
offset: 0,
});
lower_state.format_arr_loc = args.len() + compiled_args.len() - 2;
nir.add_var(
nir_variable_mode::nir_var_uniform,
unsafe { glsl_array_type(glsl_int16_t_type(), count as u32, 2) },
@ -757,7 +767,6 @@ fn lower_and_optimize_nir(
"image_formats",
);
lower_state.order_arr_loc = args.len() + compiled_args.len() - 1;
nir.add_var(
nir_variable_mode::nir_var_uniform,
unsafe { glsl_array_type(glsl_int16_t_type(), count as u32, 2) },
@ -767,11 +776,11 @@ fn lower_and_optimize_nir(
}
if nir.reads_sysval(gl_system_value::SYSTEM_VALUE_WORK_DIM) {
lower_state.work_dim_loc = compiled_args.len();
compiled_args.push(CompiledKernelArg {
kind: CompiledKernelArgType::WorkDim,
offset: 0,
});
lower_state.work_dim_loc = args.len() + compiled_args.len() - 1;
nir.add_var(
nir_variable_mode::nir_var_uniform,
unsafe { glsl_uint8_t_type() },
@ -1120,103 +1129,6 @@ impl Kernel {
}
}
for (arg, val) in kernel_info.args.iter().zip(arg_values.iter()) {
if arg.dead {
continue;
}
if arg.kind != KernelArgType::Image
&& arg.kind != KernelArgType::RWImage
&& arg.kind != KernelArgType::Texture
&& arg.kind != KernelArgType::Sampler
{
input.resize(arg.offset as usize, 0);
}
match val.as_ref().unwrap() {
KernelArgValue::Constant(c) => input.extend_from_slice(c),
KernelArgValue::Buffer(buffer) => {
let res = buffer.get_res_of_dev(q.device)?;
add_global(q, &mut input, &mut resource_info, res, buffer.offset);
}
KernelArgValue::Image(image) => {
let res = image.get_res_of_dev(q.device)?;
// If resource is a buffer, the image was created from a buffer. Use strides and
// dimensions of the image then.
let app_img_info = if res.as_ref().is_buffer()
&& image.mem_type == CL_MEM_OBJECT_IMAGE2D
{
Some(AppImgInfo::new(
image.image_desc.row_pitch()? / image.image_elem_size as u32,
image.image_desc.width()?,
image.image_desc.height()?,
))
} else {
None
};
let format = image.pipe_format;
let (formats, orders) = if arg.kind == KernelArgType::Image {
iviews.push(res.pipe_image_view(
format,
false,
image.pipe_image_host_access(),
app_img_info.as_ref(),
));
(&mut img_formats, &mut img_orders)
} else if arg.kind == KernelArgType::RWImage {
iviews.push(res.pipe_image_view(
format,
true,
image.pipe_image_host_access(),
app_img_info.as_ref(),
));
(&mut img_formats, &mut img_orders)
} else {
sviews.push((res.clone(), format, app_img_info));
(&mut tex_formats, &mut tex_orders)
};
let binding = arg.binding as usize;
assert!(binding >= formats.len());
formats.resize(binding, 0);
orders.resize(binding, 0);
formats.push(image.image_format.image_channel_data_type as u16);
orders.push(image.image_format.image_channel_order as u16);
}
KernelArgValue::LocalMem(size) => {
// TODO 32 bit
let pot = cmp::min(*size, 0x80);
variable_local_size =
variable_local_size.next_multiple_of(pot.next_power_of_two() as u64);
if q.device.address_bits() == 64 {
let variable_local_size: [u8; 8] = variable_local_size.to_ne_bytes();
input.extend_from_slice(&variable_local_size);
} else {
let variable_local_size: [u8; 4] =
(variable_local_size as u32).to_ne_bytes();
input.extend_from_slice(&variable_local_size);
}
variable_local_size += *size as u64;
}
KernelArgValue::Sampler(sampler) => {
samplers.push(sampler.pipe());
}
KernelArgValue::None => {
assert!(
arg.kind == KernelArgType::MemGlobal
|| arg.kind == KernelArgType::MemConstant
);
input.extend_from_slice(null_ptr);
}
}
}
// subtract the shader local_size as we only request something on top of that.
variable_local_size -= static_local_size;
let mut printf_buf = None;
if nir_kernel_build.printf_info.is_some() {
let buf = q
@ -1232,10 +1144,116 @@ impl Kernel {
}
for arg in &nir_kernel_build.compiled_args {
if arg.offset as usize > input.len() {
let is_opaque = if let CompiledKernelArgType::APIArg(idx) = arg.kind {
matches!(
kernel_info.args[idx as usize].kind,
KernelArgType::Image
| KernelArgType::RWImage
| KernelArgType::Texture
| KernelArgType::Sampler
)
} else {
false
};
if !is_opaque && arg.offset as usize > input.len() {
input.resize(arg.offset as usize, 0);
}
match arg.kind {
CompiledKernelArgType::APIArg(idx) => {
let api_arg = &kernel_info.args[idx as usize];
if api_arg.dead {
continue;
}
let Some(value) = &arg_values[idx as usize] else {
continue;
};
match value {
KernelArgValue::Constant(c) => input.extend_from_slice(c),
KernelArgValue::Buffer(buffer) => {
let res = buffer.get_res_of_dev(q.device)?;
add_global(q, &mut input, &mut resource_info, res, buffer.offset);
}
KernelArgValue::Image(image) => {
let res = image.get_res_of_dev(q.device)?;
// If resource is a buffer, the image was created from a buffer. Use
// strides and dimensions of the image then.
let app_img_info = if res.as_ref().is_buffer()
&& image.mem_type == CL_MEM_OBJECT_IMAGE2D
{
Some(AppImgInfo::new(
image.image_desc.row_pitch()?
/ image.image_elem_size as u32,
image.image_desc.width()?,
image.image_desc.height()?,
))
} else {
None
};
let format = image.pipe_format;
let (formats, orders) = if api_arg.kind == KernelArgType::Image {
iviews.push(res.pipe_image_view(
format,
false,
image.pipe_image_host_access(),
app_img_info.as_ref(),
));
(&mut img_formats, &mut img_orders)
} else if api_arg.kind == KernelArgType::RWImage {
iviews.push(res.pipe_image_view(
format,
true,
image.pipe_image_host_access(),
app_img_info.as_ref(),
));
(&mut img_formats, &mut img_orders)
} else {
sviews.push((res.clone(), format, app_img_info));
(&mut tex_formats, &mut tex_orders)
};
let binding = arg.offset as usize;
assert!(binding >= formats.len());
formats.resize(binding, 0);
orders.resize(binding, 0);
formats.push(image.image_format.image_channel_data_type as u16);
orders.push(image.image_format.image_channel_order as u16);
}
KernelArgValue::LocalMem(size) => {
// TODO 32 bit
let pot = cmp::min(*size, 0x80);
variable_local_size = variable_local_size
.next_multiple_of(pot.next_power_of_two() as u64);
if q.device.address_bits() == 64 {
let variable_local_size: [u8; 8] =
variable_local_size.to_ne_bytes();
input.extend_from_slice(&variable_local_size);
} else {
let variable_local_size: [u8; 4] =
(variable_local_size as u32).to_ne_bytes();
input.extend_from_slice(&variable_local_size);
}
variable_local_size += *size as u64;
}
KernelArgValue::Sampler(sampler) => {
samplers.push(sampler.pipe());
}
KernelArgValue::None => {
assert!(
api_arg.kind == KernelArgType::MemGlobal
|| api_arg.kind == KernelArgType::MemConstant
);
input.extend_from_slice(null_ptr);
}
}
}
CompiledKernelArgType::ConstantBuffer => {
assert!(nir_kernel_build.constant_buffer.is_some());
let res = nir_kernel_build.constant_buffer.as_ref().unwrap();
@ -1277,6 +1295,9 @@ impl Kernel {
}
}
// subtract the shader local_size as we only request something on top of that.
variable_local_size -= static_local_size;
let mut sviews: Vec<_> = sviews
.iter()
.map(|(s, f, aii)| ctx.create_sampler_view(s, *f, aii.as_ref()))