diff --git a/src/amd/compiler/aco_insert_waitcnt.cpp b/src/amd/compiler/aco_insert_waitcnt.cpp index fd5ae122973..40d1937639f 100644 --- a/src/amd/compiler/aco_insert_waitcnt.cpp +++ b/src/amd/compiler/aco_insert_waitcnt.cpp @@ -75,32 +75,25 @@ struct wait_entry { uint32_t logical_events; /* use wait_event notion */ uint8_t counters; /* use counter_type notion */ bool wait_on_read : 1; - bool logical : 1; uint8_t vmem_types : 4; /* use vmem_type notion. for counter_vm. */ uint8_t vm_mask : 2; /* which halves of the VGPR event_vmem uses */ - wait_entry(wait_event event_, wait_imm imm_, uint8_t counters_, bool logical_, - bool wait_on_read_) + wait_entry(wait_event event_, wait_imm imm_, uint8_t counters_, bool wait_on_read_) : imm(imm_), events(event_), logical_events(event_), counters(counters_), - wait_on_read(wait_on_read_), logical(logical_), vmem_types(0), vm_mask(0) + wait_on_read(wait_on_read_), vmem_types(0), vm_mask(0) {} - bool join(const wait_entry& other, bool logical_edge) + bool join(const wait_entry& other) { bool changed = (other.events & ~events) || (other.counters & ~counters) || (other.wait_on_read && !wait_on_read) || (other.vmem_types & ~vmem_types) || - (other.vm_mask & ~vm_mask) || (!other.logical && logical); + (other.vm_mask & ~vm_mask); events |= other.events; - if (logical_edge) { - changed |= other.logical_events & ~logical_events; - logical_events |= other.logical_events; - } counters |= other.counters; changed |= imm.combine(other.imm); wait_on_read |= other.wait_on_read; vmem_types |= other.vmem_types; vm_mask |= other.vm_mask; - logical &= other.logical; return changed; } @@ -122,7 +115,6 @@ struct wait_entry { UNUSED void print(FILE* output) const { - fprintf(output, "logical: %u\n", logical); imm.print(output); if (events) fprintf(output, "events: %u\n", events); @@ -132,8 +124,6 @@ struct wait_entry { fprintf(output, "counters: %u\n", counters); if (!wait_on_read) fprintf(output, "wait_on_read: %u\n", wait_on_read); - if (!logical) - fprintf(output, "logical: %u\n", logical); if (vmem_types) fprintf(output, "vmem_types: %u\n", vmem_types); if (vm_mask) @@ -211,29 +201,27 @@ struct wait_ctx { using iterator = std::map::iterator; - for (const auto& entry : other->gpr_map) { - if (logical_merge ? !logical : (entry.second.logical != logical)) { - if (logical) { - iterator it = gpr_map.find(entry.first); - if (it != gpr_map.end()) { - changed |= entry.second.logical_events & ~it->second.logical_events; - it->second.logical_events |= entry.second.logical_events; - } - } - continue; - } - - const std::pair insert_pair = gpr_map.insert(entry); - if (insert_pair.second) { - if (!logical) + if (logical == logical_merge) { + for (const auto& entry : other->gpr_map) { + const std::pair insert_pair = gpr_map.insert(entry); + if (insert_pair.second) { insert_pair.first->second.logical_events = 0; - changed = true; - } else { - changed |= insert_pair.first->second.join(entry.second, logical); + changed = true; + } else { + changed |= insert_pair.first->second.join(entry.second); + } } } if (logical) { + for (const auto& entry : other->gpr_map) { + iterator it = gpr_map.find(entry.first); + if (it != gpr_map.end()) { + changed |= (entry.second.logical_events & ~it->second.logical_events) != 0; + it->second.logical_events |= entry.second.logical_events; + } + } + for (unsigned i = 0; i < storage_count; i++) { changed |= barrier_imm[i].combine(other->barrier_imm[i]); changed |= (other->barrier_events[i] & ~barrier_events[i]) != 0; @@ -318,6 +306,35 @@ get_vmem_mask(wait_ctx& ctx, Instruction* instr) } } +wait_imm +get_imm(wait_ctx& ctx, PhysReg reg, wait_entry& entry) +{ + if (reg.reg() >= 256) { + uint32_t events = entry.logical_events; + + /* ALU can't safely write to unwritten destination VGPR lanes with DS/VMEM on GFX11+ without + * waiting for the load to finish, even if none of the lanes are involved in the load. + */ + if (ctx.gfx_level >= GFX11) { + uint32_t ds_vmem_events = + event_lds | event_gds | event_vmem | event_vmem_sample | event_vmem_bvh | event_flat; + events |= ds_vmem_events; + } + + uint32_t counters = 0; + u_foreach_bit (i, entry.events & events) + counters |= ctx.info->get_counters_for_event((wait_event)(1 << i)); + + wait_imm imm; + u_foreach_bit (i, entry.counters & counters) + imm[i] = entry.imm[i]; + + return imm; + } else { + return entry.imm; + } +} + void check_instr(wait_ctx& ctx, wait_imm& wait, Instruction* instr) { @@ -329,7 +346,7 @@ check_instr(wait_ctx& ctx, wait_imm& wait, Instruction* instr) for (unsigned j = 0; j < op.size(); j++) { std::map::iterator it = ctx.gpr_map.find(PhysReg{op.physReg() + j}); if (it != ctx.gpr_map.end() && it->second.wait_on_read) - wait.combine(it->second.imm); + wait.combine(get_imm(ctx, PhysReg{op.physReg() + j}, it->second)); } } @@ -342,7 +359,7 @@ check_instr(wait_ctx& ctx, wait_imm& wait, Instruction* instr) if (it == ctx.gpr_map.end()) continue; - wait_imm reg_imm = it->second.imm; + wait_imm reg_imm = get_imm(ctx, reg, it->second); /* Vector Memory reads and writes decrease the counter in the order they were issued. * Before GFX12, they also write VGPRs in order if they're of the same type. @@ -611,22 +628,24 @@ update_counters_for_flat_load(wait_ctx& ctx, memory_sync_info sync = memory_sync void insert_wait_entry(wait_ctx& ctx, PhysReg reg, RegClass rc, wait_event event, bool wait_on_read, - uint8_t vmem_types = 0, uint32_t vm_mask = 0, bool force_linear = false) + uint8_t vmem_types = 0, uint32_t vm_mask = 0) { uint16_t counters = ctx.info->get_counters_for_event(event); wait_imm imm; u_foreach_bit (i, counters) imm[i] = 0; - wait_entry new_entry(event, imm, counters, !rc.is_linear() && !force_linear, wait_on_read); + wait_entry new_entry(event, imm, counters, wait_on_read); if (counters & counter_vm) new_entry.vmem_types |= vmem_types; for (unsigned i = 0; i < rc.size(); i++, vm_mask >>= 2) { new_entry.vm_mask = vm_mask & 0x3; auto it = ctx.gpr_map.emplace(PhysReg{reg.reg() + i}, new_entry); - if (!it.second) - it.first->second.join(new_entry, true); + if (!it.second) { + it.first->second.join(new_entry); + it.first->second.logical_events |= event; + } } } @@ -642,15 +661,7 @@ void insert_wait_entry(wait_ctx& ctx, Definition def, wait_event event, uint8_t vmem_types = 0, uint32_t vm_mask = 0) { - /* We can't safely write to unwritten destination VGPR lanes with DS/VMEM on GFX11 without - * waiting for the load to finish. - */ - uint32_t ds_vmem_events = - event_lds | event_gds | event_vmem | event_vmem_sample | event_vmem_bvh | event_flat; - bool force_linear = ctx.gfx_level >= GFX11 && (event & ds_vmem_events); - - insert_wait_entry(ctx, def.physReg(), def.regClass(), event, true, vmem_types, vm_mask, - force_linear); + insert_wait_entry(ctx, def.physReg(), def.regClass(), event, true, vmem_types, vm_mask); } void diff --git a/src/amd/compiler/tests/test_insert_waitcnt.cpp b/src/amd/compiler/tests/test_insert_waitcnt.cpp index f1f48794765..e06b21fcb2e 100644 --- a/src/amd/compiler/tests/test_insert_waitcnt.cpp +++ b/src/amd/compiler/tests/test_insert_waitcnt.cpp @@ -613,6 +613,70 @@ BEGIN_TEST(insert_waitcnt.vmem_ds) finish_waitcnt_test(); END_TEST +BEGIN_TEST(insert_waitcnt.waw.vmem_ds_valu) + for (amd_gfx_level gfx : {GFX10_3, GFX11, GFX12}) { + if (!setup_cs(NULL, gfx)) + continue; + + Definition def_v4(PhysReg(260), v1); + Operand op_v0(PhysReg(256), v1); + Operand desc_s4(PhysReg(0), s4); + + emit_divergent_if_else( + program.get(), bld, Operand::c64(1), + [&]() + { + //>> p_unit_test 1 + //! v1: %0:v[4] = buffer_load_dword %0:s[0-3], %0:v[0], 0 + bld.pseudo(aco_opcode::p_unit_test, Operand::c32(1)); + bld.mubuf(aco_opcode::buffer_load_dword, def_v4, desc_s4, op_v0, Operand::zero(), 0, + false); + }, + [&]() + { + //>> p_unit_test 2 + //~gfx11! s_waitcnt vmcnt(0) + //~gfx12! s_wait_loadcnt imm:0 + //! v1: %0:v[4] = v_mov_b32 0 + bld.pseudo(aco_opcode::p_unit_test, Operand::c32(2)); + bld.vop1(aco_opcode::v_mov_b32, def_v4, Operand::zero()); + }); + //>> p_unit_test 3 + //~gfx(10_3|11)! s_waitcnt vmcnt(0) + //~gfx12! s_wait_loadcnt imm:0 + //! p_unit_test %0:v[4] + bld.pseudo(aco_opcode::p_unit_test, Operand::c32(3)); + bld.pseudo(aco_opcode::p_unit_test, Operand(PhysReg(260), v1)); + + emit_divergent_if_else( + program.get(), bld, Operand::c64(1), + [&]() + { + //>> p_unit_test 4 + //! v1: %0:v[4] = ds_read_b32 %0:v[0] + bld.pseudo(aco_opcode::p_unit_test, Operand::c32(4)); + bld.ds(aco_opcode::ds_read_b32, def_v4, op_v0); + }, + [&]() + { + //>> p_unit_test 5 + //~gfx11! s_waitcnt lgkmcnt(0) + //~gfx12! s_wait_dscnt imm:0 + //! v1: %0:v[4] = v_mov_b32 0 + bld.pseudo(aco_opcode::p_unit_test, Operand::c32(5)); + bld.vop1(aco_opcode::v_mov_b32, def_v4, Operand::zero()); + }); + //>> p_unit_test 6 + //~gfx(10_3|11)! s_waitcnt lgkmcnt(0) + //~gfx12! s_wait_dscnt imm:0 + //! p_unit_test %0:v[4] + bld.pseudo(aco_opcode::p_unit_test, Operand::c32(6)); + bld.pseudo(aco_opcode::p_unit_test, Operand(PhysReg(260), v1)); + + finish_waitcnt_test(); + } +END_TEST + BEGIN_TEST(insert_waitcnt.waw.vmem_different_halves) if (!setup_cs(NULL, GFX12)) return; @@ -799,3 +863,39 @@ BEGIN_TEST(insert_waitcnt.divergent_branch.inc_counter) finish_waitcnt_test(); } END_TEST + +BEGIN_TEST(insert_waitcnt.divergent_branch.no_skip) + for (amd_gfx_level gfx : {GFX10_3, GFX11, GFX12}) { + if (!setup_cs(NULL, gfx)) + continue; + + Definition def_v4(PhysReg(260), v1); + Operand op_v0(PhysReg(256), v1); + Operand desc_s4(PhysReg(0), s4); + Operand desc_s8(PhysReg(8), s8); + + //>> v1: %0:v[4] = buffer_load_dword %0:s[0-3], %0:v[0], 0 + bld.mubuf(aco_opcode::buffer_load_dword, def_v4, desc_s4, op_v0, Operand::zero(), 0, false); + + //>> p_unit_test 1 + //~gfx(10_3|11)! s_waitcnt vmcnt(0) + //~gfx12! s_wait_loadcnt imm:0 + //! p_unit_test %0:v[4] + bld.reset(program->create_and_insert_block()); + program->blocks[1].linear_preds.push_back(0); + program->blocks[1].logical_preds.push_back(0); + bld.pseudo(aco_opcode::p_unit_test, Operand::c32(1)); + bld.pseudo(aco_opcode::p_unit_test, Operand(PhysReg(260), v1)); + + //>> p_unit_test 2 + //! p_unit_test %0:v[4] + bld.reset(program->create_and_insert_block()); + program->blocks[2].linear_preds.push_back(1); + program->blocks[2].logical_preds.push_back(1); + program->blocks[2].logical_preds.push_back(0); + bld.pseudo(aco_opcode::p_unit_test, Operand::c32(2)); + bld.pseudo(aco_opcode::p_unit_test, Operand(PhysReg(260), v1)); + + finish_waitcnt_test(); + } +END_TEST