agx: Add interpolateAtOffset lowering pass

Add a lowering pass that lowers interpolation to math on the coefficient
registers. This handles interpolateAtOffset, as well as flat shading as an easy
special case.

Signed-off-by: Alyssa Rosenzweig <alyssa@rosenzweig.io>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/24498>
This commit is contained in:
Alyssa Rosenzweig 2023-08-04 09:39:55 -04:00 committed by Marge Bot
parent 48029548f3
commit ff0e25d293
3 changed files with 166 additions and 0 deletions

View file

@ -9,6 +9,7 @@
struct nir_shader;
bool agx_nir_lower_interpolation(struct nir_shader *s);
bool agx_nir_opt_ixor_bcsel(struct nir_shader *shader);
bool agx_nir_lower_algebraic_late(struct nir_shader *shader);
bool agx_nir_fuse_algebraic_late(struct nir_shader *shader);

View file

@ -0,0 +1,164 @@
/*
* Copyright 2023 Valve Corporation
* SPDX-License-Identifier: MIT
*/
#include "compiler/shader_enums.h"
#include "agx_nir.h"
#include "nir.h"
#include "nir_builder.h"
#include "nir_builder_opcodes.h"
#include "nir_intrinsics.h"
#include "nir_intrinsics_indices.h"
/*
* In AGX, the values of fragment shader inputs are represented as coefficient
* vectors <A, B, C>, which are dotted with <x, y, 1> to perform interpolation.
* x and y are relative to the tile. In other words, A and B are the
* screen-space partial derivatives of the input, and C is the value at the
* corner of the tile.
*
* For some interpolation modes, the dot product happens in the iterator
* hardware. Other modes are implemented in this file, by lowering to math on
* the coefficient vectors.
*/
/* XXX: It's not clear what this is for, but seems necessary */
static nir_ssa_def *
cf_valid(nir_builder *b, nir_ssa_def *cf)
{
nir_ssa_def *bit =
nir_ieq_imm(b, nir_iand_imm(b, nir_channel(b, cf, 0), 1), 0);
/* XXX: Apple's compiler actually checks that the significand is nonzero and
* the exponent is 0 or 1. This is probably a typo -- it doesn't make any
* logical sense. Presumably they just meant to check for denorms, so let's
* do that. Either way the tests pass.
*/
nir_ssa_def *cf01 = nir_trim_vector(b, cf, 2);
return nir_ior(b, bit, nir_fisnormal(b, cf01));
}
static nir_ssa_def *
interpolate_at_offset(nir_builder *b, nir_ssa_def *cf, nir_ssa_def *offset,
bool perspective)
{
/* Get the coordinate of the pixel within the tile */
nir_ssa_def *pixel_coords = nir_load_pixel_coord(b);
nir_ssa_def *tile_offs = nir_umod_imm(b, pixel_coords, 32);
/* Convert to float, getting the center of the pixel */
nir_ssa_def *center = nir_fadd_imm(b, nir_u2f32(b, tile_offs), 0.5);
/* Calculate the location to interpolate. offset is defined relative to the
* center of the pixel and is a float.
*/
nir_ssa_def *pos = nir_fadd(b, center, nir_f2f32(b, offset));
/* Interpolate with the given coefficients */
nir_ssa_def *interp = nir_ffma(b, nir_channel(b, pos, 1),
nir_channel(b, cf, 1), nir_channel(b, cf, 2));
interp = nir_ffma(b, nir_channel(b, pos, 0), nir_channel(b, cf, 0), interp);
/* Divide by RHW. This load will be lowered recursively. */
if (perspective) {
nir_ssa_def *bary = nir_load_barycentric_at_offset(
b, 32, offset, .interp_mode = INTERP_MODE_NOPERSPECTIVE);
nir_ssa_def *rhw = nir_load_interpolated_input(
b, 1, 32, bary, nir_imm_int(b, 0), .component = 3,
.io_semantics = {
.location = VARYING_SLOT_POS,
.num_slots = 1,
});
interp = nir_fdiv(b, interp, rhw);
}
/* Replace invalid interpolations with the constant channel */
return nir_bcsel(b, cf_valid(b, cf), interp, nir_channel(b, cf, 2));
}
static nir_ssa_def *
interpolate_flat(nir_builder *b, nir_ssa_def *coefficients)
{
/* Same value anywhere, so just take the constant (affine) component */
return nir_channel(b, coefficients, 2);
}
static enum glsl_interp_mode
interp_mode_for_load(nir_intrinsic_instr *load)
{
if (load->intrinsic == nir_intrinsic_load_input)
return INTERP_MODE_FLAT;
else
return nir_intrinsic_interp_mode(nir_src_as_intrinsic(load->src[0]));
}
static bool
needs_lower(const nir_instr *instr, UNUSED const void *_)
{
if (instr->type != nir_instr_type_intrinsic)
return false;
const nir_intrinsic_instr *load = nir_instr_as_intrinsic(instr);
/* at_offset barycentrics need to be lowered */
if (load->intrinsic == nir_intrinsic_load_interpolated_input) {
return (nir_src_as_intrinsic(load->src[0])->intrinsic ==
nir_intrinsic_load_barycentric_at_offset);
}
/* Flat shading always lowered */
return (load->intrinsic == nir_intrinsic_load_input);
}
static nir_ssa_def *
interpolate_channel(nir_builder *b, nir_intrinsic_instr *load, unsigned channel)
{
nir_io_semantics sem = nir_intrinsic_io_semantics(load);
/* Indirect varyings not supported, just bias the location */
sem.location += nir_src_as_uint(*nir_get_io_offset_src(load));
sem.num_slots = 1;
nir_ssa_def *coefficients = nir_load_coefficients_agx(
b, .component = nir_intrinsic_component(load) + channel,
.interp_mode = interp_mode_for_load(load), .io_semantics = sem);
if (load->intrinsic == nir_intrinsic_load_input) {
assert(nir_dest_bit_size(load->dest) == 32);
return interpolate_flat(b, coefficients);
} else {
nir_intrinsic_instr *bary = nir_src_as_intrinsic(load->src[0]);
nir_ssa_def *interp = interpolate_at_offset(
b, coefficients, bary->src[0].ssa,
nir_intrinsic_interp_mode(bary) != INTERP_MODE_NOPERSPECTIVE);
return nir_f2fN(b, interp, nir_dest_bit_size(load->dest));
}
}
static nir_ssa_def *
lower(nir_builder *b, nir_instr *instr, void *data)
{
nir_intrinsic_instr *intr = nir_instr_as_intrinsic(instr);
/* Each component is loaded separated */
nir_ssa_def *values[NIR_MAX_VEC_COMPONENTS] = {NULL};
for (unsigned i = 0; i < nir_dest_num_components(intr->dest); ++i) {
values[i] = interpolate_channel(b, intr, i);
}
return nir_vec(b, values, nir_dest_num_components(intr->dest));
}
bool
agx_nir_lower_interpolation(nir_shader *s)
{
assert(s->info.stage == MESA_SHADER_FRAGMENT);
return nir_shader_lower_instructions(s, needs_lower, lower, NULL);
}

View file

@ -13,6 +13,7 @@ libasahi_agx_files = files(
'agx_nir_lower_sample_mask.c',
'agx_nir_lower_discard_zs_emit.c',
'agx_nir_lower_texture.c',
'agx_nir_lower_interpolation.c',
'agx_nir_lower_load_mask.c',
'agx_nir_lower_shared_bitsize.c',
'agx_nir_lower_ubo.c',