llvmpipe/cs: add support for function calls.

This adds (disabled) support for function calls to the compute shader.

Reviewed-by: Mike Blumenkrantz <michael.blumenkrantz@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/24687>
This commit is contained in:
Dave Airlie 2023-08-15 15:55:21 +10:00 committed by Marge Bot
parent 3704f158a2
commit 4ecd471ee3

View file

@ -442,6 +442,111 @@ generate_compute(struct llvmpipe_context *lp,
lp_build_name(thread_data_ptr, "thread_data");
lp_build_name(io_ptr, "vertex_io");
lp_build_nir_prepasses(nir);
struct hash_table *fns = _mesa_pointer_hash_table_create(NULL);
if (exec_list_length(&nir->functions) > 1) {
LLVMTypeRef call_context_type = lp_build_cs_func_call_context(gallivm, cs_type.length,
variant->jit_cs_context_type,
variant->jit_resources_type);
nir_foreach_function(func, nir) {
if (func->is_entrypoint)
continue;
LLVMTypeRef args[32];
int num_args;
num_args = func->num_params + LP_RESV_FUNC_ARGS;
args[0] = LLVMVectorType(LLVMInt32TypeInContext(gallivm->context), cs_type.length); /* mask */
args[1] = LLVMPointerType(call_context_type, 0);
for (int i = 0; i < func->num_params; i++) {
args[i + LP_RESV_FUNC_ARGS] = LLVMVectorType(LLVMIntTypeInContext(gallivm->context, func->params[i].bit_size), cs_type.length);
if (func->params[i].num_components > 1)
args[i + LP_RESV_FUNC_ARGS] = LLVMArrayType(args[i + LP_RESV_FUNC_ARGS], func->params[i].num_components);
}
LLVMTypeRef func_type = LLVMFunctionType(LLVMVoidTypeInContext(gallivm->context),
args, num_args, 0);
LLVMValueRef lfunc = LLVMAddFunction(gallivm->module, func->name, func_type);
LLVMSetFunctionCallConv(lfunc, LLVMCCallConv);
struct lp_build_fn *new_fn = ralloc(fns, struct lp_build_fn);
new_fn->fn_type = func_type;
new_fn->fn = lfunc;
_mesa_hash_table_insert(fns, func, new_fn);
}
nir_foreach_function(func, nir) {
if (func->is_entrypoint)
continue;
struct hash_entry *entry = _mesa_hash_table_search(fns, func);
assert(entry);
struct lp_build_fn *new_fn = entry->data;
LLVMValueRef lfunc = new_fn->fn;
block = LLVMAppendBasicBlockInContext(gallivm->context, lfunc, "entry");
builder = gallivm->builder;
LLVMPositionBuilderAtEnd(builder, block);
LLVMValueRef mask_param = LLVMGetParam(lfunc, 0);
LLVMValueRef call_context_ptr = LLVMGetParam(lfunc, 1);
LLVMValueRef call_context = LLVMBuildLoad2(builder, call_context_type, call_context_ptr, "");
struct lp_build_mask_context mask;
struct lp_bld_tgsi_system_values system_values;
memset(&system_values, 0, sizeof(system_values));
lp_build_mask_begin(&mask, gallivm, cs_type, mask_param);
lp_build_mask_check(&mask);
struct lp_build_tgsi_params params;
memset(&params, 0, sizeof(params));
params.type = cs_type;
params.mask = &mask;
params.fns = fns;
params.current_func = lfunc;
params.context_type = variant->jit_cs_context_type;
params.resources_type = variant->jit_resources_type;
params.call_context_ptr = call_context_ptr;
params.context_ptr = LLVMBuildExtractValue(builder, call_context, LP_NIR_CALL_CONTEXT_CONTEXT, "");
params.resources_ptr = LLVMBuildExtractValue(builder, call_context, LP_NIR_CALL_CONTEXT_RESOURCES, "");
params.shared_ptr = LLVMBuildExtractValue(builder, call_context, LP_NIR_CALL_CONTEXT_SHARED, "");
params.scratch_ptr = LLVMBuildExtractValue(builder, call_context, LP_NIR_CALL_CONTEXT_SCRATCH, "");
system_values.work_dim = LLVMBuildExtractValue(builder, call_context, LP_NIR_CALL_CONTEXT_WORK_DIM, "");
system_values.thread_id[0] = LLVMBuildExtractValue(builder, call_context, LP_NIR_CALL_CONTEXT_THREAD_ID_0, "");
system_values.thread_id[1] = LLVMBuildExtractValue(builder, call_context, LP_NIR_CALL_CONTEXT_THREAD_ID_1, "");
system_values.thread_id[2] = LLVMBuildExtractValue(builder, call_context, LP_NIR_CALL_CONTEXT_THREAD_ID_2, "");
system_values.block_id[0] = LLVMBuildExtractValue(builder, call_context, LP_NIR_CALL_CONTEXT_BLOCK_ID_0, "");
system_values.block_id[1] = LLVMBuildExtractValue(builder, call_context, LP_NIR_CALL_CONTEXT_BLOCK_ID_1, "");
system_values.block_id[2] = LLVMBuildExtractValue(builder, call_context, LP_NIR_CALL_CONTEXT_BLOCK_ID_2, "");
system_values.grid_size[0] = LLVMBuildExtractValue(builder, call_context, LP_NIR_CALL_CONTEXT_GRID_SIZE_0, "");
system_values.grid_size[1] = LLVMBuildExtractValue(builder, call_context, LP_NIR_CALL_CONTEXT_GRID_SIZE_1, "");
system_values.grid_size[2] = LLVMBuildExtractValue(builder, call_context, LP_NIR_CALL_CONTEXT_GRID_SIZE_2, "");
system_values.block_size[0] = LLVMBuildExtractValue(builder, call_context, LP_NIR_CALL_CONTEXT_BLOCK_SIZE_0, "");
system_values.block_size[1] = LLVMBuildExtractValue(builder, call_context, LP_NIR_CALL_CONTEXT_BLOCK_SIZE_1, "");
system_values.block_size[2] = LLVMBuildExtractValue(builder, call_context, LP_NIR_CALL_CONTEXT_BLOCK_SIZE_2, "");
params.system_values = &system_values;
params.consts_ptr = lp_jit_resources_constants(gallivm,
variant->jit_resources_type,
params.resources_ptr);
params.ssbo_ptr = lp_jit_resources_ssbos(gallivm,
variant->jit_resources_type,
params.resources_ptr);
lp_build_nir_soa_func(gallivm, shader->base.ir.nir,
func->impl,
&params,
NULL);
lp_build_mask_end(&mask);
LLVMBuildRetVoid(builder);
gallivm_verify_function(gallivm, lfunc);
}
}
block = LLVMAppendBasicBlockInContext(gallivm->context, function, "entry");
builder = gallivm->builder;
assert(builder);
@ -750,8 +855,11 @@ generate_compute(struct llvmpipe_context *lp,
resources_ptr);
params.mesh_iface = &mesh_iface.base;
lp_build_nir_soa(gallivm, shader->base.ir.nir, &params,
NULL);
params.current_func = NULL;
params.fns = fns;
lp_build_nir_soa_func(gallivm, nir,
nir_shader_get_entrypoint(nir),
&params, NULL);
if (is_mesh) {
LLVMTypeRef i32t = LLVMInt32TypeInContext(gallivm->context);
@ -833,6 +941,7 @@ generate_compute(struct llvmpipe_context *lp,
lp_bld_llvm_sampler_soa_destroy(sampler);
lp_bld_llvm_image_soa_destroy(image);
_mesa_hash_table_destroy(fns, NULL);
gallivm_verify_function(gallivm, coro);
gallivm_verify_function(gallivm, function);