aco: refactor waitcnt pass to use barrier_info

Currently there's just barrier_info_all, but more will be added later.

Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Georg Lehmann <dadschoorse@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/36491>
This commit is contained in:
Rhys Perry 2025-07-17 12:29:41 +01:00 committed by Marge Bot
parent 21332609b9
commit 6c446c2f83

View file

@ -166,6 +166,41 @@ private:
uint8_t counters[num_events] = {};
};
enum barrier_info_kind {
/* Waits for all non-private accesses and all scratch/vgpr-spill accesses */
barrier_info_all,
num_barrier_infos,
};
/* Used to keep track of wait imms that are yet to be emitted. */
struct barrier_info {
wait_imm imm[storage_count];
uint16_t events[storage_count] = {}; /* use wait_event notion */
uint8_t storage = 0;
bool join(const barrier_info& other)
{
bool changed = false;
for (unsigned i = 0; i < storage_count; i++) {
changed |= imm[i].combine(other.imm[i]);
changed |= (other.events[i] & ~events[i]) != 0;
events[i] |= other.events[i];
}
storage |= other.storage;
return changed;
}
UNUSED void print(FILE* output) const
{
u_foreach_bit (i, storage) {
fprintf(output, "storage[%u] = {\n", i);
imm[i].print(output);
fprintf(output, "events: %u\n", events[i]);
fprintf(output, "}\n");
}
}
};
struct wait_ctx {
Program* program;
enum amd_gfx_level gfx_level;
@ -176,8 +211,8 @@ struct wait_ctx {
bool pending_flat_vm = false;
bool pending_s_buffer_store = false; /* GFX10 workaround */
wait_imm barrier_imm[storage_count];
uint16_t barrier_events[storage_count] = {}; /* use wait_event notion */
barrier_info bar[num_barrier_infos];
uint8_t bar_nonempty = 0;
std::map<PhysReg, wait_entry> gpr_map;
@ -219,11 +254,9 @@ struct wait_ctx {
}
}
for (unsigned i = 0; i < storage_count; i++) {
changed |= barrier_imm[i].combine(other->barrier_imm[i]);
changed |= (other->barrier_events[i] & ~barrier_events[i]) != 0;
barrier_events[i] |= other->barrier_events[i];
}
u_foreach_bit (i, other->bar_nonempty)
changed |= bar[i].join(other->bar[i]);
bar_nonempty |= other->bar_nonempty;
}
return changed;
@ -242,13 +275,10 @@ struct wait_ctx {
fprintf(output, "}\n");
}
for (unsigned i = 0; i < storage_count; i++) {
if (!barrier_imm[i].empty() || barrier_events[i]) {
fprintf(output, "barriers[%u] = {\n", i);
barrier_imm[i].print(output);
fprintf(output, "events: %u\n", barrier_events[i]);
fprintf(output, "}\n");
}
u_foreach_bit (i, bar_nonempty) {
fprintf(output, "barriers[%u] = {\n", i);
bar[i].print(output);
fprintf(output, "}\n");
}
}
};
@ -402,24 +432,22 @@ perform_barrier(wait_ctx& ctx, wait_imm& imm, memory_sync_info sync, unsigned se
sync_scope subgroup_scope =
ctx.program->workgroup_size <= ctx.program->wave_size ? scope_workgroup : scope_subgroup;
if ((sync.semantics & semantics) && sync.scope > subgroup_scope) {
unsigned storage = sync.storage;
while (storage) {
unsigned idx = u_bit_scan(&storage);
barrier_info& bar = ctx.bar[barrier_info_all];
u_foreach_bit (i, sync.storage & bar.storage) {
uint16_t events = bar.events[i];
/* LDS is private to the workgroup */
sync_scope bar_scope_lds = MIN2(sync.scope, scope_workgroup);
uint16_t events = ctx.barrier_events[idx];
if (bar_scope_lds <= subgroup_scope)
if (MIN2(sync.scope, scope_workgroup) <= subgroup_scope)
events &= ~event_lds;
/* Until GFX11, in non-WGP, the L1 (L0 on GFX10+) cache keeps all memory operations
* in-order for the same workgroup */
if (ctx.gfx_level < GFX11 && !ctx.program->wgp_mode && sync.scope <= scope_workgroup)
events &= ~(event_vmem | event_vmem_store | event_smem);
events &= ~(event_vmem | event_vmem_store);
if (events)
imm.combine(ctx.barrier_imm[idx]);
imm.combine(bar.imm[i]);
}
}
}
@ -431,6 +459,31 @@ force_waitcnt(wait_ctx& ctx, wait_imm& imm)
imm[i] = 0;
}
void
update_barrier_info_for_wait(wait_ctx& ctx, unsigned idx, wait_imm imm)
{
barrier_info& info = ctx.bar[idx];
for (unsigned i = 0; i < wait_type_num; i++) {
if (imm[i] == wait_imm::unset_counter)
continue;
u_foreach_bit (j, info.storage) {
wait_imm& bar = info.imm[j];
if (bar[i] != wait_imm::unset_counter && imm[i] <= bar[i]) {
/* Clear this counter */
bar[i] = wait_imm::unset_counter;
info.events[j] &= ~ctx.info->events[i];
if (!info.events[j]) {
info.storage &= ~(1 << j);
if (!info.storage)
ctx.bar_nonempty &= ~(1 << idx);
}
}
}
}
}
void
kill(wait_imm& imm, Instruction* instr, wait_ctx& ctx, memory_sync_info sync_info)
{
@ -447,8 +500,9 @@ kill(wait_imm& imm, Instruction* instr, wait_ctx& ctx, memory_sync_info sync_inf
*/
if (ctx.gfx_level >= GFX11 && instr->opcode == aco_opcode::s_sendmsg &&
instr->salu().imm == sendmsg_dealloc_vgprs) {
imm.combine(ctx.barrier_imm[ffs(storage_scratch) - 1]);
imm.combine(ctx.barrier_imm[ffs(storage_vgpr_spill) - 1]);
barrier_info& bar = ctx.bar[barrier_info_all];
imm.combine(bar.imm[ffs(storage_scratch) - 1]);
imm.combine(bar.imm[ffs(storage_vgpr_spill) - 1]);
}
/* Make sure POPS coherent memory accesses have reached the L2 cache before letting the
@ -495,7 +549,8 @@ kill(wait_imm& imm, Instruction* instr, wait_ctx& ctx, memory_sync_info sync_inf
if (instr->opcode == aco_opcode::ds_ordered_count &&
((instr->ds().offset1 | (instr->ds().offset0 >> 8)) & 0x1)) {
imm.combine(ctx.barrier_imm[ffs(storage_gds) - 1]);
barrier_info& bar = ctx.bar[barrier_info_all];
imm.combine(bar.imm[ffs(storage_gds) - 1]);
}
if (instr->opcode == aco_opcode::p_barrier)
@ -513,17 +568,8 @@ kill(wait_imm& imm, Instruction* instr, wait_ctx& ctx, memory_sync_info sync_inf
for (unsigned i = 0; i < wait_type_num; i++)
ctx.nonzero &= imm[i] == 0 ? ~BITFIELD_BIT(i) : UINT32_MAX;
/* update barrier wait imms */
for (unsigned i = 0; i < storage_count; i++) {
wait_imm& bar = ctx.barrier_imm[i];
uint16_t& bar_ev = ctx.barrier_events[i];
for (unsigned j = 0; j < wait_type_num; j++) {
if (bar[j] != wait_imm::unset_counter && imm[j] <= bar[j]) {
bar[j] = wait_imm::unset_counter;
bar_ev &= ~ctx.info->events[j];
}
}
}
u_foreach_bit (i, ctx.bar_nonempty)
update_barrier_info_for_wait(ctx, i, imm);
/* remove all gprs with higher counter from map */
std::map<PhysReg, wait_entry>::iterator it = ctx.gpr_map.begin();
@ -548,20 +594,28 @@ kill(wait_imm& imm, Instruction* instr, wait_ctx& ctx, memory_sync_info sync_inf
}
void
update_barrier_imm(wait_ctx& ctx, uint8_t counters, wait_event event, memory_sync_info sync)
update_barrier_info_for_event(wait_ctx& ctx, uint8_t counters, wait_event event,
barrier_info_kind idx, uint16_t storage)
{
for (unsigned i = 0; i < storage_count; i++) {
wait_imm& bar = ctx.barrier_imm[i];
uint16_t& bar_ev = ctx.barrier_events[i];
barrier_info& info = ctx.bar[idx];
if (storage) {
info.storage |= storage;
ctx.bar_nonempty |= 1 << idx;
}
/* We re-use barrier_imm/barrier_events to wait for all scratch stores to finish. */
bool ignore_private = i == (ffs(storage_scratch) - 1) || i == (ffs(storage_vgpr_spill) - 1);
unsigned storage_tmp = info.storage;
while (storage_tmp) {
unsigned i = u_bit_scan(&storage_tmp);
wait_imm& bar = info.imm[i];
uint16_t& bar_ev = info.events[i];
if (sync.storage & (1 << i) && (!(sync.semantics & semantic_private) || ignore_private)) {
if (storage & (1 << i)) {
/* Reset counters to zero so that this instruction is waited on. */
bar_ev |= event;
u_foreach_bit (j, counters)
bar[j] = 0;
} else if (!(bar_ev & ctx.info->unordered_events) && !(ctx.info->unordered_events & event)) {
/* Increase counters so that this instruction is ignored when waiting. */
u_foreach_bit (j, counters) {
if (bar[j] != wait_imm::unset_counter && (bar_ev & ctx.info->events[j]) == event)
bar[j] = std::min<uint16_t>(bar[j] + 1, ctx.info->max_cnt[j]);
@ -570,6 +624,18 @@ update_barrier_imm(wait_ctx& ctx, uint8_t counters, wait_event event, memory_syn
}
}
/* This resets or increases the counters for the barrier infos in response to an instruction. */
void
update_barriers(wait_ctx& ctx, uint8_t counters, wait_event event, memory_sync_info sync)
{
uint16_t storage_all = sync.storage;
/* We re-use barrier_info_all to wait for all scratch stores to finish, to track those even if
* they are private. */
if (sync.semantics & semantic_private)
storage_all &= storage_scratch | storage_vgpr_spill;
update_barrier_info_for_event(ctx, counters, event, barrier_info_all, storage_all);
}
void
update_counters(wait_ctx& ctx, wait_event event, memory_sync_info sync = memory_sync_info())
{
@ -577,7 +643,7 @@ update_counters(wait_ctx& ctx, wait_event event, memory_sync_info sync = memory_
ctx.nonzero |= counters;
update_barrier_imm(ctx, counters, event, sync);
update_barriers(ctx, counters, event, sync);
if (ctx.info->unordered_events & event)
return;
@ -837,7 +903,7 @@ handle_block(Program* program, Block& block, wait_ctx& ctx)
perform_barrier(ctx, queued_imm, sync_info, semantic_acquire);
if (is_ordered_count_acquire)
queued_imm.combine(ctx.barrier_imm[ffs(storage_gds) - 1]);
queued_imm.combine(ctx.bar[barrier_info_all].imm[ffs(storage_gds) - 1]);
}
}
@ -869,8 +935,8 @@ insert_waitcnt(Program* program)
unsigned loop_progress = 0;
if (program->pending_lds_access) {
update_barrier_imm(in_ctx[0], info.get_counters_for_event(event_lds), event_lds,
memory_sync_info(storage_shared));
update_barriers(in_ctx[0], info.get_counters_for_event(event_lds), event_lds,
memory_sync_info(storage_shared));
}
for (Definition def : program->args_pending_vmem) {