mirror of
https://gitlab.freedesktop.org/mesa/mesa.git
synced 2026-05-08 04:48:08 +02:00
gallivm: add support for function calling
This adds support for calling functions in compute shaders. Functions are passed two implicit arguments - the current exec mask - a context containing all the info needed for intrinsics to work when not in the toplevel. Reviewed-by: Mike Blumenkrantz <michael.blumenkrantz@gmail.com> Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/24687>
This commit is contained in:
parent
14a6668964
commit
3704f158a2
4 changed files with 209 additions and 2 deletions
|
|
@ -2100,6 +2100,20 @@ visit_payload_atomic(struct lp_build_nir_context *bld_base,
|
|||
offset, val, val2, &result[0]);
|
||||
}
|
||||
|
||||
static void visit_load_param(struct lp_build_nir_context *bld_base,
|
||||
nir_intrinsic_instr *instr,
|
||||
LLVMValueRef result[NIR_MAX_VEC_COMPONENTS])
|
||||
{
|
||||
LLVMValueRef param = LLVMGetParam(bld_base->func, nir_intrinsic_param_idx(instr) + LP_RESV_FUNC_ARGS);
|
||||
struct gallivm_state *gallivm = bld_base->base.gallivm;
|
||||
if (instr->num_components == 1)
|
||||
result[0] = param;
|
||||
else {
|
||||
for (unsigned i = 0; i < instr->num_components; i++)
|
||||
result[i] = LLVMBuildExtractValue(gallivm->builder, param, i, "");
|
||||
}
|
||||
}
|
||||
|
||||
static void
|
||||
visit_intrinsic(struct lp_build_nir_context *bld_base,
|
||||
nir_intrinsic_instr *instr)
|
||||
|
|
@ -2305,6 +2319,9 @@ visit_intrinsic(struct lp_build_nir_context *bld_base,
|
|||
get_src(bld_base, instr->src[0]),
|
||||
get_src(bld_base, instr->src[1]));
|
||||
break;
|
||||
case nir_intrinsic_load_param:
|
||||
visit_load_param(bld_base, instr, result);
|
||||
break;
|
||||
default:
|
||||
fprintf(stderr, "Unsupported intrinsic: ");
|
||||
nir_print_instr(&instr->instr, stderr);
|
||||
|
|
@ -2729,6 +2746,29 @@ visit_deref(struct lp_build_nir_context *bld_base,
|
|||
assign_ssa(bld_base, instr->def.index, result);
|
||||
}
|
||||
|
||||
static void
|
||||
visit_call(struct lp_build_nir_context *bld_base,
|
||||
nir_call_instr *instr)
|
||||
{
|
||||
LLVMValueRef *args;
|
||||
struct hash_entry *entry = _mesa_hash_table_search(bld_base->fns, instr->callee);
|
||||
struct lp_build_fn *fn = entry->data;
|
||||
args = calloc(instr->num_params + LP_RESV_FUNC_ARGS, sizeof(LLVMValueRef));
|
||||
|
||||
assert(args);
|
||||
|
||||
args[0] = 0;
|
||||
for (unsigned i = 0; i < instr->num_params; i++) {
|
||||
LLVMValueRef arg = get_src(bld_base, instr->params[i]);
|
||||
|
||||
if (nir_src_bit_size(instr->params[i]) == 32 && LLVMTypeOf(arg) == bld_base->base.vec_type)
|
||||
arg = cast_type(bld_base, arg, nir_type_int, 32);
|
||||
args[i + LP_RESV_FUNC_ARGS] = arg;
|
||||
}
|
||||
|
||||
bld_base->call(bld_base, fn, instr->num_params + LP_RESV_FUNC_ARGS, args);
|
||||
free(args);
|
||||
}
|
||||
|
||||
static void
|
||||
visit_block(struct lp_build_nir_context *bld_base, nir_block *block)
|
||||
|
|
@ -2760,6 +2800,9 @@ visit_block(struct lp_build_nir_context *bld_base, nir_block *block)
|
|||
case nir_instr_type_deref:
|
||||
visit_deref(bld_base, nir_instr_as_deref(instr));
|
||||
break;
|
||||
case nir_instr_type_call:
|
||||
visit_call(bld_base, nir_instr_as_call(instr));
|
||||
break;
|
||||
default:
|
||||
fprintf(stderr, "Unknown NIR instr type: ");
|
||||
nir_print_instr(instr, stderr);
|
||||
|
|
|
|||
|
|
@ -35,6 +35,12 @@
|
|||
|
||||
struct nir_shader;
|
||||
|
||||
/*
|
||||
* 2 reserved functions args for each function call,
|
||||
* exec mask and context.
|
||||
*/
|
||||
#define LP_RESV_FUNC_ARGS 2
|
||||
|
||||
void lp_build_nir_soa(struct gallivm_state *gallivm,
|
||||
struct nir_shader *shader,
|
||||
const struct lp_build_tgsi_params *params,
|
||||
|
|
@ -55,6 +61,11 @@ void lp_build_nir_aos(struct gallivm_state *gallivm,
|
|||
LLVMValueRef *outputs,
|
||||
const struct lp_build_sampler_aos *sampler);
|
||||
|
||||
struct lp_build_fn {
|
||||
LLVMTypeRef fn_type;
|
||||
LLVMValueRef fn;
|
||||
};
|
||||
|
||||
struct lp_build_nir_context
|
||||
{
|
||||
struct lp_build_context base;
|
||||
|
|
@ -72,12 +83,14 @@ struct lp_build_nir_context
|
|||
LLVMValueRef *ssa_defs;
|
||||
struct hash_table *regs;
|
||||
struct hash_table *vars;
|
||||
struct hash_table *fns;
|
||||
|
||||
/** Value range analysis hash table used in code generation. */
|
||||
struct hash_table *range_ht;
|
||||
|
||||
LLVMValueRef aniso_filter_table;
|
||||
|
||||
LLVMValueRef func;
|
||||
nir_shader *shader;
|
||||
|
||||
void (*load_ubo)(struct lp_build_nir_context *bld_base,
|
||||
|
|
@ -243,6 +256,11 @@ struct lp_build_nir_context
|
|||
LLVMValueRef prim_count);
|
||||
void (*launch_mesh_workgroups)(struct lp_build_nir_context *bld_base,
|
||||
LLVMValueRef launch_grid);
|
||||
|
||||
void (*call)(struct lp_build_nir_context *bld_base,
|
||||
struct lp_build_fn *fn,
|
||||
int num_args,
|
||||
LLVMValueRef *args);
|
||||
// LLVMValueRef main_function
|
||||
};
|
||||
|
||||
|
|
@ -299,6 +317,9 @@ struct lp_build_nir_soa_context
|
|||
|
||||
LLVMValueRef kernel_args_ptr;
|
||||
unsigned gs_vertex_streams;
|
||||
|
||||
LLVMTypeRef call_context_type;
|
||||
LLVMValueRef call_context_ptr;
|
||||
};
|
||||
|
||||
void
|
||||
|
|
@ -389,5 +410,31 @@ lp_build_nir_sample_key(gl_shader_stage stage, nir_tex_instr *instr);
|
|||
|
||||
void lp_img_op_from_intrinsic(struct lp_img_params *params, nir_intrinsic_instr *instr);
|
||||
|
||||
enum lp_nir_call_context_args {
|
||||
LP_NIR_CALL_CONTEXT_CONTEXT,
|
||||
LP_NIR_CALL_CONTEXT_RESOURCES,
|
||||
LP_NIR_CALL_CONTEXT_SHARED,
|
||||
LP_NIR_CALL_CONTEXT_SCRATCH,
|
||||
LP_NIR_CALL_CONTEXT_WORK_DIM,
|
||||
LP_NIR_CALL_CONTEXT_THREAD_ID_0,
|
||||
LP_NIR_CALL_CONTEXT_THREAD_ID_1,
|
||||
LP_NIR_CALL_CONTEXT_THREAD_ID_2,
|
||||
LP_NIR_CALL_CONTEXT_BLOCK_ID_0,
|
||||
LP_NIR_CALL_CONTEXT_BLOCK_ID_1,
|
||||
LP_NIR_CALL_CONTEXT_BLOCK_ID_2,
|
||||
LP_NIR_CALL_CONTEXT_GRID_SIZE_0,
|
||||
LP_NIR_CALL_CONTEXT_GRID_SIZE_1,
|
||||
LP_NIR_CALL_CONTEXT_GRID_SIZE_2,
|
||||
LP_NIR_CALL_CONTEXT_BLOCK_SIZE_0,
|
||||
LP_NIR_CALL_CONTEXT_BLOCK_SIZE_1,
|
||||
LP_NIR_CALL_CONTEXT_BLOCK_SIZE_2,
|
||||
LP_NIR_CALL_CONTEXT_MAX_ARGS,
|
||||
};
|
||||
|
||||
LLVMTypeRef
|
||||
lp_build_cs_func_call_context(struct gallivm_state *gallivm, int length,
|
||||
LLVMTypeRef context_type, LLVMTypeRef resources_type);
|
||||
|
||||
|
||||
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -2694,6 +2694,19 @@ emit_launch_mesh_workgroups(struct lp_build_nir_context *bld_base,
|
|||
lp_build_endif(&ifthen);
|
||||
}
|
||||
|
||||
static void
|
||||
emit_call(struct lp_build_nir_context *bld_base,
|
||||
struct lp_build_fn *fn,
|
||||
int num_args,
|
||||
LLVMValueRef *args)
|
||||
{
|
||||
struct lp_build_nir_soa_context *bld = (struct lp_build_nir_soa_context *)bld_base;
|
||||
|
||||
args[0] = mask_vec(bld_base);
|
||||
args[1] = bld->call_context_ptr;
|
||||
LLVMBuildCall2(bld_base->base.gallivm->builder, fn->fn_type, fn->fn, args, num_args, "");
|
||||
}
|
||||
|
||||
static LLVMValueRef get_scratch_thread_offsets(struct gallivm_state *gallivm,
|
||||
struct lp_type type,
|
||||
unsigned scratch_size)
|
||||
|
|
@ -2800,6 +2813,90 @@ emit_clock(struct lp_build_nir_context *bld_base,
|
|||
dst[1] = lp_build_broadcast_scalar(uint_bld, hi);
|
||||
}
|
||||
|
||||
LLVMTypeRef
|
||||
lp_build_cs_func_call_context(struct gallivm_state *gallivm, int length,
|
||||
LLVMTypeRef context_type, LLVMTypeRef resources_type)
|
||||
{
|
||||
LLVMTypeRef args[LP_NIR_CALL_CONTEXT_MAX_ARGS];
|
||||
|
||||
args[LP_NIR_CALL_CONTEXT_CONTEXT] = LLVMPointerType(context_type, 0);
|
||||
args[LP_NIR_CALL_CONTEXT_RESOURCES] = LLVMPointerType(resources_type, 0);
|
||||
args[LP_NIR_CALL_CONTEXT_SHARED] = LLVMPointerType(LLVMInt32TypeInContext(gallivm->context), 0); /* shared_ptr */
|
||||
args[LP_NIR_CALL_CONTEXT_SCRATCH] = LLVMPointerType(LLVMInt8TypeInContext(gallivm->context), 0); /* scratch ptr */
|
||||
args[LP_NIR_CALL_CONTEXT_WORK_DIM] = LLVMInt32TypeInContext(gallivm->context); /* work_dim */
|
||||
args[LP_NIR_CALL_CONTEXT_THREAD_ID_0] = LLVMVectorType(LLVMInt32TypeInContext(gallivm->context), length); /* system_values.thread_id[0] */
|
||||
args[LP_NIR_CALL_CONTEXT_THREAD_ID_1] = LLVMVectorType(LLVMInt32TypeInContext(gallivm->context), length); /* system_values.thread_id[1] */
|
||||
args[LP_NIR_CALL_CONTEXT_THREAD_ID_2] = LLVMVectorType(LLVMInt32TypeInContext(gallivm->context), length); /* system_values.thread_id[2] */
|
||||
args[LP_NIR_CALL_CONTEXT_BLOCK_ID_0] = LLVMInt32TypeInContext(gallivm->context); /* system_values.block_id[0] */
|
||||
args[LP_NIR_CALL_CONTEXT_BLOCK_ID_1] = LLVMInt32TypeInContext(gallivm->context); /* system_values.block_id[1] */
|
||||
args[LP_NIR_CALL_CONTEXT_BLOCK_ID_2] = LLVMInt32TypeInContext(gallivm->context); /* system_values.block_id[2] */
|
||||
|
||||
args[LP_NIR_CALL_CONTEXT_GRID_SIZE_0] = LLVMInt32TypeInContext(gallivm->context); /* system_values.grid_size[0] */
|
||||
args[LP_NIR_CALL_CONTEXT_GRID_SIZE_1] = LLVMInt32TypeInContext(gallivm->context); /* system_values.grid_size[1] */
|
||||
args[LP_NIR_CALL_CONTEXT_GRID_SIZE_2] = LLVMInt32TypeInContext(gallivm->context); /* system_values.grid_size[2] */
|
||||
args[LP_NIR_CALL_CONTEXT_BLOCK_SIZE_0] = LLVMInt32TypeInContext(gallivm->context); /* system_values.block_size[0] */
|
||||
args[LP_NIR_CALL_CONTEXT_BLOCK_SIZE_1] = LLVMInt32TypeInContext(gallivm->context); /* system_values.block_size[1] */
|
||||
args[LP_NIR_CALL_CONTEXT_BLOCK_SIZE_2] = LLVMInt32TypeInContext(gallivm->context); /* system_values.block_size[2] */
|
||||
|
||||
LLVMTypeRef stype = LLVMStructTypeInContext(gallivm->context, args, LP_NIR_CALL_CONTEXT_MAX_ARGS, 0);
|
||||
return stype;
|
||||
}
|
||||
|
||||
static void
|
||||
build_call_context(struct lp_build_nir_soa_context *bld)
|
||||
{
|
||||
struct gallivm_state *gallivm = bld->bld_base.base.gallivm;
|
||||
bld->call_context_ptr = lp_build_alloca(gallivm, bld->call_context_type, "callcontext");
|
||||
LLVMValueRef call_context = LLVMGetUndef(bld->call_context_type);
|
||||
call_context = LLVMBuildInsertValue(gallivm->builder,
|
||||
call_context, bld->context_ptr, LP_NIR_CALL_CONTEXT_CONTEXT, "");
|
||||
call_context = LLVMBuildInsertValue(gallivm->builder,
|
||||
call_context, bld->resources_ptr, LP_NIR_CALL_CONTEXT_RESOURCES, "");
|
||||
if (bld->shared_ptr) {
|
||||
call_context = LLVMBuildInsertValue(gallivm->builder,
|
||||
call_context, bld->shared_ptr, LP_NIR_CALL_CONTEXT_SHARED, "");
|
||||
} else {
|
||||
call_context = LLVMBuildInsertValue(gallivm->builder, call_context,
|
||||
LLVMConstNull(LLVMPointerType(LLVMInt8TypeInContext(gallivm->context), 0)),
|
||||
LP_NIR_CALL_CONTEXT_SHARED, "");
|
||||
}
|
||||
if (bld->scratch_ptr) {
|
||||
call_context = LLVMBuildInsertValue(gallivm->builder,
|
||||
call_context, bld->scratch_ptr, LP_NIR_CALL_CONTEXT_SCRATCH, "");
|
||||
} else {
|
||||
call_context = LLVMBuildInsertValue(gallivm->builder, call_context,
|
||||
LLVMConstNull(LLVMPointerType(LLVMInt8TypeInContext(gallivm->context), 0)),
|
||||
LP_NIR_CALL_CONTEXT_SCRATCH, "");
|
||||
}
|
||||
call_context = LLVMBuildInsertValue(gallivm->builder,
|
||||
call_context, bld->system_values.work_dim, LP_NIR_CALL_CONTEXT_WORK_DIM, "");
|
||||
call_context = LLVMBuildInsertValue(gallivm->builder,
|
||||
call_context, bld->system_values.thread_id[0], LP_NIR_CALL_CONTEXT_THREAD_ID_0, "");
|
||||
call_context = LLVMBuildInsertValue(gallivm->builder,
|
||||
call_context, bld->system_values.thread_id[1], LP_NIR_CALL_CONTEXT_THREAD_ID_1, "");
|
||||
call_context = LLVMBuildInsertValue(gallivm->builder,
|
||||
call_context, bld->system_values.thread_id[2], LP_NIR_CALL_CONTEXT_THREAD_ID_2, "");
|
||||
call_context = LLVMBuildInsertValue(gallivm->builder,
|
||||
call_context, bld->system_values.block_id[0], LP_NIR_CALL_CONTEXT_BLOCK_ID_0, "");
|
||||
call_context = LLVMBuildInsertValue(gallivm->builder,
|
||||
call_context, bld->system_values.block_id[1], LP_NIR_CALL_CONTEXT_BLOCK_ID_1, "");
|
||||
call_context = LLVMBuildInsertValue(gallivm->builder,
|
||||
call_context, bld->system_values.block_id[2], LP_NIR_CALL_CONTEXT_BLOCK_ID_2, "");
|
||||
call_context = LLVMBuildInsertValue(gallivm->builder,
|
||||
call_context, bld->system_values.grid_size[0], LP_NIR_CALL_CONTEXT_GRID_SIZE_0, "");
|
||||
call_context = LLVMBuildInsertValue(gallivm->builder,
|
||||
call_context, bld->system_values.grid_size[1], LP_NIR_CALL_CONTEXT_GRID_SIZE_1, "");
|
||||
call_context = LLVMBuildInsertValue(gallivm->builder,
|
||||
call_context, bld->system_values.grid_size[2], LP_NIR_CALL_CONTEXT_GRID_SIZE_2, "");
|
||||
call_context = LLVMBuildInsertValue(gallivm->builder,
|
||||
call_context, bld->system_values.block_size[0], LP_NIR_CALL_CONTEXT_BLOCK_SIZE_0, "");
|
||||
call_context = LLVMBuildInsertValue(gallivm->builder,
|
||||
call_context, bld->system_values.block_size[1], LP_NIR_CALL_CONTEXT_BLOCK_SIZE_1, "");
|
||||
call_context = LLVMBuildInsertValue(gallivm->builder,
|
||||
call_context, bld->system_values.block_size[2], LP_NIR_CALL_CONTEXT_BLOCK_SIZE_2, "");
|
||||
LLVMBuildStore(gallivm->builder, call_context, bld->call_context_ptr);
|
||||
}
|
||||
|
||||
void lp_build_nir_soa_func(struct gallivm_state *gallivm,
|
||||
struct nir_shader *shader,
|
||||
nir_function_impl *impl,
|
||||
|
|
@ -2911,6 +3008,7 @@ void lp_build_nir_soa_func(struct gallivm_state *gallivm,
|
|||
bld.bld_base.read_invocation = emit_read_invocation;
|
||||
bld.bld_base.helper_invocation = emit_helper_invocation;
|
||||
bld.bld_base.interp_at = emit_interp_at;
|
||||
bld.bld_base.call = emit_call;
|
||||
bld.bld_base.load_scratch = emit_load_scratch;
|
||||
bld.bld_base.store_scratch = emit_store_scratch;
|
||||
bld.bld_base.load_const = emit_load_const;
|
||||
|
|
@ -2918,6 +3016,8 @@ void lp_build_nir_soa_func(struct gallivm_state *gallivm,
|
|||
bld.bld_base.set_vertex_and_primitive_count = emit_set_vertex_and_primitive_count;
|
||||
bld.bld_base.launch_mesh_workgroups = emit_launch_mesh_workgroups;
|
||||
|
||||
bld.bld_base.fns = params->fns;
|
||||
bld.bld_base.func = params->current_func;
|
||||
bld.mask = params->mask;
|
||||
bld.inputs = params->inputs;
|
||||
bld.outputs = outputs;
|
||||
|
|
@ -2925,6 +3025,8 @@ void lp_build_nir_soa_func(struct gallivm_state *gallivm,
|
|||
bld.ssbo_ptr = params->ssbo_ptr;
|
||||
bld.sampler = params->sampler;
|
||||
|
||||
bld.context_type = params->context_type;
|
||||
bld.context_ptr = params->context_ptr;
|
||||
bld.resources_type = params->resources_type;
|
||||
bld.resources_ptr = params->resources_ptr;
|
||||
bld.thread_data_type = params->thread_data_type;
|
||||
|
|
@ -2961,18 +3063,29 @@ void lp_build_nir_soa_func(struct gallivm_state *gallivm,
|
|||
}
|
||||
lp_exec_mask_init(&bld.exec_mask, &bld.bld_base.int_bld);
|
||||
|
||||
bld.system_values = *params->system_values;
|
||||
if (params->system_values)
|
||||
bld.system_values = *params->system_values;
|
||||
|
||||
bld.bld_base.shader = shader;
|
||||
|
||||
bld.scratch_size = ALIGN(shader->scratch_size, 8);
|
||||
if (shader->scratch_size) {
|
||||
if (params->scratch_ptr)
|
||||
bld.scratch_ptr = params->scratch_ptr;
|
||||
else if (shader->scratch_size) {
|
||||
bld.scratch_ptr = lp_build_array_alloca(gallivm,
|
||||
LLVMInt8TypeInContext(gallivm->context),
|
||||
lp_build_const_int32(gallivm, bld.scratch_size * type.length),
|
||||
"scratch");
|
||||
}
|
||||
|
||||
if (shader->info.stage == MESA_SHADER_KERNEL) {
|
||||
bld.call_context_type = lp_build_cs_func_call_context(gallivm, type.length, bld.context_type, bld.resources_type);
|
||||
if (!params->call_context_ptr) {
|
||||
build_call_context(&bld);
|
||||
} else
|
||||
bld.call_context_ptr = params->call_context_ptr;
|
||||
}
|
||||
|
||||
emit_prologue(&bld);
|
||||
lp_build_nir_llvm(&bld.bld_base, shader, impl);
|
||||
|
||||
|
|
|
|||
|
|
@ -289,6 +289,10 @@ struct lp_build_tgsi_params {
|
|||
const struct lp_build_fs_iface *fs_iface;
|
||||
unsigned gs_vertex_streams;
|
||||
LLVMValueRef aniso_filter_table;
|
||||
LLVMValueRef current_func;
|
||||
struct hash_table *fns;
|
||||
LLVMValueRef scratch_ptr;
|
||||
LLVMValueRef call_context_ptr;
|
||||
};
|
||||
|
||||
void
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue