amd: switch to derivative intrinsics

Signed-off-by: Alyssa Rosenzweig <alyssa@rosenzweig.io>
Reviewed-by: Georg Lehmann <dadschoorse@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/30565>
This commit is contained in:
Alyssa Rosenzweig 2024-07-25 09:44:36 -04:00 committed by Marge Bot
parent 048173a55a
commit daa97bb41a
6 changed files with 113 additions and 122 deletions

View file

@ -378,27 +378,14 @@ move_tex_coords(struct move_tex_coords_state *state, nir_function_impl *impl, ni
}
static bool
move_fddxy(struct move_tex_coords_state *state, nir_function_impl *impl, nir_alu_instr *instr)
move_ddxy(struct move_tex_coords_state *state, nir_function_impl *impl, nir_intrinsic_instr *instr)
{
switch (instr->op) {
case nir_op_fddx:
case nir_op_fddy:
case nir_op_fddx_fine:
case nir_op_fddy_fine:
case nir_op_fddx_coarse:
case nir_op_fddy_coarse:
break;
default:
return false;
}
unsigned num_components = instr->def.num_components;
nir_scalar components[NIR_MAX_VEC_COMPONENTS];
coord_info infos[NIR_MAX_VEC_COMPONENTS];
bool can_move_all = true;
for (unsigned i = 0; i < num_components; i++) {
components[i] = nir_scalar_chase_alu_src(nir_get_scalar(&instr->def, i), 0);
components[i] = nir_scalar_chase_movs(components[i]);
components[i] = nir_scalar_resolved(instr->src[0].ssa, i);
can_move_all &= can_move_coord(components[i], &infos[i]);
}
if (!can_move_all || state->num_wqm_vgprs + num_components > state->options->max_wqm_vgprs)
@ -410,7 +397,8 @@ move_fddxy(struct move_tex_coords_state *state, nir_function_impl *impl, nir_alu
}
nir_def *def = nir_vec_scalars(&state->toplevel_b, components, num_components);
def = nir_build_alu1(&state->toplevel_b, instr->op, def);
def = _nir_build_ddx(&state->toplevel_b, def->bit_size, def);
nir_instr_as_intrinsic(def->parent_instr)->intrinsic = instr->intrinsic;
nir_def_rewrite_uses(&instr->def, def);
state->num_wqm_vgprs += num_components;
@ -436,8 +424,6 @@ move_coords_from_divergent_cf(struct move_tex_coords_state *state, nir_function_
if (instr->type == nir_instr_type_tex && (divergent_cf || *divergent_discard)) {
progress |= move_tex_coords(state, impl, instr);
} else if (instr->type == nir_instr_type_alu && (divergent_cf || *divergent_discard)) {
progress |= move_fddxy(state, impl, nir_instr_as_alu(instr));
} else if (instr->type == nir_instr_type_intrinsic) {
nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
switch (intrin->intrinsic) {
@ -449,6 +435,15 @@ move_coords_from_divergent_cf(struct move_tex_coords_state *state, nir_function_
if (divergent_cf || nir_src_is_divergent(intrin->src[0]))
*divergent_discard = true;
break;
case nir_intrinsic_ddx:
case nir_intrinsic_ddy:
case nir_intrinsic_ddx_fine:
case nir_intrinsic_ddy_fine:
case nir_intrinsic_ddx_coarse:
case nir_intrinsic_ddy_coarse:
if (divergent_cf || *divergent_discard)
progress |= move_ddxy(state, impl, intrin);
break;
default:
break;
}

View file

@ -101,6 +101,8 @@ void ac_set_nir_options(struct radeon_info *info, bool use_llvm,
nir_io_prefer_scalar_fs_inputs |
nir_io_mix_convergent_flat_with_interpolated |
nir_io_vectorizer_ignores_types;
options->has_ddx_intrinsics = true;
options->scalarize_ddx = true;
}
bool

View file

@ -4166,84 +4166,6 @@ visit_alu_instr(isel_context* ctx, nir_alu_instr* instr)
bld.vopc(op, Definition(dst), Operand::c32(0), res);
break;
}
case nir_op_fddx:
case nir_op_fddy:
case nir_op_fddx_fine:
case nir_op_fddy_fine:
case nir_op_fddx_coarse:
case nir_op_fddy_coarse: {
uint16_t dpp_ctrl1, dpp_ctrl2;
if (instr->op == nir_op_fddx_fine) {
dpp_ctrl1 = dpp_quad_perm(0, 0, 2, 2);
dpp_ctrl2 = dpp_quad_perm(1, 1, 3, 3);
} else if (instr->op == nir_op_fddy_fine) {
dpp_ctrl1 = dpp_quad_perm(0, 1, 0, 1);
dpp_ctrl2 = dpp_quad_perm(2, 3, 2, 3);
} else {
dpp_ctrl1 = dpp_quad_perm(0, 0, 0, 0);
if (instr->op == nir_op_fddx || instr->op == nir_op_fddx_coarse)
dpp_ctrl2 = dpp_quad_perm(1, 1, 1, 1);
else
dpp_ctrl2 = dpp_quad_perm(2, 2, 2, 2);
}
if (dst.regClass() == v1 && instr->def.bit_size == 16) {
assert(instr->def.num_components == 2);
Temp src = as_vgpr(ctx, get_alu_src_vop3p(ctx, instr->src[0]));
/* swizzle to opsel: all swizzles are either 0 (x) or 1 (y) */
unsigned opsel_lo = instr->src[0].swizzle[0] & 1;
unsigned opsel_hi = instr->src[0].swizzle[1] & 1;
opsel_lo |= opsel_lo << 1;
opsel_hi |= opsel_hi << 1;
Temp tl = src;
if (nir_src_is_divergent(instr->src[0].src))
tl = bld.vop1_dpp(aco_opcode::v_mov_b32, bld.def(v1), src, dpp_ctrl1);
Builder::Result sub =
bld.vop3p(aco_opcode::v_pk_add_f16, bld.def(v1), src, tl, opsel_lo, opsel_hi);
sub->valu().neg_lo[1] = true;
sub->valu().neg_hi[1] = true;
if (nir_src_is_divergent(instr->src[0].src))
bld.vop1_dpp(aco_opcode::v_mov_b32, Definition(dst), sub, dpp_ctrl2);
else
bld.copy(Definition(dst), sub);
emit_split_vector(ctx, dst, 2);
} else {
Temp src = as_vgpr(ctx, get_alu_src(ctx, instr->src[0]));
aco_opcode subrev =
instr->def.bit_size == 16 ? aco_opcode::v_subrev_f16 : aco_opcode::v_subrev_f32;
bool use_interp = dpp_ctrl1 == dpp_quad_perm(0, 0, 0, 0) && instr->def.bit_size == 32 &&
ctx->program->gfx_level >= GFX11_5;
if (!nir_src_is_divergent(instr->src[0].src)) {
bld.vop2(subrev, Definition(dst), src, src);
} else if (use_interp && dpp_ctrl2 == dpp_quad_perm(1, 1, 1, 1)) {
bld.vinterp_inreg(aco_opcode::v_interp_p10_f32_inreg, Definition(dst), src,
Operand::c32(0x3f800000), src)
->valu()
.neg[2] = true;
} else if (use_interp && dpp_ctrl2 == dpp_quad_perm(2, 2, 2, 2)) {
Builder::Result tmp = bld.vinterp_inreg(aco_opcode::v_interp_p10_f32_inreg, bld.def(v1),
Operand::c32(0), Operand::c32(0), src);
tmp->valu().neg = 0x6;
bld.vinterp_inreg(aco_opcode::v_interp_p2_f32_inreg, Definition(dst), src,
Operand::c32(0x3f800000), tmp);
} else if (ctx->program->gfx_level >= GFX8) {
Temp tmp = bld.vop2_dpp(subrev, bld.def(v1), src, src, dpp_ctrl1);
bld.vop1_dpp(aco_opcode::v_mov_b32, Definition(dst), tmp, dpp_ctrl2);
} else {
Temp tl = bld.ds(aco_opcode::ds_swizzle_b32, bld.def(v1), src, (1 << 15) | dpp_ctrl1);
Temp tr = bld.ds(aco_opcode::ds_swizzle_b32, bld.def(v1), src, (1 << 15) | dpp_ctrl2);
bld.vop2(subrev, Definition(dst), tl, tr);
}
}
set_wqm(ctx, true);
break;
}
default: isel_err(&instr->instr, "Unknown NIR ALU instr");
}
}
@ -8573,6 +8495,83 @@ visit_intrinsic(isel_context* ctx, nir_intrinsic_instr* instr)
}
break;
}
case nir_intrinsic_ddx:
case nir_intrinsic_ddy:
case nir_intrinsic_ddx_fine:
case nir_intrinsic_ddy_fine:
case nir_intrinsic_ddx_coarse:
case nir_intrinsic_ddy_coarse: {
Temp src = as_vgpr(ctx, get_ssa_temp(ctx, instr->src[0].ssa));
Temp dst = get_ssa_temp(ctx, &instr->def);
uint16_t dpp_ctrl1, dpp_ctrl2;
if (instr->intrinsic == nir_intrinsic_ddx_fine) {
dpp_ctrl1 = dpp_quad_perm(0, 0, 2, 2);
dpp_ctrl2 = dpp_quad_perm(1, 1, 3, 3);
} else if (instr->intrinsic == nir_intrinsic_ddy_fine) {
dpp_ctrl1 = dpp_quad_perm(0, 1, 0, 1);
dpp_ctrl2 = dpp_quad_perm(2, 3, 2, 3);
} else {
dpp_ctrl1 = dpp_quad_perm(0, 0, 0, 0);
if (instr->intrinsic == nir_intrinsic_ddx ||
instr->intrinsic == nir_intrinsic_ddx_coarse)
dpp_ctrl2 = dpp_quad_perm(1, 1, 1, 1);
else
dpp_ctrl2 = dpp_quad_perm(2, 2, 2, 2);
}
if (dst.regClass() == v1 && instr->def.bit_size == 16) {
assert(instr->def.num_components == 2);
/* identify swizzle to opsel */
unsigned opsel_lo = 0b00;
unsigned opsel_hi = 0b11;
Temp tl = src;
if (nir_src_is_divergent(instr->src[0]))
tl = bld.vop1_dpp(aco_opcode::v_mov_b32, bld.def(v1), src, dpp_ctrl1);
Builder::Result sub =
bld.vop3p(aco_opcode::v_pk_add_f16, bld.def(v1), src, tl, opsel_lo, opsel_hi);
sub->valu().neg_lo[1] = true;
sub->valu().neg_hi[1] = true;
if (nir_src_is_divergent(instr->src[0]))
bld.vop1_dpp(aco_opcode::v_mov_b32, Definition(dst), sub, dpp_ctrl2);
else
bld.copy(Definition(dst), sub);
emit_split_vector(ctx, dst, 2);
} else {
aco_opcode subrev =
instr->def.bit_size == 16 ? aco_opcode::v_subrev_f16 : aco_opcode::v_subrev_f32;
bool use_interp = dpp_ctrl1 == dpp_quad_perm(0, 0, 0, 0) && instr->def.bit_size == 32 &&
ctx->program->gfx_level >= GFX11_5;
if (!nir_src_is_divergent(instr->src[0])) {
bld.vop2(subrev, Definition(dst), src, src);
} else if (use_interp && dpp_ctrl2 == dpp_quad_perm(1, 1, 1, 1)) {
bld.vinterp_inreg(aco_opcode::v_interp_p10_f32_inreg, Definition(dst), src,
Operand::c32(0x3f800000), src)
->valu()
.neg[2] = true;
} else if (use_interp && dpp_ctrl2 == dpp_quad_perm(2, 2, 2, 2)) {
Builder::Result tmp = bld.vinterp_inreg(aco_opcode::v_interp_p10_f32_inreg, bld.def(v1),
Operand::c32(0), Operand::c32(0), src);
tmp->valu().neg = 0x6;
bld.vinterp_inreg(aco_opcode::v_interp_p2_f32_inreg, Definition(dst), src,
Operand::c32(0x3f800000), tmp);
} else if (ctx->program->gfx_level >= GFX8) {
Temp tmp = bld.vop2_dpp(subrev, bld.def(v1), src, src, dpp_ctrl1);
bld.vop1_dpp(aco_opcode::v_mov_b32, Definition(dst), tmp, dpp_ctrl2);
} else {
Temp tl = bld.ds(aco_opcode::ds_swizzle_b32, bld.def(v1), src, (1 << 15) | dpp_ctrl1);
Temp tr = bld.ds(aco_opcode::ds_swizzle_b32, bld.def(v1), src, (1 << 15) | dpp_ctrl2);
bld.vop2(subrev, Definition(dst), tl, tr);
}
}
set_wqm(ctx, true);
break;
}
case nir_intrinsic_load_subgroup_invocation: {
emit_mbcnt(ctx, get_ssa_temp(ctx, &instr->def));
break;

View file

@ -337,12 +337,6 @@ init_context(isel_context* ctx, nir_shader* shader)
case nir_op_pack_snorm_2x16:
case nir_op_pack_uint_2x16:
case nir_op_pack_sint_2x16:
case nir_op_fddx:
case nir_op_fddy:
case nir_op_fddx_fine:
case nir_op_fddy_fine:
case nir_op_fddx_coarse:
case nir_op_fddy_coarse:
case nir_op_ldexp:
case nir_op_frexp_sig:
case nir_op_frexp_exp:
@ -530,6 +524,14 @@ init_context(isel_context* ctx, nir_shader* shader)
case nir_intrinsic_load_global_amd:
type = intrinsic->def.divergent ? RegType::vgpr : RegType::sgpr;
break;
case nir_intrinsic_ddx:
case nir_intrinsic_ddy:
case nir_intrinsic_ddx_fine:
case nir_intrinsic_ddy_fine:
case nir_intrinsic_ddx_coarse:
case nir_intrinsic_ddy_coarse:
type = RegType::vgpr;
break;
case nir_intrinsic_load_view_index:
type = ctx->stage == fragment_fs ? RegType::vgpr : RegType::sgpr;
break;

View file

@ -462,13 +462,7 @@ aco_nir_op_supports_packed_math_16bit(const nir_alu_instr* alu)
case nir_op_imin:
case nir_op_imax:
case nir_op_umin:
case nir_op_umax:
case nir_op_fddx:
case nir_op_fddy:
case nir_op_fddx_fine:
case nir_op_fddy_fine:
case nir_op_fddx_coarse:
case nir_op_fddy_coarse: return true;
case nir_op_umax: return true;
case nir_op_ishl: /* TODO: in NIR, these have 32bit shift operands */
case nir_op_ishr: /* while Radeon needs 16bit operands when vectorized */
case nir_op_ushr:

View file

@ -404,21 +404,21 @@ static LLVMValueRef emit_unpack_half_2x16(struct ac_llvm_context *ctx, LLVMValue
return ac_build_gather_values(ctx, temps, 2);
}
static LLVMValueRef emit_ddxy(struct ac_nir_context *ctx, nir_op op, LLVMValueRef src0)
static LLVMValueRef emit_ddxy(struct ac_nir_context *ctx, nir_intrinsic_op op, LLVMValueRef src0)
{
unsigned mask;
int idx;
LLVMValueRef result;
if (op == nir_op_fddx_fine)
if (op == nir_intrinsic_ddx_fine)
mask = AC_TID_MASK_LEFT;
else if (op == nir_op_fddy_fine)
else if (op == nir_intrinsic_ddy_fine)
mask = AC_TID_MASK_TOP;
else
mask = AC_TID_MASK_TOP_LEFT;
/* for DDX we want to next X pixel, DDY next Y pixel. */
if (op == nir_op_fddx_fine || op == nir_op_fddx_coarse || op == nir_op_fddx)
if (op == nir_intrinsic_ddx_fine || op == nir_intrinsic_ddx_coarse || op == nir_intrinsic_ddx)
idx = 1;
else
idx = 2;
@ -1081,15 +1081,6 @@ static bool visit_alu(struct ac_nir_context *ctx, const nir_alu_instr *instr)
result = LLVMBuildExtractElement(ctx->ac.builder, tmp, ctx->ac.i32_1, "");
break;
}
case nir_op_fddx:
case nir_op_fddy:
case nir_op_fddx_fine:
case nir_op_fddy_fine:
case nir_op_fddx_coarse:
case nir_op_fddy_coarse:
result = emit_ddxy(ctx, instr->op, src[0]);
break;
case nir_op_unpack_64_4x16: {
result = LLVMBuildBitCast(ctx->ac.builder, src[0], ctx->ac.v4i16, "");
break;
@ -2884,6 +2875,14 @@ static bool visit_intrinsic(struct ac_nir_context *ctx, nir_intrinsic_instr *ins
LLVMValueRef result = NULL;
switch (instr->intrinsic) {
case nir_intrinsic_ddx:
case nir_intrinsic_ddy:
case nir_intrinsic_ddx_fine:
case nir_intrinsic_ddy_fine:
case nir_intrinsic_ddx_coarse:
case nir_intrinsic_ddy_coarse:
result = emit_ddxy(ctx, instr->intrinsic, get_src(ctx, instr->src[0]));
break;
case nir_intrinsic_ballot:
case nir_intrinsic_ballot_relaxed:
result = ac_build_ballot(&ctx->ac, get_src(ctx, instr->src[0]));