Merge branch 'rusticl/ubos/new' into 'main'

rusticl: make use of constant buffers for constant memory kernel inputs

See merge request mesa/mesa!38852
This commit is contained in:
Karol Herbst 2025-12-20 00:47:21 +00:00
commit 2678dabe27
10 changed files with 409 additions and 123 deletions

View file

@ -177,7 +177,7 @@ unsafe impl CLInfo<cl_device_info> for cl_device_id {
CL_DEVICE_MAX_COMPUTE_UNITS => v.write::<cl_uint>(dev.max_compute_units()),
// TODO atm implemented as mem_const
CL_DEVICE_MAX_CONSTANT_ARGS => v.write::<cl_uint>(dev.const_max_count()),
CL_DEVICE_MAX_CONSTANT_BUFFER_SIZE => v.write::<cl_ulong>(dev.const_max_size()),
CL_DEVICE_MAX_CONSTANT_BUFFER_SIZE => v.write::<cl_ulong>(dev.const_max_size().into()),
CL_DEVICE_MAX_GLOBAL_VARIABLE_SIZE => v.write::<usize>(0),
CL_DEVICE_MAX_MEM_ALLOC_SIZE => v.write::<cl_ulong>(dev.max_mem_alloc()),
CL_DEVICE_MAX_NUM_SUB_GROUPS => v.write::<cl_uint>(if dev.subgroups_supported() {

View file

@ -163,23 +163,18 @@ impl Context {
for &dev in &self.devs {
let mut resource = None;
let bind = PIPE_BIND_GLOBAL | PIPE_BIND_CONSTANT_BUFFER;
if !user_ptr.is_null() && !copy {
resource = dev.screen().resource_create_buffer_from_user(
adj_size,
user_ptr,
PIPE_BIND_GLOBAL,
pipe_flags,
)
resource = dev
.screen()
.resource_create_buffer_from_user(adj_size, user_ptr, bind, pipe_flags)
}
if resource.is_none() {
resource = dev.screen().resource_create_buffer(
adj_size,
res_type,
PIPE_BIND_GLOBAL,
pipe_flags,
)
resource = dev
.screen()
.resource_create_buffer(adj_size, res_type, bind, pipe_flags)
}
let resource = resource.ok_or(CL_OUT_OF_RESOURCES);
@ -376,16 +371,17 @@ impl Context {
let mut buffers = HashMap::new();
for &dev in &self.devs {
let size: u32 = size.get().try_into().map_err(|_| CL_OUT_OF_HOST_MEMORY)?;
let bind = PIPE_BIND_GLOBAL | PIPE_BIND_CONSTANT_BUFFER;
// For system SVM devices we simply create a userptr resource.
let res = if dev.system_svm_supported() {
dev.screen()
.resource_create_buffer_from_user(size, ptr, PIPE_BIND_GLOBAL, 0)
.resource_create_buffer_from_user(size, ptr, bind, 0)
} else {
dev.screen().resource_create_buffer(
size,
ResourceType::Normal,
PIPE_BIND_GLOBAL,
bind,
PIPE_RESOURCE_FLAG_FRONTEND_VM,
)
};

View file

@ -425,9 +425,11 @@ impl DeviceBase {
return true;
}
// TODO
// CL_DEVICE_MAX_CONSTANT_ARGS
// The minimum value is 4 for devices that are not of type CL_DEVICE_TYPE_CUSTOM.
if self.const_max_count() < 4 {
return true;
}
// CL_DEVICE_LOCAL_MEM_SIZE
// The minimum value is 1 KB for devices that are not of type CL_DEVICE_TYPE_CUSTOM.
@ -448,9 +450,11 @@ impl DeviceBase {
return true;
}
// TODO
// CL_DEVICE_MAX_CONSTANT_ARGS
// The minimum value is 8 for devices that are not of type CL_DEVICE_TYPE_CUSTOM.
if self.const_max_count() < 8 {
return true;
}
// CL 1.0 spec:
// CL_DEVICE_LOCAL_MEM_SIZE
@ -794,7 +798,7 @@ impl DeviceBase {
self.screen.compute_caps().address_bits
}
pub fn const_max_size(&self) -> cl_ulong {
pub fn const_max_size(&self) -> u32 {
min(
// Needed to fix the `api min_max_constant_buffer_size` CL CTS test as it can't really
// handle arbitrary values here. We might want to reconsider later and figure out how to
@ -802,10 +806,7 @@ impl DeviceBase {
// should be at least 1 << 16 (native UBO size on NVidia)
// advertising more just in case it benefits other hardware
1 << 26,
min(
self.max_mem_alloc(),
self.screen.caps().max_shader_buffer_size.into(),
),
self.screen.caps().max_constant_buffer_size,
)
}

View file

@ -18,6 +18,7 @@ use mesa_rust::pipe::context::RWFlags;
use mesa_rust::pipe::resource::*;
use mesa_rust::pipe::screen::ResourceType;
use mesa_rust_gen::*;
use mesa_rust_util::conversion::TryIntoWithErr;
use mesa_rust_util::math::*;
use mesa_rust_util::serialize::*;
use rusticl_opencl_gen::*;
@ -31,6 +32,7 @@ use std::convert::TryInto;
use std::ffi::CStr;
use std::fmt::Debug;
use std::fmt::Display;
use std::mem;
use std::ops::Index;
use std::ops::Not;
use std::os::raw::c_void;
@ -140,6 +142,7 @@ enum CompiledKernelArgType {
WorkDim,
WorkGroupOffsets,
NumWorkgroups,
Ubo { api_arg: u32 },
}
#[derive(Hash, PartialEq, Eq, Clone)]
@ -251,6 +254,10 @@ impl CompiledKernelArg {
for var in nir.variables_with_mode(
nir_variable_mode::nir_var_uniform | nir_variable_mode::nir_var_image,
) {
if var.data.location == -1 {
continue;
}
let arg = &mut compiled_args[var.data.location as usize];
let t = var.type_;
@ -265,6 +272,14 @@ impl CompiledKernelArg {
}
}
fn is_opaque(&self, api_args: &[KernelArg]) -> bool {
match &self.kind {
CompiledKernelArgType::APIArg(idx) => api_args[*idx].kind.is_opaque(),
CompiledKernelArgType::Ubo { .. } => true,
_ => false,
}
}
fn serialize(args: &[Self], blob: &mut blob) {
unsafe {
blob_write_uint16(blob, args.len() as u16);
@ -291,6 +306,10 @@ impl CompiledKernelArg {
blob_write_uint8(blob, 10);
blob_write_uint32(blob, idx as u32)
}
CompiledKernelArgType::Ubo { api_arg } => {
blob_write_uint8(blob, 11);
blob_write_uint32(blob, api_arg)
}
};
}
}
@ -325,6 +344,10 @@ impl CompiledKernelArg {
let idx = blob_read_uint32(blob) as usize;
CompiledKernelArgType::APIArg(idx)
}
11 => {
let api_arg = blob_read_uint32(blob);
CompiledKernelArgType::Ubo { api_arg: api_arg }
}
_ => return None,
};
@ -492,6 +515,7 @@ pub struct NirKernelBuild {
constant_buffer: Option<PipeResourceOwned>,
shared_size: u64,
input_size: u32,
num_ubos: u32,
printf_info: Option<NirPrintfInfo>,
compiled_args: Vec<CompiledKernelArg>,
}
@ -506,6 +530,8 @@ impl NirKernelBuild {
let cb = Self::create_nir_constant_buffer(dev, &out.nir);
let shared_size = out.nir.shared_size() as u64;
let printf_info = out.nir.take_printf_info();
// Need to substract one to exclude the kernel input buffer
let num_ubos = out.nir.info().num_ubos - 1;
let nir_or_cso = if dev.shareable_shaders() {
// SAFETY: The device supports shareable shaders, upholding the
@ -524,6 +550,7 @@ impl NirKernelBuild {
input_size: out.input_size,
printf_info: printf_info,
compiled_args: out.compiled_args,
num_ubos: num_ubos.into(),
}
}
@ -989,7 +1016,6 @@ fn compile_nir_variant(
nir_variable_mode::nir_var_mem_shared
| nir_variable_mode::nir_var_function_temp
| nir_variable_mode::nir_var_shader_temp
| nir_variable_mode::nir_var_uniform
| nir_variable_mode::nir_var_mem_global
| nir_variable_mode::nir_var_mem_generic,
Some(glsl_get_cl_type_size_align),
@ -1028,7 +1054,44 @@ fn compile_nir_variant(
global_address_format,
);
// promote constant to ubo
nir_pass!(
nir,
rusticl_promote_constant_to_ubo,
// We reserve one ubo slot for the kernel input buffer
dev.const_max_count() - 1,
dev.const_max_size()
);
compiled_args.extend(
nir.variables_with_mode(nir_variable_mode::nir_var_mem_ubo)
.map(|ubo| CompiledKernelArg {
kind: CompiledKernelArgType::Ubo {
api_arg: ubo.data.location as u32,
},
offset: ubo.data.binding as usize,
dead: false,
}),
);
nir_pass!(nir, rusticl_lower_intrinsics, &mut lower_state);
// We can only dce uniform variables _after_ lowering intrinsics
nir_pass!(
nir,
nir_remove_dead_variables,
nir_variable_mode::nir_var_uniform,
&dv_opts,
);
// This calculates uniform usage, which we'll need for lowering it to an UBO
nir_pass!(
nir,
nir_lower_vars_to_explicit_types,
nir_variable_mode::nir_var_uniform,
Some(glsl_get_cl_type_size_align),
);
nir_pass!(
nir,
nir_lower_explicit_io,
@ -1284,16 +1347,18 @@ struct KernelExecBuilder<'a> {
resources: Vec<&'a PipeResource>,
resource_offsets: Vec<usize>,
workgroup_id_offset_loc: Option<usize>,
ubos: Vec<QueueContextUboBind>,
}
impl<'a> KernelExecBuilder<'a> {
fn new(dev: &'static Device, input_size: u32, num_globals: usize) -> Self {
fn new(dev: &'static Device, input_size: u32, num_globals: usize, num_ubos: u32) -> Self {
Self {
dev: dev,
input: Vec::with_capacity(input_size as usize),
resources: Vec::with_capacity(num_globals),
resource_offsets: Vec::with_capacity(num_globals),
workgroup_id_offset_loc: None,
ubos: Vec::with_capacity(num_ubos as usize),
}
}
@ -1323,6 +1388,14 @@ impl<'a> KernelExecBuilder<'a> {
}
}
fn add_ubo(&mut self, res: Option<PipeResourceOwned>, size: u32, offset: u32) {
self.ubos.push(QueueContextUboBind {
res: res,
size: size,
offset: offset,
});
}
fn add_values(&mut self, value: &[u8]) {
self.input.extend_from_slice(value);
}
@ -1371,6 +1444,10 @@ impl<'a> KernelExecBuilder<'a> {
}
}
}
fn ubos(&mut self) -> Vec<QueueContextUboBind> {
mem::take(&mut self.ubos)
}
}
impl Kernel {
@ -1573,8 +1650,12 @@ impl Kernel {
};
let nir_kernel_build = &nir_kernel_builds[variant];
let mut exec_builder =
KernelExecBuilder::new(ctx.dev, nir_kernel_build.input_size, buffer_arcs.len() + 2);
let mut exec_builder = KernelExecBuilder::new(
ctx.dev,
nir_kernel_build.input_size,
buffer_arcs.len() + 2,
nir_kernel_build.num_ubos,
);
// Set it once so we get the alignment padding right
let static_local_size: u64 = nir_kernel_build.shared_size;
let mut variable_local_size: u64 = static_local_size;
@ -1618,12 +1699,7 @@ impl Kernel {
.collect();
for arg in &nir_kernel_build.compiled_args {
let is_opaque = if let CompiledKernelArgType::APIArg(idx) = arg.kind {
args[idx].kind.is_opaque()
} else {
false
};
let is_opaque = arg.is_opaque(&args);
if !is_opaque {
exec_builder.add_zero_padding(arg.offset);
}
@ -1731,6 +1807,38 @@ impl Kernel {
}
}
}
CompiledKernelArgType::Ubo { api_arg } => {
let value = &arg_values[api_arg as usize];
match value {
Some(KernelArgValue::Buffer(buffer)) => {
let buffer = &buffer_arcs[&(buffer.as_ptr() as usize)];
let res = buffer.get_res_for_access(ctx, RWFlags::RD)?;
exec_builder.add_ubo(
Some(res.new_ref()),
buffer.size.try_into_with_err(CL_OUT_OF_RESOURCES)?,
buffer.offset().try_into_with_err(CL_OUT_OF_RESOURCES)?,
);
}
Some(KernelArgValue::SVM(handle)) => {
if let Some((base, size)) = cl_ctx.find_svm_alloc(*handle) {
let base = base as usize;
if let Some(res) = cl_ctx.copy_svm_to_dev(ctx, base)? {
let offset = *handle - base;
exec_builder.add_ubo(
Some(res),
(size - offset)
.try_into_with_err(CL_OUT_OF_RESOURCES)?,
offset.try_into_with_err(CL_OUT_OF_RESOURCES)?,
);
}
}
}
Some(KernelArgValue::None) | None => {
continue;
}
_ => panic!("uhh"),
}
}
CompiledKernelArgType::ConstantBuffer => {
assert!(nir_kernel_build.constant_buffer.is_some());
let res = nir_kernel_build.constant_buffer.as_ref().unwrap();
@ -1805,6 +1913,7 @@ impl Kernel {
ctx.bind_sampler_views(sviews);
ctx.bind_shader_images(iviews);
ctx.set_global_binding(resources, &mut globals);
ctx.bind_ubos(exec_builder.ubos());
for z in 0..grid[2].div_ceil(hw_max_grid[2]) {
for y in 0..grid[1].div_ceil(hw_max_grid[1]) {

View file

@ -15,11 +15,13 @@ use mesa_rust::pipe::context::PipeContext;
use mesa_rust::pipe::context::PipeContextPrio;
use mesa_rust::pipe::fence::PipeFence;
use mesa_rust::pipe::resource::PipeImageView;
use mesa_rust::pipe::resource::PipeResourceOwned;
use mesa_rust::pipe::resource::PipeSamplerView;
use mesa_rust_gen::*;
use mesa_rust_util::properties::*;
use rusticl_opencl_gen::*;
use std::borrow::Borrow;
use std::cmp;
use std::collections::HashMap;
use std::ffi::c_void;
@ -80,6 +82,7 @@ impl<'a> QueueContext<'a> {
bound_sampler_views: Vec::new(),
bound_shader_images: Vec::new(),
samplers: HashMap::new(),
bound_ubos: Vec::new(),
}
}
}
@ -96,6 +99,13 @@ pub struct QueueContextWithState<'a> {
bound_sampler_views: Vec<PipeSamplerView<'a>>,
bound_shader_images: Vec<PipeImageView>,
samplers: HashMap<PipeSamplerState, *mut c_void>,
bound_ubos: Vec<QueueContextUboBind>,
}
pub struct QueueContextUboBind {
pub res: Option<PipeResourceOwned>,
pub size: u32,
pub offset: u32,
}
impl<'c> QueueContextWithState<'c> {
@ -178,6 +188,26 @@ impl<'c> QueueContextWithState<'c> {
}
Ok(())
}
pub fn bind_ubos(&mut self, ubos: Vec<QueueContextUboBind>) {
for (idx, ubo) in ubos.iter().enumerate() {
let idx = idx as u32 + 1;
if let Some(res) = ubo.res.as_ref() {
self.ctx
.bind_constant_buffer(idx, res.borrow(), ubo.size, ubo.offset);
} else {
self.ctx.set_constant_buffer(idx, &[]);
}
}
// unbind trailing slots
for idx in ubos.len()..self.bound_ubos.len() {
let idx = idx as u32 + 1;
self.ctx.set_constant_buffer(idx, &[])
}
self.bound_ubos = ubos;
}
}
impl<'a> Deref for QueueContextWithState<'a> {

View file

@ -504,6 +504,10 @@ impl NirShader {
(*var).data.location = loc.try_into().unwrap();
}
}
pub fn info(&self) -> &shader_info {
&unsafe { self.nir.as_ref() }.info
}
}
impl Clone for NirShader {

View file

@ -425,11 +425,12 @@ impl PipeContext {
unsafe { self.pipe.as_ref().delete_sampler_state.unwrap()(self.pipe.as_ptr(), ptr) }
}
pub fn bind_constant_buffer(&self, idx: u32, res: &PipeResourceOwned) {
pub fn bind_constant_buffer(&self, idx: u32, res: &PipeResource, size: u32, offset: u32) {
assert!(size <= res.width() - offset);
let cb = pipe_constant_buffer {
buffer: res.pipe(),
buffer_offset: 0,
buffer_size: res.width(),
buffer: res.as_mut_ptr(),
buffer_offset: offset,
buffer_size: size,
user_buffer: ptr::null(),
};
unsafe {

View file

@ -323,7 +323,7 @@ impl PipeResource {
this.as_mut_ptr().cast()
}
fn as_mut_ptr(&self) -> *mut pipe_resource {
pub(super) fn as_mut_ptr(&self) -> *mut pipe_resource {
(&self.0 as *const pipe_resource).cast_mut()
}
@ -332,6 +332,10 @@ impl PipeResource {
// a ref as we don't impose any further restrictions on the type.
unsafe { mem::transmute(res) }
}
pub fn width(&self) -> u32 {
self.0.width0
}
}
impl ToOwned for PipeResource {

View file

@ -14,116 +14,127 @@
static bool
rusticl_lower_intrinsics_filter(const nir_instr* instr, const void* state)
{
return instr->type == nir_instr_type_intrinsic;
return instr->type == nir_instr_type_intrinsic;
}
static nir_def*
rusticl_lower_intrinsics_instr(
nir_builder *b,
nir_instr *instr,
void* _state
nir_builder *b,
nir_instr *instr,
void* _state
) {
nir_intrinsic_instr *intrins = nir_instr_as_intrinsic(instr);
struct rusticl_lower_state *state = _state;
nir_intrinsic_instr *intrins = nir_instr_as_intrinsic(instr);
struct rusticl_lower_state *state = _state;
switch (intrins->intrinsic) {
case nir_intrinsic_image_deref_format:
case nir_intrinsic_image_deref_order: {
int32_t offset;
nir_deref_instr *deref;
nir_def *val;
nir_variable *var;
switch (intrins->intrinsic) {
case nir_intrinsic_image_deref_format:
case nir_intrinsic_image_deref_order: {
int32_t offset;
nir_deref_instr *deref;
nir_def *val;
nir_variable *var;
if (intrins->intrinsic == nir_intrinsic_image_deref_format) {
offset = CL_SNORM_INT8;
var = nir_find_variable_with_location(b->shader, nir_var_uniform, state->format_arr_loc);
} else {
offset = CL_R;
var = nir_find_variable_with_location(b->shader, nir_var_uniform, state->order_arr_loc);
}
if (intrins->intrinsic == nir_intrinsic_image_deref_format) {
offset = CL_SNORM_INT8;
var = nir_find_variable_with_location(b->shader, nir_var_uniform, state->format_arr_loc);
} else {
offset = CL_R;
var = nir_find_variable_with_location(b->shader, nir_var_uniform, state->order_arr_loc);
}
val = intrins->src[0].ssa;
val = intrins->src[0].ssa;
if (nir_def_is_deref(val)) {
nir_deref_instr *deref = nir_def_as_deref(val);
nir_variable *var = nir_deref_instr_get_variable(deref);
assert(var);
val = nir_imm_intN_t(b, var->data.binding, val->bit_size);
}
if (nir_def_is_deref(val)) {
nir_deref_instr *deref = nir_def_as_deref(val);
nir_variable *var = nir_deref_instr_get_variable(deref);
assert(var);
val = nir_imm_intN_t(b, var->data.binding, val->bit_size);
}
// we put write images after read images
if (glsl_type_is_image(var->type)) {
val = nir_iadd_imm(b, val, b->shader->info.num_textures);
}
// we put write images after read images
if (glsl_type_is_image(var->type)) {
val = nir_iadd_imm(b, val, b->shader->info.num_textures);
}
deref = nir_build_deref_var(b, var);
deref = nir_build_deref_array(b, deref, val);
val = nir_u2uN(b, nir_load_deref(b, deref), 32);
deref = nir_build_deref_var(b, var);
deref = nir_build_deref_array(b, deref, val);
val = nir_u2uN(b, nir_load_deref(b, deref), 32);
// we have to fix up the value base
val = nir_iadd_imm(b, val, -offset);
// we have to fix up the value base
val = nir_iadd_imm(b, val, -offset);
return val;
}
case nir_intrinsic_load_global_invocation_id:
if (intrins->def.bit_size == 64)
return nir_u2u64(b, nir_load_global_invocation_id(b, 32));
return NULL;
case nir_intrinsic_load_base_global_invocation_id:
return nir_load_var(b, nir_find_variable_with_location(b->shader, nir_var_uniform, state->base_global_invoc_id_loc));
case nir_intrinsic_load_base_workgroup_id:
return nir_load_var(b, nir_find_variable_with_location(b->shader, nir_var_uniform, state->base_workgroup_id_loc));
case nir_intrinsic_load_global_size:
return nir_load_var(b, nir_find_variable_with_location(b->shader, nir_var_uniform, state->global_size_loc));
case nir_intrinsic_load_num_workgroups:
return nir_load_var(b, nir_find_variable_with_location(b->shader, nir_var_uniform, state->num_workgroups_loc));
case nir_intrinsic_load_constant_base_ptr:
return nir_load_var(b, nir_find_variable_with_location(b->shader, nir_var_uniform, state->const_buf_loc));
case nir_intrinsic_load_printf_buffer_address:
return nir_load_var(b, nir_find_variable_with_location(b->shader, nir_var_uniform, state->printf_buf_loc));
case nir_intrinsic_load_work_dim:
assert(nir_find_variable_with_location(b->shader, nir_var_uniform, state->work_dim_loc));
return nir_u2uN(b, nir_load_var(b, nir_find_variable_with_location(b->shader, nir_var_uniform, state->work_dim_loc)),
intrins->def.bit_size);
default:
return NULL;
}
return val;
}
case nir_intrinsic_load_global_invocation_id:
if (intrins->def.bit_size == 64)
return nir_u2u64(b, nir_load_global_invocation_id(b, 32));
return NULL;
case nir_intrinsic_load_base_global_invocation_id:
return nir_load_var(b, nir_find_variable_with_location(b->shader, nir_var_uniform, state->base_global_invoc_id_loc));
case nir_intrinsic_load_base_workgroup_id:
return nir_load_var(b, nir_find_variable_with_location(b->shader, nir_var_uniform, state->base_workgroup_id_loc));
case nir_intrinsic_load_global_size:
return nir_load_var(b, nir_find_variable_with_location(b->shader, nir_var_uniform, state->global_size_loc));
case nir_intrinsic_load_num_workgroups:
return nir_load_var(b, nir_find_variable_with_location(b->shader, nir_var_uniform, state->num_workgroups_loc));
case nir_intrinsic_load_constant_base_ptr:
return nir_load_var(b, nir_find_variable_with_location(b->shader, nir_var_uniform, state->const_buf_loc));
case nir_intrinsic_load_printf_buffer_address:
return nir_load_var(b, nir_find_variable_with_location(b->shader, nir_var_uniform, state->printf_buf_loc));
case nir_intrinsic_load_work_dim:
assert(nir_find_variable_with_location(b->shader, nir_var_uniform, state->work_dim_loc));
return nir_u2uN(b, nir_load_var(b, nir_find_variable_with_location(b->shader, nir_var_uniform, state->work_dim_loc)),
intrins->def.bit_size);
default:
return NULL;
}
}
bool
rusticl_lower_intrinsics(nir_shader *nir, struct rusticl_lower_state* state)
{
return nir_shader_lower_instructions(
nir,
rusticl_lower_intrinsics_filter,
rusticl_lower_intrinsics_instr,
state
);
return nir_shader_lower_instructions(
nir,
rusticl_lower_intrinsics_filter,
rusticl_lower_intrinsics_instr,
state
);
}
static nir_def*
rusticl_lower_input_instr(struct nir_builder *b, nir_instr *instr, void *_)
static bool
rusticl_lower_input_instr(struct nir_builder *b, nir_intrinsic_instr *intrins, void *_)
{
nir_intrinsic_instr *intrins = nir_instr_as_intrinsic(instr);
if (intrins->intrinsic != nir_intrinsic_load_kernel_input)
return NULL;
switch (intrins->intrinsic) {
case nir_intrinsic_load_kernel_input: {
b->cursor = nir_after_instr(&intrins->instr);
nir_def *ubo_idx = nir_imm_int(b, 0);
nir_def *uniform_offset = intrins->src[0].ssa;
nir_def *ubo_idx = nir_imm_int(b, 0);
nir_def *uniform_offset = intrins->src[0].ssa;
assert(intrins->def.bit_size >= 8);
nir_def *load_result =
nir_load_ubo(b, intrins->num_components, intrins->def.bit_size,
ubo_idx, nir_iadd_imm(b, uniform_offset, nir_intrinsic_base(intrins)));
assert(intrins->def.bit_size >= 8);
nir_def *load_result =
nir_load_ubo(b, intrins->num_components, intrins->def.bit_size,
ubo_idx, nir_iadd_imm(b, uniform_offset, nir_intrinsic_base(intrins)));
nir_intrinsic_instr *load = nir_def_as_intrinsic(load_result);
nir_intrinsic_instr *load = nir_def_as_intrinsic(load_result);
nir_intrinsic_set_align_mul(load, nir_intrinsic_align_mul(intrins));
nir_intrinsic_set_align_offset(load, nir_intrinsic_align_offset(intrins));
nir_intrinsic_set_range_base(load, nir_intrinsic_base(intrins));
nir_intrinsic_set_range(load, nir_intrinsic_range(intrins));
nir_intrinsic_set_align_mul(load, nir_intrinsic_align_mul(intrins));
nir_intrinsic_set_align_offset(load, nir_intrinsic_align_offset(intrins));
nir_intrinsic_set_range_base(load, nir_intrinsic_base(intrins));
nir_intrinsic_set_range(load, nir_intrinsic_range(intrins));
return load_result;
nir_def_replace(&intrins->def, load_result);
return true;
}
case nir_intrinsic_load_ubo: {
b->cursor = nir_before_instr(&intrins->instr);
nir_def *new_index = nir_iadd_imm(b, intrins->src[0].ssa, 1);
nir_src_rewrite(&intrins->src[0], new_index);
return true;
}
default:
return false;
}
}
bool
@ -133,10 +144,10 @@ rusticl_lower_inputs(nir_shader *shader)
assert(!shader->info.first_ubo_is_default_ubo);
progress = nir_shader_lower_instructions(
progress = nir_shader_intrinsics_pass(
shader,
rusticl_lower_intrinsics_filter,
rusticl_lower_input_instr,
nir_metadata_control_flow,
NULL
);
@ -149,6 +160,8 @@ rusticl_lower_inputs(nir_shader *shader)
if (shader->num_uniforms > 0) {
const struct glsl_type *type = glsl_array_type(glsl_uint8_t_type(), shader->num_uniforms, 1);
nir_variable *ubo = nir_variable_create(shader, nir_var_mem_ubo, type, "kernel_input");
ubo->data.location = -1;
ubo->data.driver_location = 0;
ubo->data.binding = 0;
ubo->data.explicit_binding = 1;
}
@ -156,3 +169,130 @@ rusticl_lower_inputs(nir_shader *shader)
shader->info.first_ubo_is_default_ubo = true;
return progress;
}
struct promote_constant_state
{
uint32_t max_ubos;
uint32_t max_ubo_size;
uint32_t curr_ubo;
};
typedef struct ubo_ref
{
nir_def *index;
nir_def *address;
} ubo_ref;
static struct ubo_ref
try_promote_and_fix_address(nir_builder *b, struct promote_constant_state *state, nir_instr *use)
{
struct ubo_ref ref = {};
switch (use->type) {
case nir_instr_type_alu: {
nir_alu_instr *alu = nir_instr_as_alu(use);
if (alu->op != nir_op_iadd)
return ref;
/* It should be impossible to get vectored addresses? */
if (alu->def.num_components != 1)
return ref;
int src = 1;
ubo_ref new_ref = try_promote_and_fix_address(b, state, nir_def_instr(alu->src[0].src.ssa));
if (!new_ref.index) {
src = 0;
new_ref = try_promote_and_fix_address(b, state, nir_def_instr(alu->src[1].src.ssa));
}
if (!new_ref.index)
return ref;
b->cursor = nir_after_instr(use);
ref.index = new_ref.index;
ref.address = nir_iadd(b, new_ref.address, nir_u2u32(b, alu->src[src].src.ssa));
return ref;
}
case nir_instr_type_intrinsic: {
nir_intrinsic_instr *intrins = nir_instr_as_intrinsic(use);
if (intrins->intrinsic != nir_intrinsic_load_deref)
return ref;
nir_deref_instr *deref = nir_src_as_deref(intrins->src[0]);
if (deref->deref_type != nir_deref_type_var)
return ref;
nir_variable *var = deref->var;
if (var->data.mode != nir_var_uniform)
return ref;
nir_variable *ubo = nir_find_variable_with_location(
b->shader, nir_var_mem_ubo, var->data.location
);
if (!ubo) {
if (state->curr_ubo >= state->max_ubos)
return ref;
const glsl_type *type = glsl_array_type(glsl_uint8_t_type(), state->max_ubo_size, 1);
ubo = nir_variable_create(b->shader, nir_var_mem_ubo, type, NULL);
ubo->data.location = var->data.location;
ubo->data.binding = state->curr_ubo++;
ubo->data.driver_location = ubo->data.binding;
b->shader->info.num_ubos++;
}
b->cursor = nir_after_instr(use);
ref.index = nir_imm_int(b, ubo->data.binding);
ref.address = nir_imm_zero(b, 1, 32);
return ref;
}
default:
return ref;
}
}
static bool
rusticl_promote_constant_cb(struct nir_builder *b, nir_intrinsic_instr *intrins, void *_state)
{
if (intrins->intrinsic != nir_intrinsic_load_global_constant)
return false;
struct promote_constant_state *state = _state;
ubo_ref new_ref = try_promote_and_fix_address(b, state, nir_def_instr(intrins->src[0].ssa));
if (!new_ref.index)
return false;
b->cursor = nir_before_instr(&intrins->instr);
nir_def *new_load = nir_load_ubo(b,
intrins->def.num_components, intrins->def.bit_size,
new_ref.index, new_ref.address
);
nir_intrinsic_instr *load_ubo = nir_def_as_intrinsic(new_load);
nir_intrinsic_copy_const_indices(load_ubo, intrins);
nir_intrinsic_set_range(load_ubo, state->max_ubo_size);
nir_def_rewrite_uses(&intrins->def, new_load);
/* Let's clean up here so we don't have to go through opt loops and the entire chains should be
* dead anyway in most cases */
nir_instr_free_and_dce(&intrins->instr);
return true;
}
bool
rusticl_promote_constant_to_ubo(nir_shader *nir, uint32_t max_ubos, uint32_t max_ubo_size)
{
struct promote_constant_state state = {
.max_ubos = max_ubos,
.max_ubo_size = max_ubo_size,
.curr_ubo = 0,
};
return nir_shader_intrinsics_pass(
nir,
rusticl_promote_constant_cb,
nir_metadata_control_flow,
&state
);
}

View file

@ -20,3 +20,4 @@ struct rusticl_lower_state {
bool rusticl_lower_intrinsics(nir_shader *nir, struct rusticl_lower_state *state);
bool rusticl_lower_inputs(nir_shader *nir);
bool rusticl_promote_constant_to_ubo(nir_shader *nir, uint32_t max_ubos, uint32_t max_ubo_size);