nir,aco: Add ds_bvh_stack_rtn

This is a ds instruction that also overwrites its first input, so
introduce a new ds format with two outputs.

Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/35269>
This commit is contained in:
Natalie Vock 2023-12-19 21:10:41 +01:00 committed by Marge Bot
parent 8815845271
commit 9707b30965
11 changed files with 72 additions and 9 deletions

View file

@ -513,7 +513,7 @@ emit_ds_instruction(asm_context& ctx, std::vector<uint32_t>& out, const Instruct
out.push_back(encoding);
encoding = 0;
if (!instr->definitions.empty())
encoding |= reg(ctx, instr->definitions[0], 8) << 24;
encoding |= reg(ctx, instr->definitions.back(), 8) << 24;
for (unsigned i = 0; i < MIN2(instr->operands.size(), 3); i++) {
const Operand& op = instr->operands[i];
if (op.physReg() != m0 && !op.isUndefined())

View file

@ -569,7 +569,7 @@ formats = [("pseudo", [Format.PSEUDO], list(itertools.product(range(5), range(7)
("sopp", [Format.SOPP], [(0, 0), (0, 1)]),
("sopc", [Format.SOPC], [(1, 2)]),
("smem", [Format.SMEM], [(0, 4), (0, 3), (1, 0), (1, 3), (1, 2), (1, 1), (0, 0)]),
("ds", [Format.DS], [(1, 0), (1, 1), (1, 2), (1, 3), (0, 3), (0, 4)]),
("ds", [Format.DS], [(1, 0), (1, 1), (1, 2), (1, 3), (0, 3), (0, 4), (2, 3)]),
("ldsdir", [Format.LDSDIR], [(1, 1)]),
("mubuf", [Format.MUBUF], [(0, 4), (1, 3), (1, 4)]),
("mtbuf", [Format.MTBUF], [(0, 4), (1, 3)]),

View file

@ -696,8 +696,8 @@ gen(Instruction* instr, wait_ctx& ctx)
if (ds.gds)
update_counters(ctx, event_gds_gpr_lock);
if (!instr->definitions.empty())
insert_wait_entry(ctx, instr->definitions[0], ds.gds ? event_gds : event_lds);
for (auto& definition : instr->definitions)
insert_wait_entry(ctx, definition, ds.gds ? event_gds : event_lds);
if (ds.gds) {
for (const Operand& op : instr->operands)

View file

@ -1444,7 +1444,8 @@ get_tied_defs(Instruction* instr)
instr->opcode == aco_opcode::s_fmac_f16) {
ops.push_back(2);
} else if (instr->opcode == aco_opcode::s_addk_i32 || instr->opcode == aco_opcode::s_mulk_i32 ||
instr->opcode == aco_opcode::s_cmovk_i32) {
instr->opcode == aco_opcode::s_cmovk_i32 ||
instr->opcode == aco_opcode::ds_bvh_stack_push4_pop1_rtn_b32) {
ops.push_back(0);
} else if (instr->isMUBUF() && instr->definitions.size() == 1 && instr->operands.size() == 4) {
ops.push_back(3);

View file

@ -1598,6 +1598,14 @@ static_assert(sizeof(VINTRP_instruction) == sizeof(Instruction) + 4, "Unexpected
* Operand(n-1): M0 - LDS size.
* Definition(0): VDST - Destination VGPR when results returned to VGPRs.
*
* For ds_bvh_stack* instructions:
*
* Operand(0): ADDR - VGPR supplying the stack address (overwritten with stack address after push)
* Operand(1): LVADDR - VGPR supplying the last visited node ID
* Operand(2): DATA - VGPR supplying the result of bvh*_intersect_ray
* Definition(0) - new ADDR (tied to operand 0, contains new stack address)
* Definition(1): VDST - next node ID to test for intersection
*
*/
struct DS_instruction : public Instruction {
memory_sync_info sync;

View file

@ -1649,6 +1649,7 @@ DS = {
("ds_pk_add_rtn_f16", op(gfx12=0xaa)),
("ds_pk_add_bf16", op(gfx12=0x9b)),
("ds_pk_add_rtn_bf16", op(gfx12=0xab)),
("ds_bvh_stack_push4_pop1_rtn_b32", op(gfx11=0xad, gfx12=0xe0)), #ds_bvh_stack_rtn in GFX11
}
for (name, num) in DS:
insn(name, num, Format.DS, InstrClass.DS)

View file

@ -1518,7 +1518,8 @@ label_instruction(opt_ctx& ctx, aco_ptr<Instruction>& instr)
if (has_usable_ds_offset && i == 0 &&
parse_base_offset(ctx, instr.get(), i, &base, &offset, false) &&
base.regClass() == instr->operands[i].regClass() &&
instr->opcode != aco_opcode::ds_swizzle_b32) {
instr->opcode != aco_opcode::ds_swizzle_b32 &&
instr->opcode != aco_opcode::ds_bvh_stack_push4_pop1_rtn_b32) {
if (instr->opcode == aco_opcode::ds_write2_b32 ||
instr->opcode == aco_opcode::ds_read2_b32 ||
instr->opcode == aco_opcode::ds_write2_b64 ||

View file

@ -912,9 +912,10 @@ validate_ir(Program* program)
check(op.isOfType(RegType::vgpr) || op.physReg() == m0 || op.isUndefined(),
"Only VGPRs are valid DS instruction operands", instr.get());
}
if (!instr->definitions.empty())
check(instr->definitions[0].regClass().type() == RegType::vgpr,
"DS instruction must return VGPR", instr.get());
for (const Definition& def : instr->definitions) {
check(def.regClass().type() == RegType::vgpr, "DS instruction must return VGPR",
instr.get());
}
break;
}
case Format::EXP: {

View file

@ -4018,6 +4018,40 @@ pops_await_overlapped_waves(isel_context* ctx)
bld.reset(ctx->block);
}
uint16_t
ds_bvh_stack_offset1_gfx11(unsigned stack_size)
{
switch (stack_size) {
case 8: return 0x00;
case 16: return 0x10;
case 32: return 0x20;
case 64: return 0x30;
default: unreachable("invalid stack size");
}
}
void
emit_ds_bvh_stack_push4_pop1_rtn(isel_context* ctx, nir_intrinsic_instr* instr, Builder& bld)
{
Temp dst = get_ssa_temp(ctx, &instr->def);
Temp stack_addr = as_vgpr(ctx, get_ssa_temp(ctx, instr->src[0].ssa));
Temp last_node = as_vgpr(ctx, get_ssa_temp(ctx, instr->src[1].ssa));
Temp intersection_result = as_vgpr(ctx, get_ssa_temp(ctx, instr->src[2].ssa));
Temp dst_stack_addr = bld.tmp(v1);
Temp dst_node_pointer = bld.tmp(v1);
uint32_t offset0 = 0, offset1 = 0;
if (ctx->program->gfx_level >= GFX12)
offset0 = nir_intrinsic_stack_size(instr);
else
offset1 = ds_bvh_stack_offset1_gfx11(nir_intrinsic_stack_size(instr));
bld.ds(aco_opcode::ds_bvh_stack_push4_pop1_rtn_b32, Definition(dst_stack_addr),
Definition(dst_node_pointer), Operand(stack_addr), Operand(last_node),
Operand(intersection_result), offset0, offset1);
bld.pseudo(aco_opcode::p_create_vector, Definition(dst), Operand(dst_stack_addr),
Operand(dst_node_pointer));
}
} // namespace
void
@ -5056,6 +5090,13 @@ visit_intrinsic(isel_context* ctx, nir_intrinsic_instr* instr)
bld.pseudo(aco_opcode::p_unit_test, Definition(get_ssa_temp(ctx, &instr->def)),
Operand::c32(nir_intrinsic_base(instr)));
break;
case nir_intrinsic_bvh_stack_rtn_amd: {
switch (instr->num_components) {
case 4: emit_ds_bvh_stack_push4_pop1_rtn(ctx, instr, bld); break;
default: unreachable("Invalid BVH stack component count!");
}
break;
}
default:
isel_err(&instr->instr, "Unimplemented intrinsic instr");
abort();

View file

@ -959,6 +959,7 @@ visit_intrinsic(nir_intrinsic_instr *instr, struct divergence_state *state)
case nir_intrinsic_load_agx:
case nir_intrinsic_load_shared_lock_nv:
case nir_intrinsic_store_shared_unlock_nv:
case nir_intrinsic_bvh_stack_rtn_amd:
is_divergent = true;
break;

View file

@ -1868,6 +1868,15 @@ intrinsic("bvh64_intersect_ray_amd", [4, 2, 1, 3, 3, 3], 4, flags=[CAN_ELIMINATE
#
intrinsic("bvh8_intersect_ray_amd", [4, 2, 1, 1, 3, 3, 1], 16, flags=[CAN_ELIMINATE, CAN_REORDER])
# operands:
# 1. stack address
# 2. previous node pointer
# 3. BVH node pointers
# returns:
# component 0: next stack address
# component 1: next node pointer
intrinsic("bvh_stack_rtn_amd", [1, 1, 0], 2, indices=[STACK_SIZE])
# Return of a callable in raytracing pipelines
intrinsic("rt_return_amd")