spirv2dxil: Support buffer_device_address

This works similarly to the CL compiler, where a 64-bit address is
decomposed into a 32-bit index and offset. But unlike CL, where the
index is into a per-kernel array of bound buffers, for Vulkan it points
into the global device-wide descriptor heap.

For all global deref chains that terminate in a load/store/atomic, create
a parallel deref chain that begins by decomposing the pointer to a vec2,
followed by a load_vulkan_descriptor, and then an SSBO deref chain. Any instance
where the original deref chain was used for something else will remain as
global derefs, so also run lower_explicit_io for global to produce appropriate
pointer math.

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/28028>
This commit is contained in:
Jesse Natalie 2024-03-06 16:30:29 -08:00 committed by Marge Bot
parent 57d914b757
commit 68f43aa3ec
3 changed files with 69 additions and 0 deletions

View file

@ -58,10 +58,12 @@ spirv_to_nir_options = {
.int64 = true,
.float64 = true,
.tessellation = true,
.physical_storage_buffer_address = true,
},
.ubo_addr_format = nir_address_format_32bit_index_offset,
.ssbo_addr_format = nir_address_format_32bit_index_offset,
.shared_addr_format = nir_address_format_logical,
.phys_ssbo_addr_format = nir_address_format_32bit_index_offset_pack64,
.min_ubo_alignment = 256, /* D3D12_CONSTANT_BUFFER_DATA_PLACEMENT_ALIGNMENT */
.min_ssbo_alignment = 16, /* D3D12_RAW_UAV_SRV_BYTE_ALIGNMENT */
@ -1123,8 +1125,11 @@ dxil_spirv_nir_passes(nir_shader *nir,
conf->push_constant_cbv.base_shader_register,
&push_constant_size);
NIR_PASS_V(nir, dxil_spirv_nir_lower_buffer_device_address);
NIR_PASS_V(nir, nir_lower_explicit_io, nir_var_mem_ubo | nir_var_mem_ssbo,
nir_address_format_32bit_index_offset);
NIR_PASS_V(nir, nir_lower_explicit_io, nir_var_mem_global,
nir_address_format_32bit_index_offset_pack64);
if (nir->info.shared_memory_explicit_layout) {
NIR_PASS_V(nir, nir_lower_vars_to_explicit_types, nir_var_mem_shared,

View file

@ -79,6 +79,9 @@ struct dxil_spirv_bindless_entry {
bool
dxil_spirv_nir_lower_bindless(nir_shader *nir, struct dxil_spirv_nir_lower_bindless_options *options);
bool
dxil_spirv_nir_lower_buffer_device_address(nir_shader *nir);
bool
dxil_spirv_nir_lower_yz_flip(nir_shader *shader,
const struct dxil_spirv_runtime_conf *rt_conf,

View file

@ -254,3 +254,64 @@ dxil_spirv_nir_lower_bindless(nir_shader *nir, struct dxil_spirv_nir_lower_bindl
}
return true;
}
/* Given a global deref chain that starts as a pointer value and ends with a load/store/atomic,
* create a new SSBO deref chain. The new chain starts with a load_vulkan_descriptor, then casts
* the resulting vec2 to an SSBO deref. */
static bool
lower_buffer_device_address(nir_builder *b, nir_intrinsic_instr *intr, void *data)
{
switch (intr->intrinsic) {
case nir_intrinsic_load_deref:
case nir_intrinsic_store_deref:
case nir_intrinsic_deref_atomic:
case nir_intrinsic_deref_atomic_swap:
break;
default:
assert(intr->intrinsic != nir_intrinsic_copy_deref);
return false;
}
nir_deref_instr *deref = nir_src_as_deref(intr->src[0]);
if (!nir_deref_mode_is(deref, nir_var_mem_global))
return false;
nir_deref_path path;
nir_deref_path_init(&path, deref, NULL);
nir_deref_instr *old_head = path.path[0];
assert(old_head->deref_type == nir_deref_type_cast &&
old_head->parent.ssa->bit_size == 64 &&
old_head->parent.ssa->num_components == 1);
b->cursor = nir_after_instr(&old_head->instr);
nir_def *pointer = old_head->parent.ssa;
nir_def *offset = nir_unpack_64_2x32_split_x(b, pointer);
nir_def *index = nir_iand_imm(b, nir_unpack_64_2x32_split_y(b, pointer), 0xffffff);
nir_def *descriptor = nir_load_vulkan_descriptor(b, 2, 32, nir_vec2(b, index, offset),
.desc_type = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER);
nir_deref_instr *head = nir_build_deref_cast_with_alignment(b, descriptor, nir_var_mem_ssbo, old_head->type,
old_head->cast.ptr_stride,
old_head->cast.align_mul,
old_head->cast.align_offset);
for (int i = 1; path.path[i]; ++i) {
nir_deref_instr *old = path.path[i];
b->cursor = nir_after_instr(&old->instr);
head = nir_build_deref_follower(b, head, old);
}
nir_src_rewrite(&intr->src[0], &head->def);
nir_deref_path_finish(&path);
return true;
}
bool
dxil_spirv_nir_lower_buffer_device_address(nir_shader *nir)
{
return nir_shader_intrinsics_pass(nir, lower_buffer_device_address,
nir_metadata_block_index |
nir_metadata_dominance |
nir_metadata_loop_analysis,
NULL);
}