aco,ac/nir: flag loads to use smem in NIR

This pass will be re-used later.

Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Georg Lehmann <dadschoorse@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/31904>
This commit is contained in:
Rhys Perry 2024-10-22 14:37:19 +01:00 committed by Marge Bot
parent 7fe4f4c14c
commit d3ae1842a2
7 changed files with 73 additions and 21 deletions

View file

@ -1706,3 +1706,53 @@ ac_nir_varying_estimate_instr_cost(nir_instr *instr)
unreachable("unexpected instr type");
}
}
typedef struct {
enum amd_gfx_level gfx_level;
bool use_llvm;
bool after_lowering;
} mem_access_cb_data;
static bool
use_smem_for_load(nir_builder *b, nir_intrinsic_instr *intrin, void *cb_data_)
{
const mem_access_cb_data *cb_data = (mem_access_cb_data *)cb_data_;
switch (intrin->intrinsic) {
case nir_intrinsic_load_ssbo:
case nir_intrinsic_load_global:
case nir_intrinsic_load_global_constant:
case nir_intrinsic_load_global_amd:
case nir_intrinsic_load_constant:
if (cb_data->use_llvm)
return false;
break;
case nir_intrinsic_load_ubo:
break;
default:
return false;
}
if (intrin->def.divergent || (cb_data->after_lowering && intrin->def.bit_size < 32))
return false;
enum gl_access_qualifier access = nir_intrinsic_access(intrin);
bool glc = access & (ACCESS_VOLATILE | ACCESS_COHERENT);
bool reorder = nir_intrinsic_can_reorder(intrin) || ((access & ACCESS_NON_WRITEABLE) && !(access & ACCESS_VOLATILE));
if (!reorder || (glc && cb_data->gfx_level < GFX8))
return false;
nir_intrinsic_set_access(intrin, access | ACCESS_SMEM_AMD);
return true;
}
bool
ac_nir_flag_smem_for_loads(nir_shader *shader, enum amd_gfx_level gfx_level, bool use_llvm, bool after_lowering)
{
mem_access_cb_data cb_data = {
.gfx_level = gfx_level,
.use_llvm = use_llvm,
.after_lowering = after_lowering,
};
return nir_shader_intrinsics_pass(shader, &use_smem_for_load, nir_metadata_all, &cb_data);
}

View file

@ -323,6 +323,9 @@ ac_nir_varying_estimate_instr_cost(nir_instr *instr);
bool
ac_nir_opt_shared_append(nir_shader *shader);
bool
ac_nir_flag_smem_for_loads(nir_shader *shader, enum amd_gfx_level gfx_level, bool use_llvm, bool after_lowering);
#ifdef __cplusplus
}
#endif

View file

@ -6020,13 +6020,10 @@ load_buffer(isel_context* ctx, unsigned num_components, unsigned component_size,
{
Builder bld(ctx->program, ctx->block);
bool glc = access & (ACCESS_VOLATILE | ACCESS_COHERENT);
bool use_smem = dst.type() != RegType::vgpr && (ctx->options->gfx_level >= GFX8 || !glc) &&
(access & ACCESS_CAN_REORDER);
if (use_smem)
bool use_smem = access & ACCESS_SMEM_AMD;
if (use_smem) {
offset = bld.as_uniform(offset);
else {
} else {
/* GFX6-7 are affected by a hw bug that prevents address clamping to
* work correctly when the SGPR offset is used.
*/
@ -6054,7 +6051,8 @@ visit_load_ubo(isel_context* ctx, nir_intrinsic_instr* instr)
unsigned size = instr->def.bit_size / 8;
load_buffer(ctx, instr->num_components, size, dst, rsrc, get_ssa_temp(ctx, instr->src[1].ssa),
nir_intrinsic_align_mul(instr), nir_intrinsic_align_offset(instr));
nir_intrinsic_align_mul(instr), nir_intrinsic_align_offset(instr),
nir_intrinsic_access(instr) | ACCESS_CAN_REORDER);
}
void
@ -6084,7 +6082,7 @@ visit_load_constant(isel_context* ctx, nir_intrinsic_instr* instr)
Operand::c32(desc[3]));
unsigned size = instr->def.bit_size / 8;
load_buffer(ctx, instr->num_components, size, dst, rsrc, offset, nir_intrinsic_align_mul(instr),
nir_intrinsic_align_offset(instr));
nir_intrinsic_align_offset(instr), nir_intrinsic_access(instr) | ACCESS_CAN_REORDER);
}
/* Packs multiple Temps of different sizes in to a vector of v1 Temps.
@ -6921,23 +6919,17 @@ visit_load_global(isel_context* ctx, nir_intrinsic_instr* instr)
num_components, component_size, align, false);
unsigned access = nir_intrinsic_access(instr) | ACCESS_TYPE_LOAD;
bool glc = access & (ACCESS_VOLATILE | ACCESS_COHERENT);
/* VMEM stores don't update the SMEM cache and it's difficult to prove that
* it's safe to use SMEM */
bool can_use_smem = (access & ACCESS_NON_WRITEABLE) && byte_align_for_smem;
if (info.dst.type() == RegType::vgpr || (ctx->options->gfx_level < GFX8 && glc) ||
!can_use_smem) {
EmitLoadParameters params = global_load_params;
params.byte_align_loads = byte_align_for_vmem;
info.cache = get_cache_flags(ctx, access);
emit_load(ctx, bld, info, params);
} else {
if ((access & ACCESS_SMEM_AMD) && byte_align_for_smem) {
if (info.resource.id())
info.resource = bld.as_uniform(info.resource);
info.offset = Operand(bld.as_uniform(info.offset));
info.cache = get_cache_flags(ctx, access | ACCESS_TYPE_SMEM);
emit_load(ctx, bld, info, smem_load_params);
} else {
EmitLoadParameters params = global_load_params;
params.byte_align_loads = byte_align_for_vmem;
info.cache = get_cache_flags(ctx, access);
emit_load(ctx, bld, info, params);
}
}

View file

@ -376,6 +376,7 @@ init_context(isel_context* ctx, nir_shader* shader)
}
apply_nuw_to_offsets(ctx, impl);
ac_nir_flag_smem_for_loads(shader, ctx->program->gfx_level, false, true);
/* sanitize control flow */
sanitize_cf_list(impl, &impl->body);

View file

@ -1186,7 +1186,7 @@ load("task_payload", [1], [BASE, ALIGN_MUL, ALIGN_OFFSET], [CAN_ELIMINATE])
# src[] = { offset }.
load("push_constant", [1], [BASE, RANGE, ALIGN_MUL, ALIGN_OFFSET], [CAN_ELIMINATE, CAN_REORDER])
# src[] = { offset }.
load("constant", [1], [BASE, RANGE, ALIGN_MUL, ALIGN_OFFSET],
load("constant", [1], [BASE, RANGE, ACCESS, ALIGN_MUL, ALIGN_OFFSET],
[CAN_ELIMINATE, CAN_REORDER])
# src[] = { address }.
load("global", [1], [ACCESS, ALIGN_MUL, ALIGN_OFFSET], [CAN_ELIMINATE])

View file

@ -810,6 +810,7 @@ print_access(enum gl_access_qualifier access, print_state *state, const char *se
{ ACCESS_CP_GE_COHERENT_AMD, "cp-ge-coherent-amd" },
{ ACCESS_IN_BOUNDS_AGX, "in-bounds-agx" },
{ ACCESS_KEEP_SCALAR, "keep-scalar" },
{ ACCESS_SMEM_AMD, "smem-amd" },
};
bool first = true;

View file

@ -1159,6 +1159,11 @@ enum gl_access_qualifier
* shader where the API wants to copy all bytes that are resident.
*/
ACCESS_KEEP_SCALAR = (1 << 15),
/**
* Indicates that this load will use SMEM.
*/
ACCESS_SMEM_AMD = (1 << 16),
};
/**