ac/llvm: port load_smem_amd behavior to load_global_amd

Reviewed-by: Daniel Schürmann <daniel@schuermann.dev>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/37101>
This commit is contained in:
Marek Olšák 2025-08-08 16:50:46 -04:00
parent f5ba2d3e8f
commit 9a33c03654

View file

@ -1890,9 +1890,24 @@ static LLVMValueRef get_global_address(struct ac_nir_context *ctx,
nir_intrinsic_instr *instr)
{
bool is_store = instr->intrinsic == nir_intrinsic_store_global_amd;
LLVMValueRef addr = get_src(ctx, instr->src[is_store ? 1 : 0]);
nir_src addr_src = instr->src[is_store ? 1 : 0];
LLVMValueRef addr = get_src(ctx, addr_src);
LLVMTypeRef ptr_type = LLVMPointerTypeInContext(ctx->ac.context, AC_ADDR_SPACE_GLOBAL);
unsigned address_space;
if (nir_intrinsic_access(instr) & ACCESS_SMEM_AMD) {
assert(!is_store);
if (addr_src.ssa->bit_size == 64)
address_space = AC_ADDR_SPACE_CONST;
else if (addr_src.ssa->bit_size == 32)
address_space = AC_ADDR_SPACE_CONST_32BIT;
else
UNREACHABLE("invalid global address bit size");
} else {
assert(addr_src.ssa->bit_size == 64);
address_space = AC_ADDR_SPACE_GLOBAL;
}
LLVMTypeRef ptr_type = LLVMPointerTypeInContext(ctx->ac.context, address_space);
uint32_t base = nir_intrinsic_base(instr);
unsigned num_src = nir_intrinsic_infos[instr->intrinsic].num_srcs;
@ -1900,7 +1915,15 @@ static LLVMValueRef get_global_address(struct ac_nir_context *ctx,
offset = LLVMBuildAdd(ctx->ac.builder, offset, LLVMConstInt(ctx->ac.i32, base, false), "");
addr = LLVMBuildIntToPtr(ctx->ac.builder, addr, ptr_type, "");
return LLVMBuildGEP2(ctx->ac.builder, ctx->ac.i8, addr, &offset, 1, "");
if (addr_src.ssa->bit_size == 32)
addr = LLVMBuildInBoundsGEP2(ctx->ac.builder, ctx->ac.i8, addr, &offset, 1, "");
else
addr = LLVMBuildGEP2(ctx->ac.builder, ctx->ac.i8, addr, &offset, 1, "");
if (nir_intrinsic_access(instr) & ACCESS_SMEM_AMD)
LLVMSetMetadata(addr, ctx->ac.uniform_md_kind, ctx->ac.empty_md);
return addr;
}
static LLVMValueRef get_memory_addr(struct ac_nir_context *ctx, nir_intrinsic_instr *intr)
@ -1929,12 +1952,17 @@ static void set_mem_op_alignment(LLVMValueRef instr, nir_intrinsic_instr *intr,
LLVMSetAlignment(instr, MIN2(align, pot_size));
}
static void set_coherent_volatile(LLVMValueRef instr, nir_intrinsic_instr *intr)
static void set_access_flags(struct ac_nir_context *ctx, LLVMValueRef instr,
nir_intrinsic_instr *intr)
{
if ((intr->intrinsic == nir_intrinsic_load_global_amd ||
intr->intrinsic == nir_intrinsic_store_global_amd) &&
nir_intrinsic_access(intr) & (ACCESS_COHERENT | ACCESS_VOLATILE))
LLVMSetOrdering(instr, LLVMAtomicOrderingMonotonic);
if (intr->intrinsic == nir_intrinsic_load_global_amd &&
nir_intrinsic_access(intr) & ACCESS_CAN_REORDER)
LLVMSetMetadata(instr, ctx->ac.invariant_load_md_kind, ctx->ac.empty_md);
}
static void visit_store(struct ac_nir_context *ctx, nir_intrinsic_instr *intr)
@ -1948,7 +1976,7 @@ static void visit_store(struct ac_nir_context *ctx, nir_intrinsic_instr *intr)
LLVMValueRef store = LLVMBuildStore(builder, data, ptr);
set_mem_op_alignment(store, intr, intr->src[0].ssa);
set_coherent_volatile(store, intr);
set_access_flags(ctx, store, intr);
}
static LLVMValueRef visit_load(struct ac_nir_context *ctx, nir_intrinsic_instr *intr)
@ -1958,7 +1986,7 @@ static LLVMValueRef visit_load(struct ac_nir_context *ctx, nir_intrinsic_instr *
LLVMValueRef value = LLVMBuildLoad2(ctx->ac.builder, result_type, ptr, "");
set_mem_op_alignment(value, intr, &intr->def);
set_coherent_volatile(value, intr);
set_access_flags(ctx, value, intr);
return value;
}