poly: Generalize unroll_restart() to arbitrary workgroup/subgroup sizes

The original asahi code assumed a subgroup size of 32 and a workgroup
size of 32 * 32 = 1024.  This makes doing ctz(ballot(b)) across an
entire workgroup an almost trivial operation.  On panfrost, we won't be
so blessed unless we choose a workgroup size of 16 * 16 = 256.  It's
also not clear that we want to use workgroups at all and we may better
off sticking to just subgroup parallelism and cutting out memory
bandwidth by more than half.  With the new code, the only requirement
should be that the subgroup size is a power of two (this is always true)
and that the workgroup size is an even multiple of the subgroup size.

Even though the new code looks way more complicated, thanks to the magic
of NIR constant folding, it should all fold down to the original code on
asahi and something even smaller if one opts to go for a single subgroup.

Acked-by: Alyssa Rosenzweig <alyssa.rosenzweig@intel.com>
Reviewed-by: Mary Guillemard <mary@mary.zone>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/38404>
This commit is contained in:
Faith Ekstrand 2025-11-14 16:46:24 -05:00 committed by Marge Bot
parent d9f795e6d0
commit ed0998ca98

View file

@ -10,7 +10,75 @@
#include "poly/prim.h"
#define POLY_DECL_UNROLL_RESTART_SCRATCH(__scratch, __wg_size) \
local uint __scratch[MAX2(__wg_size / 32, sizeof(void *))]
local uchar __scratch[MAX2(__wg_size / 8, sizeof(void *))]
static inline void
poly_store_local_ballot_arr(local void *dst, uint idx, uint4 ballot)
{
const uint sg_size = get_sub_group_size();
if (sg_size == 8)
((uchar *)dst)[idx] = ballot.x;
else if (sg_size == 16)
((ushort *)dst)[idx] = ballot.x;
else if (sg_size == 32)
((uint *)dst)[idx] = ballot.x;
else if (sg_size == 64)
((uint2 *)dst)[idx] = ballot.xy;
else if (sg_size == 128)
((uint4 *)dst)[idx] = ballot;
}
static inline uint4
poly_load_local_ballot_arr(local void *src, uint idx)
{
const uint sg_size = get_sub_group_size();
uint4 ballot = (uint4)(0);
if (sg_size == 8)
ballot.x = ((uchar *)src)[idx];
else if (sg_size == 16)
ballot.x = ((ushort *)src)[idx];
else if (sg_size == 32)
ballot.x = ((uint *)src)[idx];
else if (sg_size == 64)
ballot.xy = ((uint2 *)src)[idx];
else if (sg_size == 128)
ballot = ((uint4 *)src)[idx];
return ballot;
}
/* sub_group_ballot_find_lsb() doesn't have a defined return value when the
* ballot is empty so we need our own helper.
*/
static uint
poly_ballot_ctz(uint4 ballot)
{
const uint sg_size = get_sub_group_size();
if (ballot.x)
return ctz(ballot.x);
if (sg_size > 32 && ballot.y)
return 32 + ctz(ballot.y);
if (sg_size > 64 && ballot.z)
return 64 + ctz(ballot.z);
if (sg_size > 96 && ballot.w)
return 96 + ctz(ballot.w);
return sg_size;
}
static inline uint4
poly_sub_group_broadcast_uint4(uint4 val, uint lane)
{
uint4 bval;
bval.x = sub_group_broadcast(val.x, lane);
bval.y = sub_group_broadcast(val.y, lane);
bval.z = sub_group_broadcast(val.z, lane);
bval.w = sub_group_broadcast(val.w, lane);
return bval;
}
/*
* Return the ID of the first thread in the workgroup where cond is true, or
@ -18,16 +86,44 @@
* the workgroup.
*/
static inline uint
poly_work_group_first_true(bool cond, local uint *scratch)
poly_work_group_first_true(bool cond, local void *scratch)
{
barrier(CLK_LOCAL_MEM_FENCE);
scratch[get_sub_group_id()] = sub_group_ballot(cond)[0];
const uint sg_size = get_sub_group_size();
const uint num_sg = get_num_sub_groups();
uint4 ballot = sub_group_ballot(cond);
if (num_sg == 1)
return poly_ballot_ctz(ballot);
barrier(CLK_LOCAL_MEM_FENCE);
uint first_group =
ctz(sub_group_ballot(scratch[get_sub_group_local_id()])[0]);
uint off = ctz(first_group < 32 ? scratch[first_group] : 0);
return (first_group * 32) + off;
if (get_sub_group_local_id() == 0)
poly_store_local_ballot_arr(scratch, get_sub_group_id(), ballot);
barrier(CLK_LOCAL_MEM_FENCE);
for (uint32_t i = 0; i < num_sg; i += sg_size) {
/* Read one subgroup worth per invocation */
uint src_sg_id = i + get_sub_group_local_id();
/* Clamp src_sg_id so we don't read OOB if the number of sugroups is not
* a multiple of the subgroup size. It's safe to repeat the top index
* because the top indices will all be the same and we'll always take
* the first one.
*/
src_sg_id = min(src_sg_id, num_sg - 1);
ballot = poly_load_local_ballot_arr(scratch, src_sg_id);
uint4 wide_ballot = sub_group_ballot(any(ballot != (uint4)(0)));
if (all(wide_ballot == (uint4)(0)))
continue;
uint first_sg = poly_ballot_ctz(wide_ballot);
uint4 first_ballot = poly_sub_group_broadcast_uint4(ballot, first_sg);
return (i + first_sg) * sg_size + poly_ballot_ctz(first_ballot);
}
return num_sg * sg_size;
}
/*
@ -84,11 +180,16 @@ poly_unroll_restart(global uint32_t *out_draw,
if (tid == 0) {
out_ptr = (uintptr_t)poly_setup_unroll_for_draw(heap, in_draw, out_draw,
mode, index_size_B);
if (get_num_sub_groups() > 1)
*(uintptr_t *)scratch = out_ptr;
}
if (get_num_sub_groups() > 1) {
barrier(CLK_LOCAL_MEM_FENCE);
out_ptr = *(uintptr_t *)scratch;
} else {
out_ptr = sub_group_broadcast(out_ptr, 0);
}
uintptr_t in_ptr = (uintptr_t)(poly_index_buffer(
index_buffer, index_buffer_range_el, in_draw[2], index_size_B));
@ -108,7 +209,7 @@ poly_unroll_restart(global uint32_t *out_draw,
uint next_offs = poly_work_group_first_true(restart, scratch);
next_restart += next_offs;
if (next_offs < 1024)
if (next_offs < cl_local_size.x)
break;
}
@ -116,10 +217,10 @@ poly_unroll_restart(global uint32_t *out_draw,
uint subcount = next_restart - needle;
uint subprims = u_decomposed_prims_for_vertices(mode, subcount);
uint out_prims_base = out_prims;
for (uint i = tid; i < subprims; i += 1024) {
for (uint i = tid; i < subprims; i += cl_local_size.x) {
for (uint vtx = 0; vtx < per_prim; ++vtx) {
uint id =
poly_vertex_id_for_topology(mode, flatshade_first, i, vtx, subprims);
uint id = poly_vertex_id_for_topology(mode, flatshade_first, i,
vtx, subprims);
uint offset = needle + id;
uint x = ((out_prims_base + i) * per_prim) + vtx;