From 652af97f8b5757b9803cd5235312d0060b0f2a6a Mon Sep 17 00:00:00 2001 From: Samuel Pitoiset Date: Tue, 7 Apr 2026 17:55:14 +0200 Subject: [PATCH] radv/nir: lower descriptor heap in radv_nir_lower_descriptors Signed-off-by: Samuel Pitoiset Part-of: --- .../vulkan/nir/radv_nir_lower_descriptors.c | 199 ++++++++++++++---- 1 file changed, 156 insertions(+), 43 deletions(-) diff --git a/src/amd/vulkan/nir/radv_nir_lower_descriptors.c b/src/amd/vulkan/nir/radv_nir_lower_descriptors.c index bd3da26249a..a80e85c6dae 100644 --- a/src/amd/vulkan/nir/radv_nir_lower_descriptors.c +++ b/src/amd/vulkan/nir/radv_nir_lower_descriptors.c @@ -20,6 +20,7 @@ typedef struct { enum amd_gfx_level gfx_level; uint32_t address32_hi; + uint32_t sampled_image_desc_size; uint32_t combined_image_sampler_desc_size; uint32_t combined_image_sampler_offset; bool disable_aniso_single_level; @@ -102,6 +103,16 @@ load_desc_ptr(nir_builder *b, lower_descriptors_state *state, unsigned set) return get_scalar_arg(b, 1, state->args->descriptors[set]); } +static nir_def * +load_heap_ptr(nir_builder *b, lower_descriptors_state *state, unsigned heap_idx) +{ + if (mesa_shader_stage_is_rt(b->shader->info.stage)) + return nir_load_param(b, heap_idx == RADV_HEAP_RESOURCE ? RT_ARG_HEAP_RESOURCE : RT_ARG_HEAP_SAMPLER); + + assert(state->args->descriptors[heap_idx].used); + return get_scalar_arg(b, 1, state->args->descriptors[heap_idx]); +} + static void visit_vulkan_resource_index(nir_builder *b, lower_descriptors_state *state, nir_intrinsic_instr *intrin) { @@ -229,20 +240,57 @@ visit_ssbo_descriptor_amd(nir_builder *b, lower_descriptors_state *state, nir_in } static nir_def * -get_sampler_desc(nir_builder *b, lower_descriptors_state *state, nir_deref_instr *deref, +get_sampler_desc(nir_builder *b, lower_descriptors_state *state, nir_deref_instr *deref, nir_def *index, enum ac_descriptor_type desc_type, bool non_uniform, nir_tex_instr *tex, bool write) { - nir_variable *var = nir_deref_instr_get_variable(deref); - assert(var); - unsigned desc_set = var->data.descriptor_set; - unsigned binding_index = var->data.binding; - bool indirect = nir_deref_instr_has_indirect(deref); + nir_def *desc_ptr = NULL; + uint32_t offset = 0; + bool indirect = false; + uint32_t plane_offset; - struct radv_descriptor_set_layout *layout = state->layout->set[desc_set].layout; - struct radv_descriptor_set_binding_layout *binding = &layout->binding[binding_index]; + if (deref) { + nir_variable *var = nir_deref_instr_get_variable(deref); + assert(var && !index); + unsigned desc_set = var->data.descriptor_set; + unsigned binding_index = var->data.binding; + indirect = nir_deref_instr_has_indirect(deref); + + struct radv_descriptor_set_layout *layout = state->layout->set[desc_set].layout; + struct radv_descriptor_set_binding_layout *binding = &layout->binding[binding_index]; + + if (desc_type == AC_DESC_SAMPLER) { + /* Immutable/embedded samplers are lowered earlier. */ + assert(!binding->immutable_samplers_offset || indirect); + } + + while (deref->deref_type != nir_deref_type_var) { + assert(deref->deref_type == nir_deref_type_array); + unsigned array_size = MAX2(glsl_get_aoa_size(deref->type), 1); + array_size *= binding->size; + + nir_def *tmp = nir_imul_imm_nuw(b, deref->arr.index.ssa, array_size); + + if (index) { + index = nir_iadd_nuw(b, tmp, index); + } else { + index = tmp; + } + + deref = nir_deref_instr_parent(deref); + } + + offset = binding->offset; + if (desc_type == AC_DESC_SAMPLER && binding->type == VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER) + offset += state->combined_image_sampler_offset; + + desc_ptr = load_desc_ptr(b, state, desc_set); + plane_offset = state->combined_image_sampler_desc_size; + } else { + desc_ptr = load_heap_ptr(b, state, desc_type == AC_DESC_SAMPLER ? RADV_HEAP_SAMPLER : RADV_HEAP_RESOURCE); + plane_offset = state->sampled_image_desc_size; + } unsigned size = 8; - unsigned offset = binding->offset; switch (desc_type) { case AC_DESC_IMAGE: case AC_DESC_PLANE_0: @@ -251,47 +299,25 @@ get_sampler_desc(nir_builder *b, lower_descriptors_state *state, nir_deref_instr offset += 32; break; case AC_DESC_PLANE_1: - offset += state->combined_image_sampler_desc_size; + offset += plane_offset; break; case AC_DESC_SAMPLER: - /* Immutable/embedded samplers are lowered earlier. */ - assert(!binding->immutable_samplers_offset || indirect); - size = RADV_SAMPLER_DESC_SIZE / 4; - if (binding->type == VK_DESCRIPTOR_TYPE_COMBINED_IMAGE_SAMPLER) - offset += state->combined_image_sampler_offset; break; case AC_DESC_BUFFER: size = RADV_BUFFER_DESC_SIZE / 4; break; case AC_DESC_PLANE_2: - offset += 2 * state->combined_image_sampler_desc_size; + offset += 2 * plane_offset; break; } - nir_def *index = NULL; - while (deref->deref_type != nir_deref_type_var) { - assert(deref->deref_type == nir_deref_type_array); - unsigned array_size = MAX2(glsl_get_aoa_size(deref->type), 1); - array_size *= binding->size; - - nir_def *tmp = nir_imul_imm_nuw(b, deref->arr.index.ssa, array_size); - - if (index) { - index = nir_iadd_nuw(b, tmp, index); - } else { - index = tmp; - } - - deref = nir_deref_instr_parent(deref); - } - nir_def *index_offset = index ? nir_iadd_imm_nuw(b, index, offset) : nir_imm_int(b, offset); if (non_uniform) - return nir_iadd(b, load_desc_ptr(b, state, desc_set), index_offset); + return nir_iadd(b, desc_ptr, index_offset); - nir_def *addr = convert_pointer_to_64_bit(b, state, load_desc_ptr(b, state, desc_set)); + nir_def *addr = convert_pointer_to_64_bit(b, state, desc_ptr); nir_def *desc = ac_nir_load_smem(b, size, addr, index_offset, size * 4u, 0); if (desc_type == AC_DESC_IMAGE && state->has_image_load_dcc_bug && !tex && !write) { @@ -329,7 +355,7 @@ update_image_intrinsic(nir_builder *b, lower_descriptors_state *state, nir_intri bool is_load = intrin->intrinsic == nir_intrinsic_image_deref_load || intrin->intrinsic == nir_intrinsic_image_deref_sparse_load; - nir_def *desc = get_sampler_desc(b, state, deref, dim == GLSL_SAMPLER_DIM_BUF ? AC_DESC_BUFFER : AC_DESC_IMAGE, + nir_def *desc = get_sampler_desc(b, state, deref, NULL, dim == GLSL_SAMPLER_DIM_BUF ? AC_DESC_BUFFER : AC_DESC_IMAGE, nir_intrinsic_access(intrin) & ACCESS_NON_UNIFORM, NULL, !is_load); if (intrin->intrinsic == nir_intrinsic_image_deref_descriptor_amd) { @@ -339,6 +365,56 @@ update_image_intrinsic(nir_builder *b, lower_descriptors_state *state, nir_intri } } +static void +lower_load_heap_descriptor(nir_builder *b, lower_descriptors_state *state, nir_intrinsic_instr *intrin) +{ + nir_resource_type resource_type = nir_intrinsic_resource_type(intrin); + assert(resource_type != nir_resource_type_sampler); + + nir_def *heap_ptr = load_heap_ptr(b, state, RADV_HEAP_RESOURCE); + nir_def *desc; + + if (resource_type == nir_resource_type_acceleration_structure) { + nir_def *addr = convert_pointer_to_64_bit(b, state, nir_iadd(b, heap_ptr, intrin->src[0].ssa)); + desc = nir_build_load_global(b, 1, 64, addr, .access = ACCESS_NON_WRITEABLE); + } else { + desc = nir_vec3(b, heap_ptr, intrin->src[0].ssa, nir_imm_int(b, 0)); + } + + nir_def_replace(&intrin->def, desc); +} + +static void +lower_load_resource_heap_data(nir_builder *b, lower_descriptors_state *state, nir_intrinsic_instr *intrin) +{ + nir_def *heap_ptr = load_heap_ptr(b, state, RADV_HEAP_RESOURCE); + nir_def *addr = convert_pointer_to_64_bit(b, state, nir_iadd(b, heap_ptr, intrin->src[0].ssa)); + + nir_def *data = nir_build_load_global(b, intrin->def.num_components, intrin->def.bit_size, addr, + .access = ACCESS_NON_WRITEABLE, .align_mul = nir_intrinsic_align_mul(intrin), + .align_offset = nir_intrinsic_align_offset(intrin)); + + nir_def_replace(&intrin->def, data); +} + +static void +lower_image_heap_intrinsic(nir_builder *b, lower_descriptors_state *state, nir_intrinsic_instr *intrin) +{ + const bool is_load = + intrin->intrinsic == nir_intrinsic_image_heap_load || intrin->intrinsic == nir_intrinsic_image_heap_sparse_load; + enum glsl_sampler_dim dim = nir_intrinsic_image_dim(intrin); + + nir_def *desc = + get_sampler_desc(b, state, NULL, intrin->src[0].ssa, dim == GLSL_SAMPLER_DIM_BUF ? AC_DESC_BUFFER : AC_DESC_IMAGE, + nir_intrinsic_access(intrin) & ACCESS_NON_UNIFORM, NULL, !is_load); + + if (intrin->intrinsic == nir_intrinsic_image_heap_descriptor_amd) { + nir_def_replace(&intrin->def, desc); + } else { + nir_rewrite_image_intrinsic(intrin, desc, nir_image_intrinsic_type_bindless); + } +} + static bool can_increase_load_size(nir_intrinsic_instr *intrin, unsigned offset, unsigned old, unsigned new) { @@ -459,6 +535,22 @@ lower_descriptors_intrin(nir_builder *b, lower_descriptors_state *state, nir_int nir_def_replace(&intrin->def, load_push_constant(b, state, intrin)); break; } + case nir_intrinsic_load_heap_descriptor: + lower_load_heap_descriptor(b, state, intrin); + break; + case nir_intrinsic_load_resource_heap_data: + lower_load_resource_heap_data(b, state, intrin); + break; + case nir_intrinsic_image_heap_load: + case nir_intrinsic_image_heap_sparse_load: + case nir_intrinsic_image_heap_store: + case nir_intrinsic_image_heap_atomic: + case nir_intrinsic_image_heap_atomic_swap: + case nir_intrinsic_image_heap_size: + case nir_intrinsic_image_heap_samples: + case nir_intrinsic_image_heap_descriptor_amd: + lower_image_heap_intrinsic(b, state, intrin); + break; default: return false; } @@ -473,6 +565,8 @@ lower_descriptors_tex(nir_builder *b, lower_descriptors_state *state, nir_tex_in nir_deref_instr *texture_deref_instr = NULL; nir_deref_instr *sampler_deref_instr = NULL; + nir_def *texture_heap_offset = NULL; + nir_def *sampler_heap_offset = NULL; int plane = -1; nir_def *image = NULL; @@ -492,6 +586,12 @@ lower_descriptors_tex(nir_builder *b, lower_descriptors_state *state, nir_tex_in case nir_tex_src_sampler_handle: sampler = tex->src[i].src.ssa; break; + case nir_tex_src_texture_heap_offset: + texture_heap_offset = tex->src[i].src.ssa; + break; + case nir_tex_src_sampler_heap_offset: + sampler_heap_offset = tex->src[i].src.ssa; + break; default: break; } @@ -500,19 +600,23 @@ lower_descriptors_tex(nir_builder *b, lower_descriptors_state *state, nir_tex_in if (plane >= 0) { assert(tex->op != nir_texop_txf_ms && tex->op != nir_texop_samples_identical); assert(tex->sampler_dim != GLSL_SAMPLER_DIM_BUF); - image = - get_sampler_desc(b, state, texture_deref_instr, AC_DESC_PLANE_0 + plane, tex->texture_non_uniform, tex, false); + image = get_sampler_desc(b, state, texture_deref_instr, texture_heap_offset, AC_DESC_PLANE_0 + plane, + tex->texture_non_uniform, tex, false); } else if (tex->sampler_dim == GLSL_SAMPLER_DIM_BUF) { - image = get_sampler_desc(b, state, texture_deref_instr, AC_DESC_BUFFER, tex->texture_non_uniform, tex, false); + image = get_sampler_desc(b, state, texture_deref_instr, texture_heap_offset, AC_DESC_BUFFER, + tex->texture_non_uniform, tex, false); } else if (tex->op == nir_texop_fragment_mask_fetch_amd) { - image = get_sampler_desc(b, state, texture_deref_instr, AC_DESC_FMASK, tex->texture_non_uniform, tex, false); + image = get_sampler_desc(b, state, texture_deref_instr, texture_heap_offset, AC_DESC_FMASK, + tex->texture_non_uniform, tex, false); } else { - image = get_sampler_desc(b, state, texture_deref_instr, AC_DESC_IMAGE, tex->texture_non_uniform, tex, false); + image = get_sampler_desc(b, state, texture_deref_instr, texture_heap_offset, AC_DESC_IMAGE, + tex->texture_non_uniform, tex, false); } - if (sampler_deref_instr) { + if (sampler_deref_instr || sampler_heap_offset) { assert(!sampler); - sampler = get_sampler_desc(b, state, sampler_deref_instr, AC_DESC_SAMPLER, tex->sampler_non_uniform, tex, false); + sampler = get_sampler_desc(b, state, sampler_deref_instr, sampler_heap_offset, AC_DESC_SAMPLER, + tex->sampler_non_uniform, tex, false); } if (sampler && state->disable_aniso_single_level && tex->sampler_dim < GLSL_SAMPLER_DIM_RECT) { @@ -555,6 +659,14 @@ lower_descriptors_tex(nir_builder *b, lower_descriptors_state *state, nir_tex_in case nir_tex_src_sampler_handle: nir_src_rewrite(&tex->src[i].src, sampler); break; + case nir_tex_src_texture_heap_offset: + tex->src[i].src_type = nir_tex_src_texture_handle; + nir_src_rewrite(&tex->src[i].src, image); + break; + case nir_tex_src_sampler_heap_offset: + tex->src[i].src_type = nir_tex_src_sampler_handle; + nir_src_rewrite(&tex->src[i].src, sampler); + break; default: break; } @@ -572,6 +684,7 @@ radv_nir_lower_descriptors(nir_shader *shader, struct radv_device *device, const lower_descriptors_state state = { .gfx_level = pdev->info.gfx_level, .address32_hi = pdev->info.address32_hi, + .sampled_image_desc_size = radv_get_sampled_image_desc_size(pdev), .combined_image_sampler_desc_size = radv_get_combined_image_sampler_desc_size(pdev), .combined_image_sampler_offset = radv_get_combined_image_sampler_offset(pdev), .disable_aniso_single_level = pdev->cache_key.disable_aniso_single_level,