diff --git a/src/gallium/auxiliary/gallivm/lp_bld_nir.c b/src/gallium/auxiliary/gallivm/lp_bld_nir.c index aab7c807a4e..ce49783da8b 100644 --- a/src/gallium/auxiliary/gallivm/lp_bld_nir.c +++ b/src/gallium/auxiliary/gallivm/lp_bld_nir.c @@ -2100,6 +2100,20 @@ visit_payload_atomic(struct lp_build_nir_context *bld_base, offset, val, val2, &result[0]); } +static void visit_load_param(struct lp_build_nir_context *bld_base, + nir_intrinsic_instr *instr, + LLVMValueRef result[NIR_MAX_VEC_COMPONENTS]) +{ + LLVMValueRef param = LLVMGetParam(bld_base->func, nir_intrinsic_param_idx(instr) + LP_RESV_FUNC_ARGS); + struct gallivm_state *gallivm = bld_base->base.gallivm; + if (instr->num_components == 1) + result[0] = param; + else { + for (unsigned i = 0; i < instr->num_components; i++) + result[i] = LLVMBuildExtractValue(gallivm->builder, param, i, ""); + } +} + static void visit_intrinsic(struct lp_build_nir_context *bld_base, nir_intrinsic_instr *instr) @@ -2305,6 +2319,9 @@ visit_intrinsic(struct lp_build_nir_context *bld_base, get_src(bld_base, instr->src[0]), get_src(bld_base, instr->src[1])); break; + case nir_intrinsic_load_param: + visit_load_param(bld_base, instr, result); + break; default: fprintf(stderr, "Unsupported intrinsic: "); nir_print_instr(&instr->instr, stderr); @@ -2729,6 +2746,29 @@ visit_deref(struct lp_build_nir_context *bld_base, assign_ssa(bld_base, instr->def.index, result); } +static void +visit_call(struct lp_build_nir_context *bld_base, + nir_call_instr *instr) +{ + LLVMValueRef *args; + struct hash_entry *entry = _mesa_hash_table_search(bld_base->fns, instr->callee); + struct lp_build_fn *fn = entry->data; + args = calloc(instr->num_params + LP_RESV_FUNC_ARGS, sizeof(LLVMValueRef)); + + assert(args); + + args[0] = 0; + for (unsigned i = 0; i < instr->num_params; i++) { + LLVMValueRef arg = get_src(bld_base, instr->params[i]); + + if (nir_src_bit_size(instr->params[i]) == 32 && LLVMTypeOf(arg) == bld_base->base.vec_type) + arg = cast_type(bld_base, arg, nir_type_int, 32); + args[i + LP_RESV_FUNC_ARGS] = arg; + } + + bld_base->call(bld_base, fn, instr->num_params + LP_RESV_FUNC_ARGS, args); + free(args); +} static void visit_block(struct lp_build_nir_context *bld_base, nir_block *block) @@ -2760,6 +2800,9 @@ visit_block(struct lp_build_nir_context *bld_base, nir_block *block) case nir_instr_type_deref: visit_deref(bld_base, nir_instr_as_deref(instr)); break; + case nir_instr_type_call: + visit_call(bld_base, nir_instr_as_call(instr)); + break; default: fprintf(stderr, "Unknown NIR instr type: "); nir_print_instr(instr, stderr); diff --git a/src/gallium/auxiliary/gallivm/lp_bld_nir.h b/src/gallium/auxiliary/gallivm/lp_bld_nir.h index 14e77d72f01..3ef67ae8a64 100644 --- a/src/gallium/auxiliary/gallivm/lp_bld_nir.h +++ b/src/gallium/auxiliary/gallivm/lp_bld_nir.h @@ -35,6 +35,12 @@ struct nir_shader; +/* + * 2 reserved functions args for each function call, + * exec mask and context. + */ +#define LP_RESV_FUNC_ARGS 2 + void lp_build_nir_soa(struct gallivm_state *gallivm, struct nir_shader *shader, const struct lp_build_tgsi_params *params, @@ -55,6 +61,11 @@ void lp_build_nir_aos(struct gallivm_state *gallivm, LLVMValueRef *outputs, const struct lp_build_sampler_aos *sampler); +struct lp_build_fn { + LLVMTypeRef fn_type; + LLVMValueRef fn; +}; + struct lp_build_nir_context { struct lp_build_context base; @@ -72,12 +83,14 @@ struct lp_build_nir_context LLVMValueRef *ssa_defs; struct hash_table *regs; struct hash_table *vars; + struct hash_table *fns; /** Value range analysis hash table used in code generation. */ struct hash_table *range_ht; LLVMValueRef aniso_filter_table; + LLVMValueRef func; nir_shader *shader; void (*load_ubo)(struct lp_build_nir_context *bld_base, @@ -243,6 +256,11 @@ struct lp_build_nir_context LLVMValueRef prim_count); void (*launch_mesh_workgroups)(struct lp_build_nir_context *bld_base, LLVMValueRef launch_grid); + + void (*call)(struct lp_build_nir_context *bld_base, + struct lp_build_fn *fn, + int num_args, + LLVMValueRef *args); // LLVMValueRef main_function }; @@ -299,6 +317,9 @@ struct lp_build_nir_soa_context LLVMValueRef kernel_args_ptr; unsigned gs_vertex_streams; + + LLVMTypeRef call_context_type; + LLVMValueRef call_context_ptr; }; void @@ -389,5 +410,31 @@ lp_build_nir_sample_key(gl_shader_stage stage, nir_tex_instr *instr); void lp_img_op_from_intrinsic(struct lp_img_params *params, nir_intrinsic_instr *instr); +enum lp_nir_call_context_args { + LP_NIR_CALL_CONTEXT_CONTEXT, + LP_NIR_CALL_CONTEXT_RESOURCES, + LP_NIR_CALL_CONTEXT_SHARED, + LP_NIR_CALL_CONTEXT_SCRATCH, + LP_NIR_CALL_CONTEXT_WORK_DIM, + LP_NIR_CALL_CONTEXT_THREAD_ID_0, + LP_NIR_CALL_CONTEXT_THREAD_ID_1, + LP_NIR_CALL_CONTEXT_THREAD_ID_2, + LP_NIR_CALL_CONTEXT_BLOCK_ID_0, + LP_NIR_CALL_CONTEXT_BLOCK_ID_1, + LP_NIR_CALL_CONTEXT_BLOCK_ID_2, + LP_NIR_CALL_CONTEXT_GRID_SIZE_0, + LP_NIR_CALL_CONTEXT_GRID_SIZE_1, + LP_NIR_CALL_CONTEXT_GRID_SIZE_2, + LP_NIR_CALL_CONTEXT_BLOCK_SIZE_0, + LP_NIR_CALL_CONTEXT_BLOCK_SIZE_1, + LP_NIR_CALL_CONTEXT_BLOCK_SIZE_2, + LP_NIR_CALL_CONTEXT_MAX_ARGS, +}; + +LLVMTypeRef +lp_build_cs_func_call_context(struct gallivm_state *gallivm, int length, + LLVMTypeRef context_type, LLVMTypeRef resources_type); + + #endif diff --git a/src/gallium/auxiliary/gallivm/lp_bld_nir_soa.c b/src/gallium/auxiliary/gallivm/lp_bld_nir_soa.c index baa533e334c..06dde5c5803 100644 --- a/src/gallium/auxiliary/gallivm/lp_bld_nir_soa.c +++ b/src/gallium/auxiliary/gallivm/lp_bld_nir_soa.c @@ -2694,6 +2694,19 @@ emit_launch_mesh_workgroups(struct lp_build_nir_context *bld_base, lp_build_endif(&ifthen); } +static void +emit_call(struct lp_build_nir_context *bld_base, + struct lp_build_fn *fn, + int num_args, + LLVMValueRef *args) +{ + struct lp_build_nir_soa_context *bld = (struct lp_build_nir_soa_context *)bld_base; + + args[0] = mask_vec(bld_base); + args[1] = bld->call_context_ptr; + LLVMBuildCall2(bld_base->base.gallivm->builder, fn->fn_type, fn->fn, args, num_args, ""); +} + static LLVMValueRef get_scratch_thread_offsets(struct gallivm_state *gallivm, struct lp_type type, unsigned scratch_size) @@ -2800,6 +2813,90 @@ emit_clock(struct lp_build_nir_context *bld_base, dst[1] = lp_build_broadcast_scalar(uint_bld, hi); } +LLVMTypeRef +lp_build_cs_func_call_context(struct gallivm_state *gallivm, int length, + LLVMTypeRef context_type, LLVMTypeRef resources_type) +{ + LLVMTypeRef args[LP_NIR_CALL_CONTEXT_MAX_ARGS]; + + args[LP_NIR_CALL_CONTEXT_CONTEXT] = LLVMPointerType(context_type, 0); + args[LP_NIR_CALL_CONTEXT_RESOURCES] = LLVMPointerType(resources_type, 0); + args[LP_NIR_CALL_CONTEXT_SHARED] = LLVMPointerType(LLVMInt32TypeInContext(gallivm->context), 0); /* shared_ptr */ + args[LP_NIR_CALL_CONTEXT_SCRATCH] = LLVMPointerType(LLVMInt8TypeInContext(gallivm->context), 0); /* scratch ptr */ + args[LP_NIR_CALL_CONTEXT_WORK_DIM] = LLVMInt32TypeInContext(gallivm->context); /* work_dim */ + args[LP_NIR_CALL_CONTEXT_THREAD_ID_0] = LLVMVectorType(LLVMInt32TypeInContext(gallivm->context), length); /* system_values.thread_id[0] */ + args[LP_NIR_CALL_CONTEXT_THREAD_ID_1] = LLVMVectorType(LLVMInt32TypeInContext(gallivm->context), length); /* system_values.thread_id[1] */ + args[LP_NIR_CALL_CONTEXT_THREAD_ID_2] = LLVMVectorType(LLVMInt32TypeInContext(gallivm->context), length); /* system_values.thread_id[2] */ + args[LP_NIR_CALL_CONTEXT_BLOCK_ID_0] = LLVMInt32TypeInContext(gallivm->context); /* system_values.block_id[0] */ + args[LP_NIR_CALL_CONTEXT_BLOCK_ID_1] = LLVMInt32TypeInContext(gallivm->context); /* system_values.block_id[1] */ + args[LP_NIR_CALL_CONTEXT_BLOCK_ID_2] = LLVMInt32TypeInContext(gallivm->context); /* system_values.block_id[2] */ + + args[LP_NIR_CALL_CONTEXT_GRID_SIZE_0] = LLVMInt32TypeInContext(gallivm->context); /* system_values.grid_size[0] */ + args[LP_NIR_CALL_CONTEXT_GRID_SIZE_1] = LLVMInt32TypeInContext(gallivm->context); /* system_values.grid_size[1] */ + args[LP_NIR_CALL_CONTEXT_GRID_SIZE_2] = LLVMInt32TypeInContext(gallivm->context); /* system_values.grid_size[2] */ + args[LP_NIR_CALL_CONTEXT_BLOCK_SIZE_0] = LLVMInt32TypeInContext(gallivm->context); /* system_values.block_size[0] */ + args[LP_NIR_CALL_CONTEXT_BLOCK_SIZE_1] = LLVMInt32TypeInContext(gallivm->context); /* system_values.block_size[1] */ + args[LP_NIR_CALL_CONTEXT_BLOCK_SIZE_2] = LLVMInt32TypeInContext(gallivm->context); /* system_values.block_size[2] */ + + LLVMTypeRef stype = LLVMStructTypeInContext(gallivm->context, args, LP_NIR_CALL_CONTEXT_MAX_ARGS, 0); + return stype; +} + +static void +build_call_context(struct lp_build_nir_soa_context *bld) +{ + struct gallivm_state *gallivm = bld->bld_base.base.gallivm; + bld->call_context_ptr = lp_build_alloca(gallivm, bld->call_context_type, "callcontext"); + LLVMValueRef call_context = LLVMGetUndef(bld->call_context_type); + call_context = LLVMBuildInsertValue(gallivm->builder, + call_context, bld->context_ptr, LP_NIR_CALL_CONTEXT_CONTEXT, ""); + call_context = LLVMBuildInsertValue(gallivm->builder, + call_context, bld->resources_ptr, LP_NIR_CALL_CONTEXT_RESOURCES, ""); + if (bld->shared_ptr) { + call_context = LLVMBuildInsertValue(gallivm->builder, + call_context, bld->shared_ptr, LP_NIR_CALL_CONTEXT_SHARED, ""); + } else { + call_context = LLVMBuildInsertValue(gallivm->builder, call_context, + LLVMConstNull(LLVMPointerType(LLVMInt8TypeInContext(gallivm->context), 0)), + LP_NIR_CALL_CONTEXT_SHARED, ""); + } + if (bld->scratch_ptr) { + call_context = LLVMBuildInsertValue(gallivm->builder, + call_context, bld->scratch_ptr, LP_NIR_CALL_CONTEXT_SCRATCH, ""); + } else { + call_context = LLVMBuildInsertValue(gallivm->builder, call_context, + LLVMConstNull(LLVMPointerType(LLVMInt8TypeInContext(gallivm->context), 0)), + LP_NIR_CALL_CONTEXT_SCRATCH, ""); + } + call_context = LLVMBuildInsertValue(gallivm->builder, + call_context, bld->system_values.work_dim, LP_NIR_CALL_CONTEXT_WORK_DIM, ""); + call_context = LLVMBuildInsertValue(gallivm->builder, + call_context, bld->system_values.thread_id[0], LP_NIR_CALL_CONTEXT_THREAD_ID_0, ""); + call_context = LLVMBuildInsertValue(gallivm->builder, + call_context, bld->system_values.thread_id[1], LP_NIR_CALL_CONTEXT_THREAD_ID_1, ""); + call_context = LLVMBuildInsertValue(gallivm->builder, + call_context, bld->system_values.thread_id[2], LP_NIR_CALL_CONTEXT_THREAD_ID_2, ""); + call_context = LLVMBuildInsertValue(gallivm->builder, + call_context, bld->system_values.block_id[0], LP_NIR_CALL_CONTEXT_BLOCK_ID_0, ""); + call_context = LLVMBuildInsertValue(gallivm->builder, + call_context, bld->system_values.block_id[1], LP_NIR_CALL_CONTEXT_BLOCK_ID_1, ""); + call_context = LLVMBuildInsertValue(gallivm->builder, + call_context, bld->system_values.block_id[2], LP_NIR_CALL_CONTEXT_BLOCK_ID_2, ""); + call_context = LLVMBuildInsertValue(gallivm->builder, + call_context, bld->system_values.grid_size[0], LP_NIR_CALL_CONTEXT_GRID_SIZE_0, ""); + call_context = LLVMBuildInsertValue(gallivm->builder, + call_context, bld->system_values.grid_size[1], LP_NIR_CALL_CONTEXT_GRID_SIZE_1, ""); + call_context = LLVMBuildInsertValue(gallivm->builder, + call_context, bld->system_values.grid_size[2], LP_NIR_CALL_CONTEXT_GRID_SIZE_2, ""); + call_context = LLVMBuildInsertValue(gallivm->builder, + call_context, bld->system_values.block_size[0], LP_NIR_CALL_CONTEXT_BLOCK_SIZE_0, ""); + call_context = LLVMBuildInsertValue(gallivm->builder, + call_context, bld->system_values.block_size[1], LP_NIR_CALL_CONTEXT_BLOCK_SIZE_1, ""); + call_context = LLVMBuildInsertValue(gallivm->builder, + call_context, bld->system_values.block_size[2], LP_NIR_CALL_CONTEXT_BLOCK_SIZE_2, ""); + LLVMBuildStore(gallivm->builder, call_context, bld->call_context_ptr); +} + void lp_build_nir_soa_func(struct gallivm_state *gallivm, struct nir_shader *shader, nir_function_impl *impl, @@ -2911,6 +3008,7 @@ void lp_build_nir_soa_func(struct gallivm_state *gallivm, bld.bld_base.read_invocation = emit_read_invocation; bld.bld_base.helper_invocation = emit_helper_invocation; bld.bld_base.interp_at = emit_interp_at; + bld.bld_base.call = emit_call; bld.bld_base.load_scratch = emit_load_scratch; bld.bld_base.store_scratch = emit_store_scratch; bld.bld_base.load_const = emit_load_const; @@ -2918,6 +3016,8 @@ void lp_build_nir_soa_func(struct gallivm_state *gallivm, bld.bld_base.set_vertex_and_primitive_count = emit_set_vertex_and_primitive_count; bld.bld_base.launch_mesh_workgroups = emit_launch_mesh_workgroups; + bld.bld_base.fns = params->fns; + bld.bld_base.func = params->current_func; bld.mask = params->mask; bld.inputs = params->inputs; bld.outputs = outputs; @@ -2925,6 +3025,8 @@ void lp_build_nir_soa_func(struct gallivm_state *gallivm, bld.ssbo_ptr = params->ssbo_ptr; bld.sampler = params->sampler; + bld.context_type = params->context_type; + bld.context_ptr = params->context_ptr; bld.resources_type = params->resources_type; bld.resources_ptr = params->resources_ptr; bld.thread_data_type = params->thread_data_type; @@ -2961,18 +3063,29 @@ void lp_build_nir_soa_func(struct gallivm_state *gallivm, } lp_exec_mask_init(&bld.exec_mask, &bld.bld_base.int_bld); - bld.system_values = *params->system_values; + if (params->system_values) + bld.system_values = *params->system_values; bld.bld_base.shader = shader; bld.scratch_size = ALIGN(shader->scratch_size, 8); - if (shader->scratch_size) { + if (params->scratch_ptr) + bld.scratch_ptr = params->scratch_ptr; + else if (shader->scratch_size) { bld.scratch_ptr = lp_build_array_alloca(gallivm, LLVMInt8TypeInContext(gallivm->context), lp_build_const_int32(gallivm, bld.scratch_size * type.length), "scratch"); } + if (shader->info.stage == MESA_SHADER_KERNEL) { + bld.call_context_type = lp_build_cs_func_call_context(gallivm, type.length, bld.context_type, bld.resources_type); + if (!params->call_context_ptr) { + build_call_context(&bld); + } else + bld.call_context_ptr = params->call_context_ptr; + } + emit_prologue(&bld); lp_build_nir_llvm(&bld.bld_base, shader, impl); diff --git a/src/gallium/auxiliary/gallivm/lp_bld_tgsi.h b/src/gallium/auxiliary/gallivm/lp_bld_tgsi.h index a5a049f7238..a4435e282c8 100644 --- a/src/gallium/auxiliary/gallivm/lp_bld_tgsi.h +++ b/src/gallium/auxiliary/gallivm/lp_bld_tgsi.h @@ -289,6 +289,10 @@ struct lp_build_tgsi_params { const struct lp_build_fs_iface *fs_iface; unsigned gs_vertex_streams; LLVMValueRef aniso_filter_table; + LLVMValueRef current_func; + struct hash_table *fns; + LLVMValueRef scratch_ptr; + LLVMValueRef call_context_ptr; }; void