From 7b526de18ffcfd61fa63aecc9980a2d4b157bcf5 Mon Sep 17 00:00:00 2001 From: Kevin Chuang Date: Tue, 25 Mar 2025 18:55:15 -0700 Subject: [PATCH] intel/compiler/rt: Calculate barycentrics on demand This commit moves the calculation of tri_bary out of brw_nir_rt_load_mem_hit_from_addr(), and only do the calculation on demand, since unorm_float_convert can be expensive. We do this for both Xe1/2 and Xe3+ for consistency. Signed-off-by: Kevin Chuang Reviewed-by: Sagar Ghuge Reviewed-by: Lionel Landwerlin Part-of: --- .../compiler/brw_nir_lower_ray_queries.c | 3 +- src/intel/compiler/brw_nir_rt.c | 51 +++++++++++-------- src/intel/compiler/brw_nir_rt.h | 12 +++-- src/intel/compiler/brw_nir_rt_builder.h | 37 ++++++++++---- src/intel/vulkan/anv_pipeline.c | 8 +-- 5 files changed, 73 insertions(+), 38 deletions(-) diff --git a/src/intel/compiler/brw_nir_lower_ray_queries.c b/src/intel/compiler/brw_nir_lower_ray_queries.c index e3362677c6a..d11e98fb55e 100644 --- a/src/intel/compiler/brw_nir_lower_ray_queries.c +++ b/src/intel/compiler/brw_nir_lower_ray_queries.c @@ -432,7 +432,8 @@ lower_ray_query_intrinsic(nir_builder *b, break; case nir_ray_query_value_intersection_barycentrics: - sysval = hit_in.tri_bary; + sysval = brw_nir_rt_load_tri_bary_from_addr(b, stack_addr, committed, + state->devinfo); break; case nir_ray_query_value_intersection_front_face: diff --git a/src/intel/compiler/brw_nir_rt.c b/src/intel/compiler/brw_nir_rt.c index 853a6ece8fc..7c869d19eae 100644 --- a/src/intel/compiler/brw_nir_rt.c +++ b/src/intel/compiler/brw_nir_rt.c @@ -55,7 +55,7 @@ resize_deref(nir_builder *b, nir_deref_instr *deref, } static bool -lower_rt_io_derefs(nir_shader *shader) +lower_rt_io_derefs(nir_shader *shader, const struct intel_device_info *devinfo) { nir_function_impl *impl = nir_shader_get_entrypoint(shader); @@ -94,13 +94,24 @@ lower_rt_io_derefs(nir_shader *shader) assert(stage == MESA_SHADER_ANY_HIT || stage == MESA_SHADER_CLOSEST_HIT || stage == MESA_SHADER_INTERSECTION); - nir_def *hit_addr = - brw_nir_rt_mem_hit_addr(&b, stage == MESA_SHADER_CLOSEST_HIT); - /* The vec2 barycentrics are in 2nd and 3rd dwords of MemHit */ - nir_def *bary_addr = nir_iadd_imm(&b, hit_addr, 4); - hit_attrib_addr = nir_bcsel(&b, nir_load_leaf_procedural_intel(&b), - brw_nir_rt_hit_attrib_data_addr(&b), - bary_addr); + hit_attrib_addr = brw_nir_rt_hit_attrib_data_addr(&b); + + /* For tri, we store tri_bary at hit_attrib_data_addr. + * The reason we don't directly provide the address where u and v is + * located is that for Xe3+ u and v needs extra unorm_to_float + * calculation, so we write the computed value to hit_attrib_data_addr + * for shader to dereference. + */ + nir_push_if(&b, nir_inot(&b, nir_load_leaf_procedural_intel(&b))); + { + nir_def* tri_bary = + brw_nir_rt_load_tri_bary_from_addr(&b, + brw_nir_rt_stack_addr(&b), + stage == MESA_SHADER_CLOSEST_HIT, + devinfo); + nir_store_global(&b, hit_attrib_addr, 4, tri_bary, 0x3); + } + nir_pop_if(&b, NULL); progress = true; } @@ -183,7 +194,7 @@ lower_rt_io_derefs(nir_shader *shader) * variable down the call stack. */ static void -lower_rt_io_and_scratch(nir_shader *nir) +lower_rt_io_and_scratch(nir_shader *nir, const struct intel_device_info *devinfo) { /* First, we to ensure all the I/O variables have explicit types. Because * these are shader-internal and don't come in from outside, they don't @@ -196,7 +207,7 @@ lower_rt_io_and_scratch(nir_shader *nir) glsl_get_natural_size_align_bytes); /* Now patch any derefs to I/O vars */ - NIR_PASS_V(nir, lower_rt_io_derefs); + NIR_PASS_V(nir, lower_rt_io_derefs, devinfo); /* Finally, lower any remaining function_temp, mem_constant, or * ray_hit_attrib access to 64-bit global memory access. @@ -329,11 +340,11 @@ lower_ray_walk_intrinsics(nir_shader *shader, } void -brw_nir_lower_raygen(nir_shader *nir) +brw_nir_lower_raygen(nir_shader *nir, const struct intel_device_info *devinfo) { assert(nir->info.stage == MESA_SHADER_RAYGEN); NIR_PASS_V(nir, brw_nir_lower_shader_returns); - lower_rt_io_and_scratch(nir); + lower_rt_io_and_scratch(nir, devinfo); } void @@ -342,31 +353,31 @@ brw_nir_lower_any_hit(nir_shader *nir, const struct intel_device_info *devinfo) assert(nir->info.stage == MESA_SHADER_ANY_HIT); NIR_PASS_V(nir, brw_nir_lower_shader_returns); NIR_PASS_V(nir, lower_ray_walk_intrinsics, devinfo); - lower_rt_io_and_scratch(nir); + lower_rt_io_and_scratch(nir, devinfo); } void -brw_nir_lower_closest_hit(nir_shader *nir) +brw_nir_lower_closest_hit(nir_shader *nir, const struct intel_device_info *devinfo) { assert(nir->info.stage == MESA_SHADER_CLOSEST_HIT); NIR_PASS_V(nir, brw_nir_lower_shader_returns); - lower_rt_io_and_scratch(nir); + lower_rt_io_and_scratch(nir, devinfo); } void -brw_nir_lower_miss(nir_shader *nir) +brw_nir_lower_miss(nir_shader *nir, const struct intel_device_info *devinfo) { assert(nir->info.stage == MESA_SHADER_MISS); NIR_PASS_V(nir, brw_nir_lower_shader_returns); - lower_rt_io_and_scratch(nir); + lower_rt_io_and_scratch(nir, devinfo); } void -brw_nir_lower_callable(nir_shader *nir) +brw_nir_lower_callable(nir_shader *nir, const struct intel_device_info *devinfo) { assert(nir->info.stage == MESA_SHADER_CALLABLE); NIR_PASS_V(nir, brw_nir_lower_shader_returns); - lower_rt_io_and_scratch(nir); + lower_rt_io_and_scratch(nir, devinfo); } void @@ -380,7 +391,7 @@ brw_nir_lower_combined_intersection_any_hit(nir_shader *intersection, NIR_PASS_V(intersection, brw_nir_lower_intersection_shader, any_hit, devinfo); NIR_PASS_V(intersection, lower_ray_walk_intrinsics, devinfo); - lower_rt_io_and_scratch(intersection); + lower_rt_io_and_scratch(intersection, devinfo); } static nir_def * diff --git a/src/intel/compiler/brw_nir_rt.h b/src/intel/compiler/brw_nir_rt.h index 0893016d51c..c577a25617b 100644 --- a/src/intel/compiler/brw_nir_rt.h +++ b/src/intel/compiler/brw_nir_rt.h @@ -30,12 +30,16 @@ extern "C" { #endif -void brw_nir_lower_raygen(nir_shader *nir); +void brw_nir_lower_raygen(nir_shader *nir, + const struct intel_device_info *devinfo); void brw_nir_lower_any_hit(nir_shader *nir, const struct intel_device_info *devinfo); -void brw_nir_lower_closest_hit(nir_shader *nir); -void brw_nir_lower_miss(nir_shader *nir); -void brw_nir_lower_callable(nir_shader *nir); +void brw_nir_lower_closest_hit(nir_shader *nir, + const struct intel_device_info *devinfo); +void brw_nir_lower_miss(nir_shader *nir, + const struct intel_device_info *devinfo); +void brw_nir_lower_callable(nir_shader *nir, + const struct intel_device_info *devinfo); void brw_nir_lower_combined_intersection_any_hit(nir_shader *intersection, const nir_shader *any_hit, const struct intel_device_info *devinfo); diff --git a/src/intel/compiler/brw_nir_rt_builder.h b/src/intel/compiler/brw_nir_rt_builder.h index 9bae756ef6f..987997bafcb 100644 --- a/src/intel/compiler/brw_nir_rt_builder.h +++ b/src/intel/compiler/brw_nir_rt_builder.h @@ -465,7 +465,6 @@ brw_nir_rt_unpack_leaf_ptr(nir_builder *b, nir_def *vec2, */ struct brw_nir_rt_mem_hit_defs { nir_def *t; - nir_def *tri_bary; /**< Only valid for triangle geometry */ nir_def *aabb_hit_kind; /**< Only valid for AABB geometry */ nir_def *valid; nir_def *leaf_type; @@ -478,6 +477,34 @@ struct brw_nir_rt_mem_hit_defs { nir_def *inst_leaf_ptr; }; +/* For Xe3+, barycentric coordinates are stored as 24 bit unorm. + * Since unorm_float could be expensive, we calculate tri_bary on + * demand. We do this for Xe3+ and Xe1/2 for consistency. + */ +static inline nir_def * +brw_nir_rt_load_tri_bary_from_addr(nir_builder *b, + nir_def *stack_addr, + bool committed, + const struct intel_device_info *devinfo) +{ + nir_def *hit_addr = + brw_nir_rt_mem_hit_addr_from_addr(b, stack_addr, committed); + + nir_def *data = brw_nir_rt_load(b, hit_addr, 16, 4, 32); + nir_def *tri_bary; + if (devinfo->ver >= 30) { + nir_def *u = nir_iand_imm(b, nir_channel(b, data, 1), 0xffffff); + nir_def *v = nir_iand_imm(b, nir_channel(b, data, 2), 0xffffff); + const unsigned bits[1] = {24}; + tri_bary = nir_vec2(b, + nir_format_unorm_to_float_precise(b, u, bits), + nir_format_unorm_to_float_precise(b, v, bits)); + } else { + tri_bary = nir_channels(b, data, 0x6); + } + + return tri_bary; +} static inline void brw_nir_rt_load_mem_hit_from_addr(nir_builder *b, struct brw_nir_rt_mem_hit_defs *defs, @@ -495,19 +522,11 @@ brw_nir_rt_load_mem_hit_from_addr(nir_builder *b, if (devinfo->ver >= 30) { defs->aabb_hit_kind = nir_iand_imm(b, nir_channel(b, data, 1), 0xffffff); - nir_def *u = nir_iand_imm(b, nir_channel(b, data, 1), 0xffffff); - nir_def *v = nir_iand_imm(b, nir_channel(b, data, 2), 0xffffff); - /* For Xe3+, barycentric coordinates are stored as 24 bit unorm */ - const unsigned bits[1] = {24}; - defs->tri_bary = nir_vec2(b, - nir_format_unorm_to_float_precise(b, u, bits), - nir_format_unorm_to_float_precise(b, v, bits)); defs->prim_index_delta = nir_ubitfield_extract(b, bitfield, nir_imm_int(b, 0), nir_imm_int(b, 5)); } else { defs->aabb_hit_kind = nir_channel(b, data, 1); - defs->tri_bary = nir_channels(b, data, 0x6); defs->prim_index_delta = nir_ubitfield_extract(b, bitfield, nir_imm_int(b, 0), nir_imm_int(b, 16)); diff --git a/src/intel/vulkan/anv_pipeline.c b/src/intel/vulkan/anv_pipeline.c index 74daba16286..d65afaf99c0 100644 --- a/src/intel/vulkan/anv_pipeline.c +++ b/src/intel/vulkan/anv_pipeline.c @@ -3728,7 +3728,7 @@ anv_pipeline_compile_ray_tracing(struct anv_ray_tracing_pipeline *pipeline, nir_shader *nir = nir_shader_clone(tmp_stage_ctx, stages[i].nir); switch (stages[i].stage) { case MESA_SHADER_RAYGEN: - brw_nir_lower_raygen(nir); + brw_nir_lower_raygen(nir, devinfo); break; case MESA_SHADER_ANY_HIT: @@ -3736,18 +3736,18 @@ anv_pipeline_compile_ray_tracing(struct anv_ray_tracing_pipeline *pipeline, break; case MESA_SHADER_CLOSEST_HIT: - brw_nir_lower_closest_hit(nir); + brw_nir_lower_closest_hit(nir, devinfo); break; case MESA_SHADER_MISS: - brw_nir_lower_miss(nir); + brw_nir_lower_miss(nir, devinfo); break; case MESA_SHADER_INTERSECTION: unreachable("These are handled later"); case MESA_SHADER_CALLABLE: - brw_nir_lower_callable(nir); + brw_nir_lower_callable(nir, devinfo); break; default: