nak: Lower MSAA image load/store/atomic/size

This is designed to pair with the MSAA storage descriptor handling we
added to NIL.

Reviewed-by: Mary Guillemard <mary.guillemard@collabora.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/36207>
This commit is contained in:
Faith Ekstrand 2025-07-17 16:19:42 -04:00 committed by Marge Bot
parent 3a22117e56
commit 65d836fb26

View file

@ -423,12 +423,19 @@ lower_tex(nir_builder *b, nir_tex_instr *tex, const struct nak_compiler *nak)
return true;
}
static nir_def *
build_txq_samples_raw(nir_builder *b, nir_def *img_h,
const struct nak_compiler *nak)
{
nir_def *res = build_txq(b, nir_texop_tex_type_nv, img_h, NULL, nak);
return nir_channel(b, res, 2);
}
static nir_def *
build_txq_samples(nir_builder *b, nir_def *img_h,
const struct nak_compiler *nak)
{
nir_def *res = build_txq(b, nir_texop_tex_type_nv, img_h, NULL, nak);
res = nir_channel(b, res, 2);
nir_def *res = build_txq_samples_raw(b, img_h, nak);
if (!has_null_descriptors(nak)) {
res = nir_bcsel(b, build_img_is_null(b, img_h, nak),
@ -659,6 +666,87 @@ shrink_image_store(nir_builder *b, nir_intrinsic_instr *intrin,
return true;
}
static nir_def *
build_px_size_sa_log2(nir_builder *b, nir_def *samples)
{
nir_def *samples_log2 = nir_ufind_msb(b, samples);
/* Map from samples_log2 to pixels per sample (log2):
*
* 0 -> (0, 0)
* 1 -> (1, 0)
* 2 -> (1, 1)
* 3 -> (2, 1)
* 4 -> (2, 2)
*
* so
*
* h_log2 = samples_log2 / 2
* w_log2 = (samples_log2 + 1) / 2 = samples_log2 - h_log2
*/
nir_def *h_log2 = nir_udiv_imm(b, samples_log2, 2);
nir_def *w_log2 = nir_isub(b, samples_log2, h_log2);
return nir_vec2(b, w_log2, h_log2);
}
static bool
lower_msaa_image_access(nir_builder *b, nir_intrinsic_instr *intrin,
const struct nak_compiler *nak)
{
assert(nir_intrinsic_image_dim(intrin) == GLSL_SAMPLER_DIM_MS);
b->cursor = nir_before_instr(&intrin->instr);
nir_def *img_h = intrin->src[0].ssa;
nir_def *x = nir_channel(b, intrin->src[1].ssa, 0);
nir_def *y = nir_channel(b, intrin->src[1].ssa, 1);
nir_def *z = nir_channel(b, intrin->src[1].ssa, 2);
nir_def *w = nir_channel(b, intrin->src[1].ssa, 3);
nir_def *s = intrin->src[2].ssa;
nir_def *samples = build_txq_samples_raw(b, img_h, nak);
nir_def *px_size_sa_log2 = build_px_size_sa_log2(b, samples);
nir_def *px_w_log2 = nir_channel(b, px_size_sa_log2, 0);
nir_def *px_h_log2 = nir_channel(b, px_size_sa_log2, 1);
/* Compute the x/y offsets
*
* txq.sampler_pos gives us the sample coordinates as a signed 4.12 fixed
* point with x in the bottom 16 bits and y in the top 16 bits.
*/
nir_def *spos_sf = build_txq(b, nir_texop_sample_pos_nv, img_h, s, nak);
spos_sf = nir_trim_vector(b, spos_sf, 2);
/* Fortunately, the samples are laid out in the supersampled image the same
* as the sample locations, rounded to an integer sample offset. So we
* just have to figure out which samples each of those hits in the 2D grid.
*
* Add 0x0800 to convert from signed 4.12 fixed-point centered around 0 to
* unsigned 4.12 fixed point. Then shift by 12 - px_sz_log2 to divide off
* the extra, leaving an integer offset. It's safe to do it all in one add
* because we know a priori that the low 8 bits of each sample position are
* zero so any overflow in the low 16 bits will just set a 1 in bit 16
* which will get shifted away.
*/
nir_def *spos_uf = nir_iadd_imm(b, spos_sf, 0x08000800);
nir_def *sx = nir_ushr(b, nir_iand_imm(b, spos_uf, 0xffff),
nir_isub_imm(b, 12, px_w_log2));
nir_def *sy = nir_ushr(b, spos_uf, nir_isub_imm(b, 28, px_h_log2));
/* Add in the sample offsets */
x = nir_iadd(b, nir_ishl(b, x, px_w_log2), sx);
y = nir_iadd(b, nir_ishl(b, y, px_h_log2), sy);
/* Smash x negative if s > samples to get OOB behavior */
x = nir_bcsel(b, nir_ult(b, s, samples), x, nir_imm_int(b, -1));
nir_intrinsic_set_image_dim(intrin, GLSL_SAMPLER_DIM_2D);
nir_src_rewrite(&intrin->src[1], nir_vec4(b, x, y, z, w));
nir_src_rewrite(&intrin->src[2], nir_undef(b, 1, 32));
return true;
}
static bool
lower_image_txq(nir_builder *b, nir_intrinsic_instr *intrin,
const struct nak_compiler *nak)
@ -672,6 +760,18 @@ lower_image_txq(nir_builder *b, nir_intrinsic_instr *intrin,
case nir_intrinsic_bindless_image_size:
res = build_txq_size(b, intrin->def.num_components, img_h,
intrin->src[1].ssa /* lod */, nak);
if (nir_intrinsic_image_dim(intrin) == GLSL_SAMPLER_DIM_MS) {
/* When NIL sets up the MSAA image descriptor, it uses a width and
* height in samples, rather than pixels because sust/ld/atom ignore
* the sample count and blindly bounds check whatever x/y coordinates
* they're given. This means we need to divide back out the pixel
* size in order to get the size in pixels.
*/
nir_def *samples = build_txq_samples_raw(b, img_h, nak);
nir_def *px_size_sa_log2 = build_px_size_sa_log2(b, samples);
res = nir_ushr(b, res, px_size_sa_log2);
}
break;
case nir_intrinsic_bindless_image_samples:
res = build_txq_samples(b, img_h, nak);
@ -715,13 +815,32 @@ lower_tex_instr(nir_builder *b, nir_instr *instr, void *_data)
nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
switch (intrin->intrinsic) {
case nir_intrinsic_bindless_image_load:
case nir_intrinsic_bindless_image_sparse_load:
return shrink_image_load(b, intrin, nak);
case nir_intrinsic_bindless_image_store:
return shrink_image_store(b, intrin, nak);
case nir_intrinsic_bindless_image_sparse_load: {
bool progress = false;
if (nir_intrinsic_image_dim(intrin) == GLSL_SAMPLER_DIM_MS)
progress |= lower_msaa_image_access(b, intrin, nak);
progress |= shrink_image_load(b, intrin, nak);
return progress;
}
case nir_intrinsic_bindless_image_store: {
bool progress = false;
if (nir_intrinsic_image_dim(intrin) == GLSL_SAMPLER_DIM_MS)
progress |= lower_msaa_image_access(b, intrin, nak);
progress |= shrink_image_store(b, intrin, nak);
return progress;
}
case nir_intrinsic_bindless_image_atomic:
case nir_intrinsic_bindless_image_atomic_swap:
if (nir_intrinsic_image_dim(intrin) == GLSL_SAMPLER_DIM_MS)
return lower_msaa_image_access(b, intrin, nak);
return false;
case nir_intrinsic_bindless_image_size:
case nir_intrinsic_bindless_image_samples:
return lower_image_txq(b, intrin, nak);
default:
return false;
}