ac/llvm: let ring_offsets be accessed like a normal arg

Signed-off-by: Rhys Perry <pendingchaos02@gmail.com>
Reviewed-by: Samuel Pitoiset <samuel.pitoiset@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/19202>
This commit is contained in:
Rhys Perry 2022-10-20 13:17:11 +01:00 committed by Marge Bot
parent 24618721d3
commit be6f30a0db
3 changed files with 39 additions and 7 deletions

View file

@ -127,6 +127,8 @@ void ac_llvm_context_init(struct ac_llvm_context *ctx, struct ac_llvm_compiler *
ctx->empty_md = LLVMMDNodeInContext(ctx->context, NULL, 0);
ctx->flow = calloc(1, sizeof(*ctx->flow));
ctx->ring_offsets_index = INT32_MAX;
}
void ac_llvm_context_dispose(struct ac_llvm_context *ctx)
@ -229,6 +231,7 @@ LLVMTypeRef ac_to_integer_type(struct ac_llvm_context *ctx, LLVMTypeRef t)
if (LLVMGetTypeKind(t) == LLVMPointerTypeKind) {
switch (LLVMGetPointerAddressSpace(t)) {
case AC_ADDR_SPACE_GLOBAL:
case AC_ADDR_SPACE_CONST:
return ctx->i64;
case AC_ADDR_SPACE_CONST_32BIT:
case AC_ADDR_SPACE_LDS:
@ -4583,12 +4586,22 @@ struct ac_llvm_pointer ac_build_main(const struct ac_shader_args *args, struct a
LLVMTypeRef ret_type, LLVMModuleRef module)
{
LLVMTypeRef arg_types[AC_MAX_ARGS];
enum ac_arg_regfile arg_regfiles[AC_MAX_ARGS];
/* ring_offsets doesn't have a corresponding function parameter because LLVM can allocate it
* itself for scratch memory purposes and gives us access through llvm.amdgcn.implicit.buffer.ptr
*/
unsigned arg_count = 0;
for (unsigned i = 0; i < args->arg_count; i++) {
arg_types[i] = arg_llvm_type(args->args[i].type, args->args[i].size, ctx);
if (args->ring_offsets.used && i == args->ring_offsets.arg_index) {
ctx->ring_offsets_index = i;
continue;
}
arg_regfiles[arg_count] = args->args[i].file;
arg_types[arg_count++] = arg_llvm_type(args->args[i].type, args->args[i].size, ctx);
}
LLVMTypeRef main_function_type = LLVMFunctionType(ret_type, arg_types, args->arg_count, 0);
LLVMTypeRef main_function_type = LLVMFunctionType(ret_type, arg_types, arg_count, 0);
LLVMValueRef main_function = LLVMAddFunction(module, name, main_function_type);
LLVMBasicBlockRef main_function_body =
@ -4596,10 +4609,10 @@ struct ac_llvm_pointer ac_build_main(const struct ac_shader_args *args, struct a
LLVMPositionBuilderAtEnd(ctx->builder, main_function_body);
LLVMSetFunctionCallConv(main_function, convention);
for (unsigned i = 0; i < args->arg_count; ++i) {
for (unsigned i = 0; i < arg_count; ++i) {
LLVMValueRef P = LLVMGetParam(main_function, i);
if (args->args[i].file != AC_ARG_SGPR)
if (arg_regfiles[i] != AC_ARG_SGPR)
continue;
ac_add_function_attr(ctx->context, main_function, i + 1, "inreg");
@ -4611,6 +4624,14 @@ struct ac_llvm_pointer ac_build_main(const struct ac_shader_args *args, struct a
}
}
if (args->ring_offsets.used) {
ctx->ring_offsets =
ac_build_intrinsic(ctx, "llvm.amdgcn.implicit.buffer.ptr",
LLVMPointerType(ctx->i8, AC_ADDR_SPACE_CONST), NULL, 0, 0);
ctx->ring_offsets = LLVMBuildBitCast(ctx->builder, ctx->ring_offsets,
ac_array_in_const_addr_space(ctx->v4i32), "");
}
ctx->main_function = (struct ac_llvm_pointer) {
.value = main_function,
.pointee_type = main_function_type

View file

@ -160,6 +160,9 @@ struct ac_llvm_context {
bool exports_mrtz;
struct ac_llvm_pointer lds;
LLVMValueRef ring_offsets;
int ring_offsets_index;
};
void ac_llvm_context_init(struct ac_llvm_context *ctx, struct ac_llvm_compiler *compiler,
@ -607,7 +610,10 @@ LLVMTypeRef ac_arg_type_to_pointee_type(struct ac_llvm_context *ctx, enum ac_arg
static inline LLVMValueRef ac_get_arg(struct ac_llvm_context *ctx, struct ac_arg arg)
{
assert(arg.used);
return LLVMGetParam(ctx->main_function.value, arg.arg_index);
if (arg.arg_index == ctx->ring_offsets_index)
return ctx->ring_offsets;
int offset = arg.arg_index > ctx->ring_offsets_index ? -1 : 0;
return LLVMGetParam(ctx->main_function.value, arg.arg_index + offset);
}
static inline struct ac_llvm_pointer
@ -615,7 +621,7 @@ ac_get_ptr_arg(struct ac_llvm_context *ctx, const struct ac_shader_args *args, s
{
struct ac_llvm_pointer ptr;
ptr.pointee_type = ac_arg_type_to_pointee_type(ctx, args->args[arg.arg_index].type);
ptr.value = LLVMGetParam(ctx->main_function.value, arg.arg_index);
ptr.value = ac_get_arg(ctx, arg);
return ptr;
}

View file

@ -4182,7 +4182,12 @@ static bool visit_intrinsic(struct ac_nir_context *ctx, nir_intrinsic_instr *ins
case nir_intrinsic_load_scalar_arg_amd:
case nir_intrinsic_load_vector_arg_amd: {
assert(nir_intrinsic_base(instr) < AC_MAX_ARGS);
result = ac_to_integer(&ctx->ac, LLVMGetParam(ctx->main_function, nir_intrinsic_base(instr)));
struct ac_arg arg;
arg.arg_index = nir_intrinsic_base(instr);
arg.used = true;
result = ac_to_integer(&ctx->ac, ac_get_arg(&ctx->ac, arg));
if (ac_get_elem_bits(&ctx->ac, LLVMTypeOf(result)) != 32)
result = LLVMBuildBitCast(ctx->ac.builder, result, get_def_type(ctx, &instr->dest.ssa), "");
break;
}
case nir_intrinsic_load_smem_amd: {