diff --git a/src/poly/cl/restart.h b/src/poly/cl/restart.h index 384efc980e4..512f988549c 100644 --- a/src/poly/cl/restart.h +++ b/src/poly/cl/restart.h @@ -10,7 +10,75 @@ #include "poly/prim.h" #define POLY_DECL_UNROLL_RESTART_SCRATCH(__scratch, __wg_size) \ - local uint __scratch[MAX2(__wg_size / 32, sizeof(void *))] + local uchar __scratch[MAX2(__wg_size / 8, sizeof(void *))] + +static inline void +poly_store_local_ballot_arr(local void *dst, uint idx, uint4 ballot) +{ + const uint sg_size = get_sub_group_size(); + + if (sg_size == 8) + ((uchar *)dst)[idx] = ballot.x; + else if (sg_size == 16) + ((ushort *)dst)[idx] = ballot.x; + else if (sg_size == 32) + ((uint *)dst)[idx] = ballot.x; + else if (sg_size == 64) + ((uint2 *)dst)[idx] = ballot.xy; + else if (sg_size == 128) + ((uint4 *)dst)[idx] = ballot; +} + +static inline uint4 +poly_load_local_ballot_arr(local void *src, uint idx) +{ + const uint sg_size = get_sub_group_size(); + + uint4 ballot = (uint4)(0); + if (sg_size == 8) + ballot.x = ((uchar *)src)[idx]; + else if (sg_size == 16) + ballot.x = ((ushort *)src)[idx]; + else if (sg_size == 32) + ballot.x = ((uint *)src)[idx]; + else if (sg_size == 64) + ballot.xy = ((uint2 *)src)[idx]; + else if (sg_size == 128) + ballot = ((uint4 *)src)[idx]; + + return ballot; +} + +/* sub_group_ballot_find_lsb() doesn't have a defined return value when the + * ballot is empty so we need our own helper. + */ +static uint +poly_ballot_ctz(uint4 ballot) +{ + const uint sg_size = get_sub_group_size(); + + if (ballot.x) + return ctz(ballot.x); + if (sg_size > 32 && ballot.y) + return 32 + ctz(ballot.y); + if (sg_size > 64 && ballot.z) + return 64 + ctz(ballot.z); + if (sg_size > 96 && ballot.w) + return 96 + ctz(ballot.w); + + return sg_size; +} + +static inline uint4 +poly_sub_group_broadcast_uint4(uint4 val, uint lane) +{ + uint4 bval; + bval.x = sub_group_broadcast(val.x, lane); + bval.y = sub_group_broadcast(val.y, lane); + bval.z = sub_group_broadcast(val.z, lane); + bval.w = sub_group_broadcast(val.w, lane); + return bval; +} /* * Return the ID of the first thread in the workgroup where cond is true, or @@ -18,16 +86,44 @@ * the workgroup. */ static inline uint -poly_work_group_first_true(bool cond, local uint *scratch) +poly_work_group_first_true(bool cond, local void *scratch) { - barrier(CLK_LOCAL_MEM_FENCE); - scratch[get_sub_group_id()] = sub_group_ballot(cond)[0]; + const uint sg_size = get_sub_group_size(); + const uint num_sg = get_num_sub_groups(); + + uint4 ballot = sub_group_ballot(cond); + if (num_sg == 1) + return poly_ballot_ctz(ballot); + barrier(CLK_LOCAL_MEM_FENCE); - uint first_group = - ctz(sub_group_ballot(scratch[get_sub_group_local_id()])[0]); - uint off = ctz(first_group < 32 ? scratch[first_group] : 0); - return (first_group * 32) + off; + if (get_sub_group_local_id() == 0) + poly_store_local_ballot_arr(scratch, get_sub_group_id(), ballot); + + barrier(CLK_LOCAL_MEM_FENCE); + + for (uint32_t i = 0; i < num_sg; i += sg_size) { + /* Read one subgroup worth per invocation */ + uint src_sg_id = i + get_sub_group_local_id(); + + /* Clamp src_sg_id so we don't read OOB if the number of sugroups is not + * a multiple of the subgroup size. It's safe to repeat the top index + * because the top indices will all be the same and we'll always take + * the first one. + */ + src_sg_id = min(src_sg_id, num_sg - 1); + + ballot = poly_load_local_ballot_arr(scratch, src_sg_id); + uint4 wide_ballot = sub_group_ballot(any(ballot != (uint4)(0))); + if (all(wide_ballot == (uint4)(0))) + continue; + + uint first_sg = poly_ballot_ctz(wide_ballot); + uint4 first_ballot = poly_sub_group_broadcast_uint4(ballot, first_sg); + return (i + first_sg) * sg_size + poly_ballot_ctz(first_ballot); + } + + return num_sg * sg_size; } /* @@ -84,11 +180,16 @@ poly_unroll_restart(global uint32_t *out_draw, if (tid == 0) { out_ptr = (uintptr_t)poly_setup_unroll_for_draw(heap, in_draw, out_draw, mode, index_size_B); - *(uintptr_t *)scratch = out_ptr; + if (get_num_sub_groups() > 1) + *(uintptr_t *)scratch = out_ptr; } - barrier(CLK_LOCAL_MEM_FENCE); - out_ptr = *(uintptr_t *)scratch; + if (get_num_sub_groups() > 1) { + barrier(CLK_LOCAL_MEM_FENCE); + out_ptr = *(uintptr_t *)scratch; + } else { + out_ptr = sub_group_broadcast(out_ptr, 0); + } uintptr_t in_ptr = (uintptr_t)(poly_index_buffer( index_buffer, index_buffer_range_el, in_draw[2], index_size_B)); @@ -108,7 +209,7 @@ poly_unroll_restart(global uint32_t *out_draw, uint next_offs = poly_work_group_first_true(restart, scratch); next_restart += next_offs; - if (next_offs < 1024) + if (next_offs < cl_local_size.x) break; } @@ -116,10 +217,10 @@ poly_unroll_restart(global uint32_t *out_draw, uint subcount = next_restart - needle; uint subprims = u_decomposed_prims_for_vertices(mode, subcount); uint out_prims_base = out_prims; - for (uint i = tid; i < subprims; i += 1024) { + for (uint i = tid; i < subprims; i += cl_local_size.x) { for (uint vtx = 0; vtx < per_prim; ++vtx) { - uint id = - poly_vertex_id_for_topology(mode, flatshade_first, i, vtx, subprims); + uint id = poly_vertex_id_for_topology(mode, flatshade_first, i, + vtx, subprims); uint offset = needle + id; uint x = ((out_prims_base + i) * per_prim) + vtx;