aco: implement rotate

Reviewed-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/27118>
This commit is contained in:
Georg Lehmann 2024-01-17 11:52:10 +01:00 committed by Marge Bot
parent b90ec971d7
commit 4c74077b62
3 changed files with 157 additions and 0 deletions

View file

@ -97,6 +97,13 @@ ds_pattern_bitmode(unsigned and_mask, unsigned or_mask, unsigned xor_mask)
return and_mask | (or_mask << 5) | (xor_mask << 10);
}
inline unsigned
ds_pattern_rotate(unsigned delta, unsigned mask)
{
assert(delta < 32 && mask < 32);
return mask | (delta << 5) | 0xc000;
}
aco_ptr<Instruction> create_s_mov(Definition dst, Operand src);
enum sendmsg {

View file

@ -7961,6 +7961,52 @@ inclusive_scan_to_exclusive(isel_context* ctx, ReduceOp op, Definition dst, Temp
}
}
bool
emit_rotate_by_constant(isel_context* ctx, Temp& dst, Temp src, unsigned cluster_size,
uint64_t delta)
{
Builder bld(ctx->program, ctx->block);
RegClass rc = src.regClass();
dst = Temp(0, rc);
delta %= cluster_size;
if (delta == 0) {
dst = bld.copy(bld.def(rc), src);
} else if (delta * 2 == cluster_size && cluster_size <= 32) {
dst = emit_masked_swizzle(ctx, bld, src, ds_pattern_bitmode(0x1f, 0, delta), true);
} else if (cluster_size == 4) {
unsigned res[4];
for (unsigned i = 0; i < 4; i++)
res[i] = (i + delta) & 0x3;
uint32_t dpp_ctrl = dpp_quad_perm(res[0], res[1], res[2], res[3]);
if (ctx->program->gfx_level >= GFX8)
dst = bld.vop1_dpp(aco_opcode::v_mov_b32, bld.def(rc), src, dpp_ctrl);
else
dst = bld.ds(aco_opcode::ds_swizzle_b32, bld.def(v1), src, (1 << 15) | dpp_ctrl);
} else if (cluster_size == 8 && ctx->program->gfx_level >= GFX10) {
uint32_t lane_sel = 0;
for (unsigned i = 0; i < 8; i++)
lane_sel |= ((i + delta) & 0x7) << (i * 3);
dst = bld.vop1_dpp8(aco_opcode::v_mov_b32, bld.def(rc), src, lane_sel);
} else if (cluster_size == 16 && ctx->program->gfx_level >= GFX8) {
dst = bld.vop1_dpp(aco_opcode::v_mov_b32, bld.def(rc), src, dpp_row_rr(16 - delta));
} else if (cluster_size <= 32 && ctx->program->gfx_level >= GFX9) {
uint32_t ctrl = ds_pattern_rotate(delta, ~(cluster_size - 1) & 0x1f);
dst = bld.ds(aco_opcode::ds_swizzle_b32, bld.def(v1), src, ctrl);
} else if (cluster_size == 64) {
bool has_wf_dpp = ctx->program->gfx_level >= GFX8 && ctx->program->gfx_level < GFX10;
if (delta == 32 && ctx->program->gfx_level >= GFX11) {
dst = bld.vop1(aco_opcode::v_permlane64_b32, bld.def(rc), src);
} else if (delta == 1 && has_wf_dpp) {
dst = bld.vop1_dpp(aco_opcode::v_mov_b32, bld.def(rc), src, dpp_wf_rl1);
} else if (delta == 63 && has_wf_dpp) {
dst = bld.vop1_dpp(aco_opcode::v_mov_b32, bld.def(rc), src, dpp_wf_rr1);
}
}
return dst.id() != 0;
}
void
emit_interp_center(isel_context* ctx, Temp dst, Temp bary, Temp pos1, Temp pos2)
{
@ -8432,6 +8478,109 @@ visit_intrinsic(isel_context* ctx, nir_intrinsic_instr* instr)
}
break;
}
case nir_intrinsic_rotate: {
Temp src = get_ssa_temp(ctx, instr->src[0].ssa);
Temp delta = get_ssa_temp(ctx, instr->src[1].ssa);
Temp dst = get_ssa_temp(ctx, &instr->def);
assert(nir_intrinsic_execution_scope(instr) == SCOPE_SUBGROUP);
assert(instr->def.bit_size > 1 && instr->def.bit_size <= 32);
if (!nir_src_is_divergent(instr->src[0])) {
emit_uniform_subgroup(ctx, instr, src);
break;
}
unsigned cluster_size = nir_intrinsic_cluster_size(instr);
cluster_size = util_next_power_of_two(
MIN2(cluster_size ? cluster_size : ctx->program->wave_size, ctx->program->wave_size));
if (cluster_size == 1) {
bld.copy(Definition(dst), src);
break;
}
delta = bld.as_uniform(delta);
src = as_vgpr(ctx, src);
Temp tmp;
if (nir_src_is_const(instr->src[1]) &&
emit_rotate_by_constant(ctx, tmp, src, cluster_size, nir_src_as_uint(instr->src[1]))) {
} else if (cluster_size == 2) {
Temp noswap =
bld.sopc(aco_opcode::s_bitcmp0_b32, bld.def(s1, scc), delta, Operand::c32(0));
noswap = bool_to_vector_condition(ctx, noswap);
Temp swapped = emit_masked_swizzle(ctx, bld, src, ds_pattern_bitmode(0x1f, 0, 0x1), true);
tmp = bld.vop2(aco_opcode::v_cndmask_b32, bld.def(src.regClass()), swapped, src, noswap);
} else if (ctx->program->gfx_level >= GFX10 && cluster_size <= 16) {
if (cluster_size == 4) /* shift mask already does this for 8/16. */
delta = bld.sop2(aco_opcode::s_and_b32, bld.def(s1), bld.def(s1, scc), delta,
Operand::c32(0x3));
delta =
bld.sop2(aco_opcode::s_lshl_b32, bld.def(s1), bld.def(s1, scc), delta, Operand::c32(2));
Temp lo = bld.copy(bld.def(s1), Operand::c32(cluster_size == 4 ? 0x32103210 : 0x76543210));
Temp hi;
if (cluster_size <= 8) {
Temp shr = bld.sop2(aco_opcode::s_lshr_b32, bld.def(s1), bld.def(s1, scc), lo, delta);
if (cluster_size == 4) {
Temp lotolohi = bld.copy(bld.def(s1), Operand::c32(0x4444));
Temp lohi =
bld.sop2(aco_opcode::s_or_b32, bld.def(s1), bld.def(s1, scc), shr, lotolohi);
lo = bld.sop2(aco_opcode::s_pack_ll_b32_b16, bld.def(s1), shr, lohi);
} else {
delta = bld.sop2(aco_opcode::s_sub_u32, bld.def(s1), bld.def(s1, scc),
Operand::c32(32), delta);
Temp shl =
bld.sop2(aco_opcode::s_lshl_b32, bld.def(s1), bld.def(s1, scc), lo, delta);
lo = bld.sop2(aco_opcode::s_or_b32, bld.def(s1), bld.def(s1, scc), shr, shl);
}
Temp lotohi = bld.copy(bld.def(s1), Operand::c32(0x88888888));
hi = bld.sop2(aco_opcode::s_or_b32, bld.def(s1), bld.def(s1, scc), lo, lotohi);
} else {
hi = bld.copy(bld.def(s1), Operand::c32(0xfedcba98));
Temp lohi = bld.pseudo(aco_opcode::p_create_vector, bld.def(s2), lo, hi);
Temp shr = bld.sop2(aco_opcode::s_lshr_b64, bld.def(s2), bld.def(s1, scc), lohi, delta);
delta = bld.sop2(aco_opcode::s_sub_u32, bld.def(s1), bld.def(s1, scc), Operand::c32(64),
delta);
Temp shl = bld.sop2(aco_opcode::s_lshl_b64, bld.def(s2), bld.def(s1, scc), lohi, delta);
lohi = bld.sop2(aco_opcode::s_or_b64, bld.def(s2), bld.def(s1, scc), shr, shl);
lo = bld.tmp(s1);
hi = bld.tmp(s1);
bld.pseudo(aco_opcode::p_split_vector, Definition(lo), Definition(hi), lohi);
}
Builder::Result ret =
bld.vop3(aco_opcode::v_permlane16_b32, bld.def(src.regClass()), src, lo, hi);
ret->valu().opsel[0] = true; /* set FETCH_INACTIVE */
ret->valu().opsel[1] = true; /* set BOUND_CTRL */
tmp = ret;
} else {
/* Fallback to ds_bpermute if we can't find a special instruction. */
Temp tid = emit_mbcnt(ctx, bld.tmp(v1));
Temp src_lane = bld.vadd32(bld.def(v1), tid, delta);
if (ctx->program->gfx_level >= GFX10 && cluster_size == 32) {
/* ds_bpermute is restricted to 32 lanes on GFX10+. */
Temp index_x4 =
bld.vop2(aco_opcode::v_lshlrev_b32, bld.def(v1), Operand::c32(2u), src_lane);
tmp = bld.ds(aco_opcode::ds_bpermute_b32, bld.def(v1), index_x4, src);
} else {
/* Technically, full wave rotate doesn't need this, but it breaks the pseudo ops. */
src_lane = bld.vop3(aco_opcode::v_bfi_b32, bld.def(v1), Operand::c32(cluster_size - 1),
src_lane, tid);
tmp = emit_bpermute(ctx, bld, src_lane, src);
}
}
tmp = emit_extract_vector(ctx, tmp, 0, dst.regClass());
bld.copy(Definition(dst), tmp);
set_wqm(ctx);
break;
}
case nir_intrinsic_load_sample_id: {
bld.vop3(aco_opcode::v_bfe_u32, Definition(get_ssa_temp(ctx, &instr->def)),
get_arg(ctx, ctx->args->ancillary), Operand::c32(8u), Operand::c32(4u));

View file

@ -552,6 +552,7 @@ init_context(isel_context* ctx, nir_shader* shader)
case nir_intrinsic_quad_swap_diagonal:
case nir_intrinsic_quad_swizzle_amd:
case nir_intrinsic_masked_swizzle_amd:
case nir_intrinsic_rotate:
case nir_intrinsic_inclusive_scan:
case nir_intrinsic_exclusive_scan:
case nir_intrinsic_reduce: