ac/nir: Add pass to fixup SMEM on GFX6-7

The pass implements two mitigations for the GFX6-7 SMEM bug:

1. To mitigate VM faults by NULL descriptors:

Make sure that SMEM buffer loads always access a mapped BO.
Use either the descriptor BO (or compute scratch BO),
or otherwise use the zero-filled BO in their place.

2. To mitigate VM faults by OOB robust buffer access:

Add an instruction to clamp the offset source to the
num_records field of the descriptor. It will be still
out of bounds, but the VM fault can be completely mitigated
if the driver adds a padding to each memory allocation.

Signed-off-by: Timur Kristóf <timur.kristof@gmail.com>
Reviewed-by: Samuel Pitoiset <samuel.pitoiset@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/38769>
This commit is contained in:
Timur Kristóf 2025-12-19 12:35:28 -06:00 committed by Marge Bot
parent 7023ef4b8b
commit 18b8543026
2 changed files with 155 additions and 0 deletions

View file

@ -414,6 +414,13 @@ ac_nir_opt_shared_append(nir_shader *shader);
bool
ac_nir_flag_smem_for_loads(nir_shader *shader, enum amd_gfx_level gfx_level, bool use_llvm);
bool
ac_nir_fixup_mem_access_gfx6(nir_shader *shader,
struct ac_shader_args *args,
const uint32_t padding_bytes,
const bool fixup_null_desc,
const bool fixup_robust_oob);
bool
ac_nir_lower_mem_access_bit_sizes(nir_shader *shader, enum amd_gfx_level gfx_level, bool use_llvm);

View file

@ -16,6 +16,14 @@ typedef struct {
bool had_terminate;
} mem_access_cb_data;
typedef struct {
struct ac_shader_args *args;
struct hash_table *range_ht;
uint32_t padding_bytes;
bool fixup_smem_null_desc;
bool fixup_smem_oob;
} fixup_mem_access_state;
static bool
set_smem_access_flags(nir_builder *b, nir_intrinsic_instr *intrin, void *cb_data_)
{
@ -98,6 +106,146 @@ ac_nir_flag_smem_for_loads(nir_shader *shader, enum amd_gfx_level gfx_level, boo
return nir_shader_intrinsics_pass(shader, &set_smem_access_flags, nir_metadata_all, &cb_data);
}
/* Mitigate out of bounds SMEM access from NULL and mutable descriptors.
* This is necessary because VKD3D-Proton assumes that all descriptor
* types use the same bit pattern for NULL descriptors.
*
* All of this mess is compiled into two SALU instructions per descriptor:
* s_cmp_eq_u32 <dw2>, 0
* s_cselect_b64 <dw0:1>, s[0:1], <dw0:1>
*/
static bool
fixup_smem_null_descriptor_gfx6(nir_builder *b, nir_intrinsic_instr *intrin, fixup_mem_access_state *state)
{
nir_def *desc = intrin->src[0].ssa;
b->cursor = nir_after_def(desc);
/* Use the descriptor BO (or compute scratch BO) as dummy address */
nir_def *dummy_0_1 = nir_pack_64_2x32(b, ac_nir_load_arg(b, state->args, state->args->ring_offsets));
/* Get each individual dword of the descriptor */
nir_def *dw0 = nir_channel(b, desc, 0);
nir_def *dw1 = nir_channel(b, desc, 1);
nir_def *dw2 = nir_channel(b, desc, 2);
nir_def *dw3 = nir_channel(b, desc, 3);
/* Pack the address from the descriptor into 64 bits */
nir_def *dw_0_1 = nir_pack_64_2x32_split(b, dw0, dw1);
/* Check if this is a NULL descriptor (based on size) */
nir_def *is_null = nir_ieq_imm(b, dw2, 0);
/* For NULL descriptors, use the dummy address */
dw_0_1 = nir_bcsel(b, is_null, dummy_0_1, dw_0_1);
/* Repack the descriptor into a vec4 */
dw0 = nir_unpack_64_2x32_split_x(b, dw_0_1);
dw1 = nir_unpack_64_2x32_split_y(b, dw_0_1);
nir_def *fixed_desc = nir_vec4(b, dw0, dw1, dw2, dw3);
/* Rewrite all uses of the descriptor (not just SMEM), to reduce SGPR use. */
nir_def_rewrite_uses_after(desc, fixed_desc);
return true;
}
static bool
fixup_smem_robust_oob_gfx6(nir_builder *b, nir_intrinsic_instr *intrin, fixup_mem_access_state *state)
{
nir_def *desc = intrin->src[0].ssa;
nir_def *offs = intrin->src[1].ssa;
/* Bytes loaded by the SMEM instruction */
const uint32_t bytes = (intrin->def.num_components * intrin->def.bit_size) / 8;
/* Find the unsigned upper bound of the offset. This is the
* highest possible offset that the current SMEM instruction
* can use. We know for sure it will not go beyond that.
*/
const uint32_t offset_uub =
nir_unsigned_upper_bound(b->shader, state->range_ht,
nir_scalar_chase_movs(nir_get_scalar(offs, 0)));
/* We allow the SMEM instruction to read beyond
* the allocated BO (so they might read from the padding).
* Verify that the SMEM instruction doesn't read past the
* padding that we add after the virtual address of the BO.
*/
if (offset_uub <= state->padding_bytes - bytes)
return true;
b->cursor = nir_before_instr(&intrin->instr);
/* Number of elements in the buffer (from the descriptor) */
nir_def *num_records = nir_channel(b, desc, 2);
/* Prevent the SMEM instruction from reading past the padding:
* clamp the offset to the number of elements in the buffer.
* This will still be OOB from the perspective of the application,
* but in reality it will just read from the padding.
*/
offs = nir_umin(b, num_records, offs);
nir_src_rewrite(&intrin->src[1], offs);
return true;
}
/* Fixup out of bounds behaviour of SMEM on GFX6-7.
*
* On GFX6-7, SMEM accesses memory even when the access would be out of bounds.
* To mitigate the VM fault, we add an instruction to clamp the offset source to the
* num_records field from the descriptor. The access will be still out of bounds, but
* this way the VM fault can be completely mitigated if the driver adds a padding
* to each memory allocation. The padding needs to be at least 1 page.
*/
static bool
fixup_smem_oob_access_gfx6(nir_builder *b, nir_intrinsic_instr *intrin, void *state_ptr)
{
if (intrin->intrinsic != nir_intrinsic_load_ssbo &&
intrin->intrinsic != nir_intrinsic_load_ubo &&
intrin->intrinsic != nir_intrinsic_load_buffer_amd)
return false;
fixup_mem_access_state *state = (fixup_mem_access_state *)state_ptr;
const unsigned access = nir_intrinsic_access(intrin);
if (!(access & ACCESS_SMEM_AMD) || intrin->src[0].ssa->num_components != 4)
return false;
bool progress = false;
if (state->fixup_smem_null_desc)
progress |= fixup_smem_null_descriptor_gfx6(b, intrin, state);
if (state->fixup_smem_oob)
progress |= fixup_smem_robust_oob_gfx6(b, intrin, state);
return progress;
}
/**
* Fixup memory access issues on old GPUs.
*/
bool
ac_nir_fixup_mem_access_gfx6(nir_shader *shader,
struct ac_shader_args *args,
const uint32_t padding_bytes,
const bool fixup_smem_null_desc,
const bool fixup_smem_oob)
{
fixup_mem_access_state state = {
.range_ht = _mesa_pointer_hash_table_create(NULL),
.padding_bytes = padding_bytes,
.args = args,
.fixup_smem_null_desc = fixup_smem_null_desc,
.fixup_smem_oob = fixup_smem_oob,
};
bool progress =
nir_shader_intrinsics_pass(shader, &fixup_smem_oob_access_gfx6, nir_metadata_all, &state);
_mesa_hash_table_destroy(state.range_ht, NULL);
return progress;
}
static nir_mem_access_size_align
lower_mem_access_cb(nir_intrinsic_op intrin, uint8_t bytes, uint8_t bit_size, uint32_t align_mul, uint32_t align_offset,
bool offset_is_const, enum gl_access_qualifier access, const void *cb_data_)