intel/compiler/rt: Calculate barycentrics on demand

This commit moves the calculation of tri_bary out of
brw_nir_rt_load_mem_hit_from_addr(), and only do the calculation on
demand, since unorm_float_convert can be expensive. We do this for both
Xe1/2 and Xe3+ for consistency.

Signed-off-by: Kevin Chuang <kaiwenjon23@gmail.com>
Reviewed-by: Sagar Ghuge <sagar.ghuge@intel.com>
Reviewed-by: Lionel Landwerlin <lionel.g.landwerlin@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/33047>
This commit is contained in:
Kevin Chuang 2025-03-25 18:55:15 -07:00 committed by Marge Bot
parent afc23dffa4
commit 7b526de18f
5 changed files with 73 additions and 38 deletions

View file

@ -432,7 +432,8 @@ lower_ray_query_intrinsic(nir_builder *b,
break;
case nir_ray_query_value_intersection_barycentrics:
sysval = hit_in.tri_bary;
sysval = brw_nir_rt_load_tri_bary_from_addr(b, stack_addr, committed,
state->devinfo);
break;
case nir_ray_query_value_intersection_front_face:

View file

@ -55,7 +55,7 @@ resize_deref(nir_builder *b, nir_deref_instr *deref,
}
static bool
lower_rt_io_derefs(nir_shader *shader)
lower_rt_io_derefs(nir_shader *shader, const struct intel_device_info *devinfo)
{
nir_function_impl *impl = nir_shader_get_entrypoint(shader);
@ -94,13 +94,24 @@ lower_rt_io_derefs(nir_shader *shader)
assert(stage == MESA_SHADER_ANY_HIT ||
stage == MESA_SHADER_CLOSEST_HIT ||
stage == MESA_SHADER_INTERSECTION);
nir_def *hit_addr =
brw_nir_rt_mem_hit_addr(&b, stage == MESA_SHADER_CLOSEST_HIT);
/* The vec2 barycentrics are in 2nd and 3rd dwords of MemHit */
nir_def *bary_addr = nir_iadd_imm(&b, hit_addr, 4);
hit_attrib_addr = nir_bcsel(&b, nir_load_leaf_procedural_intel(&b),
brw_nir_rt_hit_attrib_data_addr(&b),
bary_addr);
hit_attrib_addr = brw_nir_rt_hit_attrib_data_addr(&b);
/* For tri, we store tri_bary at hit_attrib_data_addr.
* The reason we don't directly provide the address where u and v is
* located is that for Xe3+ u and v needs extra unorm_to_float
* calculation, so we write the computed value to hit_attrib_data_addr
* for shader to dereference.
*/
nir_push_if(&b, nir_inot(&b, nir_load_leaf_procedural_intel(&b)));
{
nir_def* tri_bary =
brw_nir_rt_load_tri_bary_from_addr(&b,
brw_nir_rt_stack_addr(&b),
stage == MESA_SHADER_CLOSEST_HIT,
devinfo);
nir_store_global(&b, hit_attrib_addr, 4, tri_bary, 0x3);
}
nir_pop_if(&b, NULL);
progress = true;
}
@ -183,7 +194,7 @@ lower_rt_io_derefs(nir_shader *shader)
* variable down the call stack.
*/
static void
lower_rt_io_and_scratch(nir_shader *nir)
lower_rt_io_and_scratch(nir_shader *nir, const struct intel_device_info *devinfo)
{
/* First, we to ensure all the I/O variables have explicit types. Because
* these are shader-internal and don't come in from outside, they don't
@ -196,7 +207,7 @@ lower_rt_io_and_scratch(nir_shader *nir)
glsl_get_natural_size_align_bytes);
/* Now patch any derefs to I/O vars */
NIR_PASS_V(nir, lower_rt_io_derefs);
NIR_PASS_V(nir, lower_rt_io_derefs, devinfo);
/* Finally, lower any remaining function_temp, mem_constant, or
* ray_hit_attrib access to 64-bit global memory access.
@ -329,11 +340,11 @@ lower_ray_walk_intrinsics(nir_shader *shader,
}
void
brw_nir_lower_raygen(nir_shader *nir)
brw_nir_lower_raygen(nir_shader *nir, const struct intel_device_info *devinfo)
{
assert(nir->info.stage == MESA_SHADER_RAYGEN);
NIR_PASS_V(nir, brw_nir_lower_shader_returns);
lower_rt_io_and_scratch(nir);
lower_rt_io_and_scratch(nir, devinfo);
}
void
@ -342,31 +353,31 @@ brw_nir_lower_any_hit(nir_shader *nir, const struct intel_device_info *devinfo)
assert(nir->info.stage == MESA_SHADER_ANY_HIT);
NIR_PASS_V(nir, brw_nir_lower_shader_returns);
NIR_PASS_V(nir, lower_ray_walk_intrinsics, devinfo);
lower_rt_io_and_scratch(nir);
lower_rt_io_and_scratch(nir, devinfo);
}
void
brw_nir_lower_closest_hit(nir_shader *nir)
brw_nir_lower_closest_hit(nir_shader *nir, const struct intel_device_info *devinfo)
{
assert(nir->info.stage == MESA_SHADER_CLOSEST_HIT);
NIR_PASS_V(nir, brw_nir_lower_shader_returns);
lower_rt_io_and_scratch(nir);
lower_rt_io_and_scratch(nir, devinfo);
}
void
brw_nir_lower_miss(nir_shader *nir)
brw_nir_lower_miss(nir_shader *nir, const struct intel_device_info *devinfo)
{
assert(nir->info.stage == MESA_SHADER_MISS);
NIR_PASS_V(nir, brw_nir_lower_shader_returns);
lower_rt_io_and_scratch(nir);
lower_rt_io_and_scratch(nir, devinfo);
}
void
brw_nir_lower_callable(nir_shader *nir)
brw_nir_lower_callable(nir_shader *nir, const struct intel_device_info *devinfo)
{
assert(nir->info.stage == MESA_SHADER_CALLABLE);
NIR_PASS_V(nir, brw_nir_lower_shader_returns);
lower_rt_io_and_scratch(nir);
lower_rt_io_and_scratch(nir, devinfo);
}
void
@ -380,7 +391,7 @@ brw_nir_lower_combined_intersection_any_hit(nir_shader *intersection,
NIR_PASS_V(intersection, brw_nir_lower_intersection_shader,
any_hit, devinfo);
NIR_PASS_V(intersection, lower_ray_walk_intrinsics, devinfo);
lower_rt_io_and_scratch(intersection);
lower_rt_io_and_scratch(intersection, devinfo);
}
static nir_def *

View file

@ -30,12 +30,16 @@
extern "C" {
#endif
void brw_nir_lower_raygen(nir_shader *nir);
void brw_nir_lower_raygen(nir_shader *nir,
const struct intel_device_info *devinfo);
void brw_nir_lower_any_hit(nir_shader *nir,
const struct intel_device_info *devinfo);
void brw_nir_lower_closest_hit(nir_shader *nir);
void brw_nir_lower_miss(nir_shader *nir);
void brw_nir_lower_callable(nir_shader *nir);
void brw_nir_lower_closest_hit(nir_shader *nir,
const struct intel_device_info *devinfo);
void brw_nir_lower_miss(nir_shader *nir,
const struct intel_device_info *devinfo);
void brw_nir_lower_callable(nir_shader *nir,
const struct intel_device_info *devinfo);
void brw_nir_lower_combined_intersection_any_hit(nir_shader *intersection,
const nir_shader *any_hit,
const struct intel_device_info *devinfo);

View file

@ -465,7 +465,6 @@ brw_nir_rt_unpack_leaf_ptr(nir_builder *b, nir_def *vec2,
*/
struct brw_nir_rt_mem_hit_defs {
nir_def *t;
nir_def *tri_bary; /**< Only valid for triangle geometry */
nir_def *aabb_hit_kind; /**< Only valid for AABB geometry */
nir_def *valid;
nir_def *leaf_type;
@ -478,6 +477,34 @@ struct brw_nir_rt_mem_hit_defs {
nir_def *inst_leaf_ptr;
};
/* For Xe3+, barycentric coordinates are stored as 24 bit unorm.
* Since unorm_float could be expensive, we calculate tri_bary on
* demand. We do this for Xe3+ and Xe1/2 for consistency.
*/
static inline nir_def *
brw_nir_rt_load_tri_bary_from_addr(nir_builder *b,
nir_def *stack_addr,
bool committed,
const struct intel_device_info *devinfo)
{
nir_def *hit_addr =
brw_nir_rt_mem_hit_addr_from_addr(b, stack_addr, committed);
nir_def *data = brw_nir_rt_load(b, hit_addr, 16, 4, 32);
nir_def *tri_bary;
if (devinfo->ver >= 30) {
nir_def *u = nir_iand_imm(b, nir_channel(b, data, 1), 0xffffff);
nir_def *v = nir_iand_imm(b, nir_channel(b, data, 2), 0xffffff);
const unsigned bits[1] = {24};
tri_bary = nir_vec2(b,
nir_format_unorm_to_float_precise(b, u, bits),
nir_format_unorm_to_float_precise(b, v, bits));
} else {
tri_bary = nir_channels(b, data, 0x6);
}
return tri_bary;
}
static inline void
brw_nir_rt_load_mem_hit_from_addr(nir_builder *b,
struct brw_nir_rt_mem_hit_defs *defs,
@ -495,19 +522,11 @@ brw_nir_rt_load_mem_hit_from_addr(nir_builder *b,
if (devinfo->ver >= 30) {
defs->aabb_hit_kind = nir_iand_imm(b, nir_channel(b, data, 1),
0xffffff);
nir_def *u = nir_iand_imm(b, nir_channel(b, data, 1), 0xffffff);
nir_def *v = nir_iand_imm(b, nir_channel(b, data, 2), 0xffffff);
/* For Xe3+, barycentric coordinates are stored as 24 bit unorm */
const unsigned bits[1] = {24};
defs->tri_bary = nir_vec2(b,
nir_format_unorm_to_float_precise(b, u, bits),
nir_format_unorm_to_float_precise(b, v, bits));
defs->prim_index_delta = nir_ubitfield_extract(b, bitfield,
nir_imm_int(b, 0),
nir_imm_int(b, 5));
} else {
defs->aabb_hit_kind = nir_channel(b, data, 1);
defs->tri_bary = nir_channels(b, data, 0x6);
defs->prim_index_delta = nir_ubitfield_extract(b, bitfield,
nir_imm_int(b, 0),
nir_imm_int(b, 16));

View file

@ -3728,7 +3728,7 @@ anv_pipeline_compile_ray_tracing(struct anv_ray_tracing_pipeline *pipeline,
nir_shader *nir = nir_shader_clone(tmp_stage_ctx, stages[i].nir);
switch (stages[i].stage) {
case MESA_SHADER_RAYGEN:
brw_nir_lower_raygen(nir);
brw_nir_lower_raygen(nir, devinfo);
break;
case MESA_SHADER_ANY_HIT:
@ -3736,18 +3736,18 @@ anv_pipeline_compile_ray_tracing(struct anv_ray_tracing_pipeline *pipeline,
break;
case MESA_SHADER_CLOSEST_HIT:
brw_nir_lower_closest_hit(nir);
brw_nir_lower_closest_hit(nir, devinfo);
break;
case MESA_SHADER_MISS:
brw_nir_lower_miss(nir);
brw_nir_lower_miss(nir, devinfo);
break;
case MESA_SHADER_INTERSECTION:
unreachable("These are handled later");
case MESA_SHADER_CALLABLE:
brw_nir_lower_callable(nir);
brw_nir_lower_callable(nir, devinfo);
break;
default: