intel/compiler/rt: Calculate barycentrics on demand

This commit moves the calculation of tri_bary out of
brw_nir_rt_load_mem_hit_from_addr(), and only do the calculation on
demand, since unorm_float_convert can be expensive. We do this for both
Xe1/2 and Xe3+ for consistency.

Signed-off-by: Kevin Chuang <kaiwenjon23@gmail.com>
Reviewed-by: Sagar Ghuge <sagar.ghuge@intel.com>
Reviewed-by: Lionel Landwerlin <lionel.g.landwerlin@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/33047>
This commit is contained in:
Kevin Chuang 2025-03-25 18:55:15 -07:00 committed by Marge Bot
parent afc23dffa4
commit 7b526de18f
5 changed files with 73 additions and 38 deletions

View file

@ -432,7 +432,8 @@ lower_ray_query_intrinsic(nir_builder *b,
break; break;
case nir_ray_query_value_intersection_barycentrics: case nir_ray_query_value_intersection_barycentrics:
sysval = hit_in.tri_bary; sysval = brw_nir_rt_load_tri_bary_from_addr(b, stack_addr, committed,
state->devinfo);
break; break;
case nir_ray_query_value_intersection_front_face: case nir_ray_query_value_intersection_front_face:

View file

@ -55,7 +55,7 @@ resize_deref(nir_builder *b, nir_deref_instr *deref,
} }
static bool static bool
lower_rt_io_derefs(nir_shader *shader) lower_rt_io_derefs(nir_shader *shader, const struct intel_device_info *devinfo)
{ {
nir_function_impl *impl = nir_shader_get_entrypoint(shader); nir_function_impl *impl = nir_shader_get_entrypoint(shader);
@ -94,13 +94,24 @@ lower_rt_io_derefs(nir_shader *shader)
assert(stage == MESA_SHADER_ANY_HIT || assert(stage == MESA_SHADER_ANY_HIT ||
stage == MESA_SHADER_CLOSEST_HIT || stage == MESA_SHADER_CLOSEST_HIT ||
stage == MESA_SHADER_INTERSECTION); stage == MESA_SHADER_INTERSECTION);
nir_def *hit_addr = hit_attrib_addr = brw_nir_rt_hit_attrib_data_addr(&b);
brw_nir_rt_mem_hit_addr(&b, stage == MESA_SHADER_CLOSEST_HIT);
/* The vec2 barycentrics are in 2nd and 3rd dwords of MemHit */ /* For tri, we store tri_bary at hit_attrib_data_addr.
nir_def *bary_addr = nir_iadd_imm(&b, hit_addr, 4); * The reason we don't directly provide the address where u and v is
hit_attrib_addr = nir_bcsel(&b, nir_load_leaf_procedural_intel(&b), * located is that for Xe3+ u and v needs extra unorm_to_float
brw_nir_rt_hit_attrib_data_addr(&b), * calculation, so we write the computed value to hit_attrib_data_addr
bary_addr); * for shader to dereference.
*/
nir_push_if(&b, nir_inot(&b, nir_load_leaf_procedural_intel(&b)));
{
nir_def* tri_bary =
brw_nir_rt_load_tri_bary_from_addr(&b,
brw_nir_rt_stack_addr(&b),
stage == MESA_SHADER_CLOSEST_HIT,
devinfo);
nir_store_global(&b, hit_attrib_addr, 4, tri_bary, 0x3);
}
nir_pop_if(&b, NULL);
progress = true; progress = true;
} }
@ -183,7 +194,7 @@ lower_rt_io_derefs(nir_shader *shader)
* variable down the call stack. * variable down the call stack.
*/ */
static void static void
lower_rt_io_and_scratch(nir_shader *nir) lower_rt_io_and_scratch(nir_shader *nir, const struct intel_device_info *devinfo)
{ {
/* First, we to ensure all the I/O variables have explicit types. Because /* First, we to ensure all the I/O variables have explicit types. Because
* these are shader-internal and don't come in from outside, they don't * these are shader-internal and don't come in from outside, they don't
@ -196,7 +207,7 @@ lower_rt_io_and_scratch(nir_shader *nir)
glsl_get_natural_size_align_bytes); glsl_get_natural_size_align_bytes);
/* Now patch any derefs to I/O vars */ /* Now patch any derefs to I/O vars */
NIR_PASS_V(nir, lower_rt_io_derefs); NIR_PASS_V(nir, lower_rt_io_derefs, devinfo);
/* Finally, lower any remaining function_temp, mem_constant, or /* Finally, lower any remaining function_temp, mem_constant, or
* ray_hit_attrib access to 64-bit global memory access. * ray_hit_attrib access to 64-bit global memory access.
@ -329,11 +340,11 @@ lower_ray_walk_intrinsics(nir_shader *shader,
} }
void void
brw_nir_lower_raygen(nir_shader *nir) brw_nir_lower_raygen(nir_shader *nir, const struct intel_device_info *devinfo)
{ {
assert(nir->info.stage == MESA_SHADER_RAYGEN); assert(nir->info.stage == MESA_SHADER_RAYGEN);
NIR_PASS_V(nir, brw_nir_lower_shader_returns); NIR_PASS_V(nir, brw_nir_lower_shader_returns);
lower_rt_io_and_scratch(nir); lower_rt_io_and_scratch(nir, devinfo);
} }
void void
@ -342,31 +353,31 @@ brw_nir_lower_any_hit(nir_shader *nir, const struct intel_device_info *devinfo)
assert(nir->info.stage == MESA_SHADER_ANY_HIT); assert(nir->info.stage == MESA_SHADER_ANY_HIT);
NIR_PASS_V(nir, brw_nir_lower_shader_returns); NIR_PASS_V(nir, brw_nir_lower_shader_returns);
NIR_PASS_V(nir, lower_ray_walk_intrinsics, devinfo); NIR_PASS_V(nir, lower_ray_walk_intrinsics, devinfo);
lower_rt_io_and_scratch(nir); lower_rt_io_and_scratch(nir, devinfo);
} }
void void
brw_nir_lower_closest_hit(nir_shader *nir) brw_nir_lower_closest_hit(nir_shader *nir, const struct intel_device_info *devinfo)
{ {
assert(nir->info.stage == MESA_SHADER_CLOSEST_HIT); assert(nir->info.stage == MESA_SHADER_CLOSEST_HIT);
NIR_PASS_V(nir, brw_nir_lower_shader_returns); NIR_PASS_V(nir, brw_nir_lower_shader_returns);
lower_rt_io_and_scratch(nir); lower_rt_io_and_scratch(nir, devinfo);
} }
void void
brw_nir_lower_miss(nir_shader *nir) brw_nir_lower_miss(nir_shader *nir, const struct intel_device_info *devinfo)
{ {
assert(nir->info.stage == MESA_SHADER_MISS); assert(nir->info.stage == MESA_SHADER_MISS);
NIR_PASS_V(nir, brw_nir_lower_shader_returns); NIR_PASS_V(nir, brw_nir_lower_shader_returns);
lower_rt_io_and_scratch(nir); lower_rt_io_and_scratch(nir, devinfo);
} }
void void
brw_nir_lower_callable(nir_shader *nir) brw_nir_lower_callable(nir_shader *nir, const struct intel_device_info *devinfo)
{ {
assert(nir->info.stage == MESA_SHADER_CALLABLE); assert(nir->info.stage == MESA_SHADER_CALLABLE);
NIR_PASS_V(nir, brw_nir_lower_shader_returns); NIR_PASS_V(nir, brw_nir_lower_shader_returns);
lower_rt_io_and_scratch(nir); lower_rt_io_and_scratch(nir, devinfo);
} }
void void
@ -380,7 +391,7 @@ brw_nir_lower_combined_intersection_any_hit(nir_shader *intersection,
NIR_PASS_V(intersection, brw_nir_lower_intersection_shader, NIR_PASS_V(intersection, brw_nir_lower_intersection_shader,
any_hit, devinfo); any_hit, devinfo);
NIR_PASS_V(intersection, lower_ray_walk_intrinsics, devinfo); NIR_PASS_V(intersection, lower_ray_walk_intrinsics, devinfo);
lower_rt_io_and_scratch(intersection); lower_rt_io_and_scratch(intersection, devinfo);
} }
static nir_def * static nir_def *

View file

@ -30,12 +30,16 @@
extern "C" { extern "C" {
#endif #endif
void brw_nir_lower_raygen(nir_shader *nir); void brw_nir_lower_raygen(nir_shader *nir,
const struct intel_device_info *devinfo);
void brw_nir_lower_any_hit(nir_shader *nir, void brw_nir_lower_any_hit(nir_shader *nir,
const struct intel_device_info *devinfo); const struct intel_device_info *devinfo);
void brw_nir_lower_closest_hit(nir_shader *nir); void brw_nir_lower_closest_hit(nir_shader *nir,
void brw_nir_lower_miss(nir_shader *nir); const struct intel_device_info *devinfo);
void brw_nir_lower_callable(nir_shader *nir); void brw_nir_lower_miss(nir_shader *nir,
const struct intel_device_info *devinfo);
void brw_nir_lower_callable(nir_shader *nir,
const struct intel_device_info *devinfo);
void brw_nir_lower_combined_intersection_any_hit(nir_shader *intersection, void brw_nir_lower_combined_intersection_any_hit(nir_shader *intersection,
const nir_shader *any_hit, const nir_shader *any_hit,
const struct intel_device_info *devinfo); const struct intel_device_info *devinfo);

View file

@ -465,7 +465,6 @@ brw_nir_rt_unpack_leaf_ptr(nir_builder *b, nir_def *vec2,
*/ */
struct brw_nir_rt_mem_hit_defs { struct brw_nir_rt_mem_hit_defs {
nir_def *t; nir_def *t;
nir_def *tri_bary; /**< Only valid for triangle geometry */
nir_def *aabb_hit_kind; /**< Only valid for AABB geometry */ nir_def *aabb_hit_kind; /**< Only valid for AABB geometry */
nir_def *valid; nir_def *valid;
nir_def *leaf_type; nir_def *leaf_type;
@ -478,6 +477,34 @@ struct brw_nir_rt_mem_hit_defs {
nir_def *inst_leaf_ptr; nir_def *inst_leaf_ptr;
}; };
/* For Xe3+, barycentric coordinates are stored as 24 bit unorm.
* Since unorm_float could be expensive, we calculate tri_bary on
* demand. We do this for Xe3+ and Xe1/2 for consistency.
*/
static inline nir_def *
brw_nir_rt_load_tri_bary_from_addr(nir_builder *b,
nir_def *stack_addr,
bool committed,
const struct intel_device_info *devinfo)
{
nir_def *hit_addr =
brw_nir_rt_mem_hit_addr_from_addr(b, stack_addr, committed);
nir_def *data = brw_nir_rt_load(b, hit_addr, 16, 4, 32);
nir_def *tri_bary;
if (devinfo->ver >= 30) {
nir_def *u = nir_iand_imm(b, nir_channel(b, data, 1), 0xffffff);
nir_def *v = nir_iand_imm(b, nir_channel(b, data, 2), 0xffffff);
const unsigned bits[1] = {24};
tri_bary = nir_vec2(b,
nir_format_unorm_to_float_precise(b, u, bits),
nir_format_unorm_to_float_precise(b, v, bits));
} else {
tri_bary = nir_channels(b, data, 0x6);
}
return tri_bary;
}
static inline void static inline void
brw_nir_rt_load_mem_hit_from_addr(nir_builder *b, brw_nir_rt_load_mem_hit_from_addr(nir_builder *b,
struct brw_nir_rt_mem_hit_defs *defs, struct brw_nir_rt_mem_hit_defs *defs,
@ -495,19 +522,11 @@ brw_nir_rt_load_mem_hit_from_addr(nir_builder *b,
if (devinfo->ver >= 30) { if (devinfo->ver >= 30) {
defs->aabb_hit_kind = nir_iand_imm(b, nir_channel(b, data, 1), defs->aabb_hit_kind = nir_iand_imm(b, nir_channel(b, data, 1),
0xffffff); 0xffffff);
nir_def *u = nir_iand_imm(b, nir_channel(b, data, 1), 0xffffff);
nir_def *v = nir_iand_imm(b, nir_channel(b, data, 2), 0xffffff);
/* For Xe3+, barycentric coordinates are stored as 24 bit unorm */
const unsigned bits[1] = {24};
defs->tri_bary = nir_vec2(b,
nir_format_unorm_to_float_precise(b, u, bits),
nir_format_unorm_to_float_precise(b, v, bits));
defs->prim_index_delta = nir_ubitfield_extract(b, bitfield, defs->prim_index_delta = nir_ubitfield_extract(b, bitfield,
nir_imm_int(b, 0), nir_imm_int(b, 0),
nir_imm_int(b, 5)); nir_imm_int(b, 5));
} else { } else {
defs->aabb_hit_kind = nir_channel(b, data, 1); defs->aabb_hit_kind = nir_channel(b, data, 1);
defs->tri_bary = nir_channels(b, data, 0x6);
defs->prim_index_delta = nir_ubitfield_extract(b, bitfield, defs->prim_index_delta = nir_ubitfield_extract(b, bitfield,
nir_imm_int(b, 0), nir_imm_int(b, 0),
nir_imm_int(b, 16)); nir_imm_int(b, 16));

View file

@ -3728,7 +3728,7 @@ anv_pipeline_compile_ray_tracing(struct anv_ray_tracing_pipeline *pipeline,
nir_shader *nir = nir_shader_clone(tmp_stage_ctx, stages[i].nir); nir_shader *nir = nir_shader_clone(tmp_stage_ctx, stages[i].nir);
switch (stages[i].stage) { switch (stages[i].stage) {
case MESA_SHADER_RAYGEN: case MESA_SHADER_RAYGEN:
brw_nir_lower_raygen(nir); brw_nir_lower_raygen(nir, devinfo);
break; break;
case MESA_SHADER_ANY_HIT: case MESA_SHADER_ANY_HIT:
@ -3736,18 +3736,18 @@ anv_pipeline_compile_ray_tracing(struct anv_ray_tracing_pipeline *pipeline,
break; break;
case MESA_SHADER_CLOSEST_HIT: case MESA_SHADER_CLOSEST_HIT:
brw_nir_lower_closest_hit(nir); brw_nir_lower_closest_hit(nir, devinfo);
break; break;
case MESA_SHADER_MISS: case MESA_SHADER_MISS:
brw_nir_lower_miss(nir); brw_nir_lower_miss(nir, devinfo);
break; break;
case MESA_SHADER_INTERSECTION: case MESA_SHADER_INTERSECTION:
unreachable("These are handled later"); unreachable("These are handled later");
case MESA_SHADER_CALLABLE: case MESA_SHADER_CALLABLE:
brw_nir_lower_callable(nir); brw_nir_lower_callable(nir, devinfo);
break; break;
default: default: