mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2026-05-04 20:38:06 +02:00
radv/rt: Lower ray payloads like hit attribs
Reviewed-by: Friedrich Vock <friedrich.vock@gmx.de> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/27051>
This commit is contained in:
parent
4f0c33196c
commit
c925b6019d
6 changed files with 121 additions and 37 deletions
|
|
@ -51,6 +51,8 @@ void radv_nir_lower_abi(nir_shader *shader, enum amd_gfx_level gfx_level, const
|
|||
|
||||
bool radv_nir_lower_hit_attrib_derefs(nir_shader *shader);
|
||||
|
||||
bool radv_nir_lower_ray_payload_derefs(nir_shader *shader, uint32_t offset);
|
||||
|
||||
bool radv_nir_lower_ray_queries(nir_shader *shader, struct radv_device *device);
|
||||
|
||||
bool radv_nir_lower_vs_inputs(nir_shader *shader, const struct radv_shader_stage *vs_stage,
|
||||
|
|
|
|||
|
|
@ -6,8 +6,14 @@
|
|||
|
||||
#include "nir.h"
|
||||
#include "nir_builder.h"
|
||||
#include "radv_constants.h"
|
||||
#include "radv_nir.h"
|
||||
|
||||
struct lower_hit_attrib_deref_args {
|
||||
nir_variable_mode mode;
|
||||
uint32_t base_offset;
|
||||
};
|
||||
|
||||
static bool
|
||||
lower_hit_attrib_deref(nir_builder *b, nir_instr *instr, void *data)
|
||||
{
|
||||
|
|
@ -18,8 +24,9 @@ lower_hit_attrib_deref(nir_builder *b, nir_instr *instr, void *data)
|
|||
if (intrin->intrinsic != nir_intrinsic_load_deref && intrin->intrinsic != nir_intrinsic_store_deref)
|
||||
return false;
|
||||
|
||||
struct lower_hit_attrib_deref_args *args = data;
|
||||
nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]);
|
||||
if (!nir_deref_mode_is(deref, nir_var_ray_hit_attrib))
|
||||
if (!nir_deref_mode_is(deref, args->mode))
|
||||
return false;
|
||||
|
||||
assert(deref->deref_type == nir_deref_type_var);
|
||||
|
|
@ -33,7 +40,7 @@ lower_hit_attrib_deref(nir_builder *b, nir_instr *instr, void *data)
|
|||
nir_def *components[NIR_MAX_VEC_COMPONENTS];
|
||||
|
||||
for (uint32_t comp = 0; comp < num_components; comp++) {
|
||||
uint32_t offset = deref->var->data.driver_location + comp * DIV_ROUND_UP(bit_size, 8);
|
||||
uint32_t offset = args->base_offset + deref->var->data.driver_location + comp * DIV_ROUND_UP(bit_size, 8);
|
||||
uint32_t base = offset / 4;
|
||||
uint32_t comp_offset = offset % 4;
|
||||
|
||||
|
|
@ -61,7 +68,7 @@ lower_hit_attrib_deref(nir_builder *b, nir_instr *instr, void *data)
|
|||
uint32_t bit_size = value->bit_size;
|
||||
|
||||
for (uint32_t comp = 0; comp < num_components; comp++) {
|
||||
uint32_t offset = deref->var->data.driver_location + comp * DIV_ROUND_UP(bit_size, 8);
|
||||
uint32_t offset = args->base_offset + deref->var->data.driver_location + comp * DIV_ROUND_UP(bit_size, 8);
|
||||
uint32_t base = offset / 4;
|
||||
uint32_t comp_offset = offset % 4;
|
||||
|
||||
|
|
@ -95,24 +102,43 @@ lower_hit_attrib_deref(nir_builder *b, nir_instr *instr, void *data)
|
|||
return true;
|
||||
}
|
||||
|
||||
bool
|
||||
radv_nir_lower_hit_attrib_derefs(nir_shader *shader)
|
||||
static bool
|
||||
radv_nir_lower_rt_vars(nir_shader *shader, nir_variable_mode mode, uint32_t base_offset)
|
||||
{
|
||||
bool progress = false;
|
||||
|
||||
progress |= nir_split_struct_vars(shader, nir_var_ray_hit_attrib);
|
||||
progress |= nir_lower_indirect_derefs(shader, nir_var_ray_hit_attrib, UINT32_MAX);
|
||||
progress |= nir_split_array_vars(shader, nir_var_ray_hit_attrib);
|
||||
progress |= nir_split_struct_vars(shader, mode);
|
||||
progress |= nir_lower_indirect_derefs(shader, mode, UINT32_MAX);
|
||||
progress |= nir_split_array_vars(shader, mode);
|
||||
|
||||
progress |= nir_lower_vars_to_explicit_types(shader, nir_var_ray_hit_attrib, glsl_get_natural_size_align_bytes);
|
||||
progress |= nir_lower_vars_to_explicit_types(shader, mode, glsl_get_natural_size_align_bytes);
|
||||
|
||||
struct lower_hit_attrib_deref_args args = {
|
||||
.mode = mode,
|
||||
.base_offset = base_offset,
|
||||
};
|
||||
|
||||
progress |= nir_shader_instructions_pass(shader, lower_hit_attrib_deref,
|
||||
nir_metadata_block_index | nir_metadata_dominance, NULL);
|
||||
nir_metadata_block_index | nir_metadata_dominance, &args);
|
||||
|
||||
if (progress) {
|
||||
nir_remove_dead_derefs(shader);
|
||||
nir_remove_dead_variables(shader, nir_var_ray_hit_attrib, NULL);
|
||||
nir_remove_dead_variables(shader, mode, NULL);
|
||||
}
|
||||
|
||||
return progress;
|
||||
}
|
||||
|
||||
bool
|
||||
radv_nir_lower_hit_attrib_derefs(nir_shader *shader)
|
||||
{
|
||||
return radv_nir_lower_rt_vars(shader, nir_var_ray_hit_attrib, 0);
|
||||
}
|
||||
|
||||
bool
|
||||
radv_nir_lower_ray_payload_derefs(nir_shader *shader, uint32_t offset)
|
||||
{
|
||||
bool progress = radv_nir_lower_rt_vars(shader, nir_var_function_temp, RADV_MAX_HIT_ATTRIB_SIZE + offset);
|
||||
progress |= radv_nir_lower_rt_vars(shader, nir_var_shader_call_data, RADV_MAX_HIT_ATTRIB_SIZE + offset);
|
||||
return progress;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -187,6 +187,7 @@ lower_rt_derefs(nir_shader *shader)
|
|||
struct rt_variables {
|
||||
struct radv_device *device;
|
||||
const VkPipelineCreateFlags2KHR flags;
|
||||
bool monolithic;
|
||||
|
||||
/* idx of the next shader to run in the next iteration of the main loop.
|
||||
* During traversal, idx is used to store the SBT index and will contain
|
||||
|
|
@ -198,6 +199,7 @@ struct rt_variables {
|
|||
|
||||
/* scratch offset of the argument area relative to stack_ptr */
|
||||
nir_variable *arg;
|
||||
uint32_t payload_offset;
|
||||
|
||||
nir_variable *stack_ptr;
|
||||
|
||||
|
|
@ -230,11 +232,13 @@ struct rt_variables {
|
|||
};
|
||||
|
||||
static struct rt_variables
|
||||
create_rt_variables(nir_shader *shader, struct radv_device *device, const VkPipelineCreateFlags2KHR flags)
|
||||
create_rt_variables(nir_shader *shader, struct radv_device *device, const VkPipelineCreateFlags2KHR flags,
|
||||
bool monolithic)
|
||||
{
|
||||
struct rt_variables vars = {
|
||||
.device = device,
|
||||
.flags = flags,
|
||||
.monolithic = monolithic,
|
||||
};
|
||||
vars.idx = nir_variable_create(shader, nir_var_shader_temp, glsl_uint_type(), "idx");
|
||||
vars.shader_addr = nir_variable_create(shader, nir_var_shader_temp, glsl_uint64_t_type(), "shader_addr");
|
||||
|
|
@ -790,7 +794,7 @@ insert_rt_case(nir_builder *b, nir_shader *shader, struct rt_variables *vars, ni
|
|||
|
||||
nir_opt_dead_cf(shader);
|
||||
|
||||
struct rt_variables src_vars = create_rt_variables(shader, vars->device, vars->flags);
|
||||
struct rt_variables src_vars = create_rt_variables(shader, vars->device, vars->flags, vars->monolithic);
|
||||
map_rt_variables(var_remap, &src_vars, vars);
|
||||
|
||||
NIR_PASS_V(shader, lower_rt_instructions, &src_vars, false);
|
||||
|
|
@ -807,20 +811,48 @@ insert_rt_case(nir_builder *b, nir_shader *shader, struct rt_variables *vars, ni
|
|||
ralloc_free(var_remap);
|
||||
}
|
||||
|
||||
nir_shader *
|
||||
radv_parse_rt_stage(struct radv_device *device, const struct radv_shader_stage *rt_stage)
|
||||
static bool
|
||||
radv_lower_payload_arg_to_offset(nir_builder *b, nir_intrinsic_instr *instr, void *data)
|
||||
{
|
||||
nir_shader *shader = radv_shader_spirv_to_nir(device, rt_stage, NULL, false);
|
||||
if (instr->intrinsic != nir_intrinsic_trace_ray)
|
||||
return false;
|
||||
|
||||
NIR_PASS(_, shader, nir_lower_vars_to_explicit_types, nir_var_function_temp | nir_var_shader_call_data,
|
||||
glsl_get_natural_size_align_bytes);
|
||||
nir_deref_instr *payload = nir_src_as_deref(instr->src[10]);
|
||||
assert(payload->deref_type == nir_deref_type_var);
|
||||
|
||||
NIR_PASS(_, shader, lower_rt_derefs);
|
||||
NIR_PASS(_, shader, radv_nir_lower_hit_attrib_derefs);
|
||||
b->cursor = nir_before_instr(&instr->instr);
|
||||
nir_def *offset = nir_imm_int(b, payload->var->data.driver_location);
|
||||
|
||||
NIR_PASS(_, shader, nir_lower_explicit_io, nir_var_function_temp, nir_address_format_32bit_offset);
|
||||
nir_src_rewrite(&instr->src[10], offset);
|
||||
|
||||
return shader;
|
||||
return true;
|
||||
}
|
||||
|
||||
void
|
||||
radv_nir_lower_rt_io(nir_shader *nir, bool monolithic, uint32_t payload_offset)
|
||||
{
|
||||
if (!monolithic) {
|
||||
NIR_PASS(_, nir, nir_lower_vars_to_explicit_types, nir_var_function_temp | nir_var_shader_call_data,
|
||||
glsl_get_natural_size_align_bytes);
|
||||
|
||||
NIR_PASS(_, nir, lower_rt_derefs);
|
||||
|
||||
NIR_PASS(_, nir, nir_lower_explicit_io, nir_var_function_temp, nir_address_format_32bit_offset);
|
||||
} else {
|
||||
if (nir->info.stage == MESA_SHADER_RAYGEN) {
|
||||
/* Use nir_lower_vars_to_explicit_types to assign the payload locations. We call
|
||||
* nir_lower_vars_to_explicit_types later after splitting the payloads.
|
||||
*/
|
||||
uint32_t scratch_size = nir->scratch_size;
|
||||
nir_lower_vars_to_explicit_types(nir, nir_var_function_temp, glsl_get_natural_size_align_bytes);
|
||||
nir->scratch_size = scratch_size;
|
||||
|
||||
nir_shader_intrinsics_pass(nir, radv_lower_payload_arg_to_offset,
|
||||
nir_metadata_block_index | nir_metadata_dominance, NULL);
|
||||
}
|
||||
|
||||
NIR_PASS(_, nir, radv_nir_lower_ray_payload_derefs, payload_offset);
|
||||
}
|
||||
}
|
||||
|
||||
static nir_function_impl *
|
||||
|
|
@ -1118,6 +1150,8 @@ radv_build_ahit_case(nir_builder *b, nir_def *sbt_idx, struct radv_ray_tracing_g
|
|||
radv_pipeline_cache_handle_to_nir(data->device, data->pipeline->stages[group->any_hit_shader].nir);
|
||||
assert(nir_stage);
|
||||
|
||||
radv_nir_lower_rt_io(nir_stage, data->vars->monolithic, data->vars->payload_offset);
|
||||
|
||||
insert_rt_case(b, nir_stage, data->vars, sbt_idx, group->handle.any_hit_index);
|
||||
ralloc_free(nir_stage);
|
||||
}
|
||||
|
|
@ -1140,12 +1174,16 @@ radv_build_isec_case(nir_builder *b, nir_def *sbt_idx, struct radv_ray_tracing_g
|
|||
radv_pipeline_cache_handle_to_nir(data->device, data->pipeline->stages[group->intersection_shader].nir);
|
||||
assert(nir_stage);
|
||||
|
||||
radv_nir_lower_rt_io(nir_stage, data->vars->monolithic, data->vars->payload_offset);
|
||||
|
||||
nir_shader *any_hit_stage = NULL;
|
||||
if (group->any_hit_shader != VK_SHADER_UNUSED_KHR) {
|
||||
any_hit_stage =
|
||||
radv_pipeline_cache_handle_to_nir(data->device, data->pipeline->stages[group->any_hit_shader].nir);
|
||||
assert(any_hit_stage);
|
||||
|
||||
radv_nir_lower_rt_io(any_hit_stage, data->vars->monolithic, data->vars->payload_offset);
|
||||
|
||||
/* reserve stack size for any_hit before it is inlined */
|
||||
data->pipeline->stages[group->any_hit_shader].stack_size = any_hit_stage->scratch_size;
|
||||
|
||||
|
|
@ -1188,6 +1226,8 @@ radv_build_recursive_case(nir_builder *b, nir_def *sbt_idx, struct radv_ray_trac
|
|||
radv_pipeline_cache_handle_to_nir(data->device, data->pipeline->stages[group->recursive_shader].nir);
|
||||
assert(nir_stage);
|
||||
|
||||
radv_nir_lower_rt_io(nir_stage, data->vars->monolithic, data->vars->payload_offset);
|
||||
|
||||
insert_rt_case(b, nir_stage, data->vars, sbt_idx, group->handle.general_index);
|
||||
ralloc_free(nir_stage);
|
||||
}
|
||||
|
|
@ -1509,7 +1549,7 @@ radv_build_traversal_shader(struct radv_device *device, struct radv_ray_tracing_
|
|||
b.shader->info.workgroup_size[0] = 8;
|
||||
b.shader->info.workgroup_size[1] = device->physical_device->rt_wave_size == 64 ? 8 : 4;
|
||||
b.shader->info.shared_size = device->physical_device->rt_wave_size * MAX_STACK_ENTRY_COUNT * sizeof(uint32_t);
|
||||
struct rt_variables vars = create_rt_variables(b.shader, device, create_flags);
|
||||
struct rt_variables vars = create_rt_variables(b.shader, device, create_flags, false);
|
||||
|
||||
/* initialize trace_ray arguments */
|
||||
nir_store_var(&b, vars.accel_struct, nir_load_accel_struct_amd(&b), 1);
|
||||
|
|
@ -1561,7 +1601,7 @@ lower_rt_instruction_monolithic(nir_builder *b, nir_instr *instr, void *data)
|
|||
case nir_intrinsic_execute_callable:
|
||||
unreachable("nir_intrinsic_execute_callable");
|
||||
case nir_intrinsic_trace_ray: {
|
||||
nir_store_var(b, vars->arg, nir_iadd_imm(b, intr->src[10].ssa, -b->shader->scratch_size), 1);
|
||||
vars->payload_offset = nir_src_as_uint(intr->src[10]);
|
||||
|
||||
nir_src cull_mask = intr->src[2];
|
||||
bool ignore_cull_mask = nir_src_is_const(cull_mask) && (nir_src_as_uint(cull_mask) & 0xFF) == 0xFF;
|
||||
|
|
@ -1603,6 +1643,16 @@ lower_rt_instruction_monolithic(nir_builder *b, nir_instr *instr, void *data)
|
|||
}
|
||||
}
|
||||
|
||||
static bool
|
||||
radv_count_hit_attrib_slots(nir_builder *b, nir_intrinsic_instr *instr, void *data)
|
||||
{
|
||||
uint32_t *count = data;
|
||||
if (instr->intrinsic == nir_intrinsic_load_hit_attrib_amd || instr->intrinsic == nir_intrinsic_store_hit_attrib_amd)
|
||||
*count = MAX2(*count, nir_intrinsic_base(instr) + 1);
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
static void
|
||||
lower_rt_instructions_monolithic(nir_shader *shader, struct radv_device *device,
|
||||
struct radv_ray_tracing_pipeline *pipeline,
|
||||
|
|
@ -1620,10 +1670,12 @@ lower_rt_instructions_monolithic(nir_shader *shader, struct radv_device *device,
|
|||
nir_shader_instructions_pass(shader, lower_rt_instruction_monolithic, nir_metadata_none, &state);
|
||||
nir_index_ssa_defs(impl);
|
||||
|
||||
/* Register storage for hit attributes */
|
||||
nir_variable *hit_attribs[RADV_MAX_HIT_ATTRIB_SIZE / sizeof(uint32_t)];
|
||||
uint32_t hit_attrib_count = 0;
|
||||
nir_shader_intrinsics_pass(shader, radv_count_hit_attrib_slots, nir_metadata_all, &hit_attrib_count);
|
||||
|
||||
for (uint32_t i = 0; i < ARRAY_SIZE(hit_attribs); i++)
|
||||
/* Register storage for hit attributes */
|
||||
STACK_ARRAY(nir_variable *, hit_attribs, hit_attrib_count);
|
||||
for (uint32_t i = 0; i < hit_attrib_count; i++)
|
||||
hit_attribs[i] = nir_local_variable_create(impl, glsl_uint_type(), "ahit_attrib");
|
||||
|
||||
lower_hit_attribs(shader, hit_attribs, 0);
|
||||
|
|
@ -1676,7 +1728,7 @@ radv_nir_lower_rt_abi(nir_shader *shader, const VkRayTracingPipelineCreateInfoKH
|
|||
|
||||
const VkPipelineCreateFlagBits2KHR create_flags = vk_rt_pipeline_create_flags(pCreateInfo);
|
||||
|
||||
struct rt_variables vars = create_rt_variables(shader, device, create_flags);
|
||||
struct rt_variables vars = create_rt_variables(shader, device, create_flags, monolithic);
|
||||
|
||||
if (monolithic)
|
||||
lower_rt_instructions_monolithic(shader, device, pipeline, pCreateInfo, &vars);
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@
|
|||
#include "nir/nir.h"
|
||||
#include "nir/nir_builder.h"
|
||||
|
||||
#include "nir/radv_nir.h"
|
||||
#include "radv_debug.h"
|
||||
#include "radv_private.h"
|
||||
#include "radv_shader.h"
|
||||
|
|
@ -318,10 +319,11 @@ radv_create_merged_rt_create_info(const VkRayTracingPipelineCreateInfoKHR *pCrea
|
|||
}
|
||||
|
||||
static bool
|
||||
should_move_rt_instruction(nir_intrinsic_op intrinsic)
|
||||
should_move_rt_instruction(nir_intrinsic_instr *instr)
|
||||
{
|
||||
switch (intrinsic) {
|
||||
switch (instr->intrinsic) {
|
||||
case nir_intrinsic_load_hit_attrib_amd:
|
||||
return nir_intrinsic_base(instr) < RADV_MAX_HIT_ATTRIB_DWORDS;
|
||||
case nir_intrinsic_load_rt_arg_scratch_offset_amd:
|
||||
case nir_intrinsic_load_ray_flags:
|
||||
case nir_intrinsic_load_ray_object_origin:
|
||||
|
|
@ -348,7 +350,7 @@ move_rt_instructions(nir_shader *shader)
|
|||
|
||||
nir_intrinsic_instr *intrinsic = nir_instr_as_intrinsic(instr);
|
||||
|
||||
if (!should_move_rt_instruction(intrinsic->intrinsic))
|
||||
if (!should_move_rt_instruction(intrinsic))
|
||||
continue;
|
||||
|
||||
nir_instr_move(target, instr);
|
||||
|
|
@ -368,6 +370,8 @@ radv_rt_nir_to_asm(struct radv_device *device, struct vk_pipeline_cache *cache,
|
|||
bool keep_executable_info = radv_pipeline_capture_shaders(device, pipeline->base.base.create_flags);
|
||||
bool keep_statistic_info = radv_pipeline_capture_shader_stats(device, pipeline->base.base.create_flags);
|
||||
|
||||
radv_nir_lower_rt_io(stage->nir, monolithic, 0);
|
||||
|
||||
/* Gather shader info. */
|
||||
nir_shader_gather_info(stage->nir, nir_shader_get_entrypoint(stage->nir));
|
||||
radv_nir_shader_info_init(stage->stage, MESA_SHADER_NONE, &stage->info);
|
||||
|
|
@ -525,7 +529,9 @@ radv_rt_compile_shaders(struct radv_device *device, struct vk_pipeline_cache *ca
|
|||
radv_pipeline_stage_init(&pCreateInfo->pStages[i], pipeline_layout, &stage_keys[s], stage);
|
||||
|
||||
/* precompile the shader */
|
||||
stage->nir = radv_parse_rt_stage(device, stage);
|
||||
stage->nir = radv_shader_spirv_to_nir(device, stage, NULL, false);
|
||||
|
||||
NIR_PASS(_, stage->nir, radv_nir_lower_hit_attrib_derefs);
|
||||
|
||||
rt_stages[i].can_inline = radv_rt_can_inline_shader(stage->nir);
|
||||
|
||||
|
|
|
|||
|
|
@ -774,7 +774,7 @@ void radv_postprocess_nir(struct radv_device *device, const struct radv_graphics
|
|||
|
||||
bool radv_shader_should_clear_lds(const struct radv_device *device, const nir_shader *shader);
|
||||
|
||||
nir_shader *radv_parse_rt_stage(struct radv_device *device, const struct radv_shader_stage *rt_stage);
|
||||
void radv_nir_lower_rt_io(nir_shader *shader, bool monolithic, uint32_t payload_offset);
|
||||
|
||||
void radv_nir_lower_rt_abi(nir_shader *shader, const VkRayTracingPipelineCreateInfoKHR *pCreateInfo,
|
||||
const struct radv_shader_args *args, const struct radv_shader_info *info,
|
||||
|
|
|
|||
|
|
@ -915,10 +915,8 @@ nir_split_array_vars(nir_shader *shader, nir_variable_mode modes)
|
|||
struct hash_table *var_info_map = _mesa_pointer_hash_table_create(mem_ctx);
|
||||
struct set *complex_vars = NULL;
|
||||
|
||||
assert((modes & (nir_var_shader_temp | nir_var_ray_hit_attrib | nir_var_function_temp)) == modes);
|
||||
|
||||
bool has_global_array = false;
|
||||
if (modes & (nir_var_shader_temp | nir_var_ray_hit_attrib)) {
|
||||
if (modes & (~nir_var_function_temp)) {
|
||||
has_global_array = init_var_list_array_infos(shader,
|
||||
&shader->variables,
|
||||
modes,
|
||||
|
|
@ -953,7 +951,7 @@ nir_split_array_vars(nir_shader *shader, nir_variable_mode modes)
|
|||
}
|
||||
|
||||
bool has_global_splits = false;
|
||||
if (modes & (nir_var_shader_temp | nir_var_ray_hit_attrib)) {
|
||||
if (modes & (~nir_var_function_temp)) {
|
||||
has_global_splits = split_var_list_arrays(shader, NULL,
|
||||
&shader->variables,
|
||||
modes,
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue