llvmpipe/cs: add task/mesh shader support to compute shader builder.

This allows generating task and mesh variants of compute shaders.

It adds:
- vertex and primitive outputs support - aos writing.
- payload support
- mesh iface for the output and count callbacks.
- draw_id
- multiple iteration support to the exec fn to allow launches
in multiple passes to reduce memory usage

Reviewed-by: Roland Scheidegger <sroland@vmware.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/23066>
This commit is contained in:
Dave Airlie 2023-05-17 12:01:36 +10:00
parent 51eb3cc563
commit 3994fb1e19
3 changed files with 389 additions and 9 deletions

View file

@ -49,6 +49,7 @@ struct lp_fragment_shader_variant;
struct lp_compute_shader_variant;
struct lp_rast_state;
struct llvmpipe_screen;
struct vertex_header;
struct lp_jit_viewport
{
@ -374,6 +375,8 @@ typedef void
uint32_t grid_size_y,
uint32_t grid_size_z,
uint32_t work_dim,
uint32_t draw_id,
struct vertex_header *io, /* mesh shader only */
struct lp_jit_cs_thread_data *thread_data);
void

View file

@ -26,6 +26,7 @@
#include "util/u_memory.h"
#include "util/os_time.h"
#include "util/u_dump.h"
#include "util/u_prim.h"
#include "util/u_string.h"
#include "tgsi/tgsi_dump.h"
#include "tgsi/tgsi_parse.h"
@ -33,12 +34,14 @@
#include "gallivm/lp_bld_debug.h"
#include "gallivm/lp_bld_intr.h"
#include "gallivm/lp_bld_flow.h"
#include "gallivm/lp_bld_pack.h"
#include "gallivm/lp_bld_gather.h"
#include "gallivm/lp_bld_coro.h"
#include "gallivm/lp_bld_nir.h"
#include "gallivm/lp_bld_jit_sample.h"
#include "lp_state_cs.h"
#include "lp_context.h"
#include "lp_setup_context.h"
#include "lp_debug.h"
#include "lp_state.h"
#include "lp_perf.h"
@ -52,6 +55,8 @@
#include "nir_serialize.h"
#include "draw/draw_context.h"
#include "draw/draw_llvm.h"
#include "draw/draw_mesh_prim.h"
/** Fragment shader number (for debugging) */
static unsigned cs_no = 0;
@ -60,12 +65,19 @@ static unsigned mesh_no = 0;
struct lp_cs_job_info {
unsigned grid_size[3];
unsigned iter_size[3];
unsigned grid_base[3];
unsigned block_size[3];
unsigned req_local_mem;
unsigned work_dim;
unsigned draw_id;
bool zero_initialize_shared_memory;
bool use_iters;
struct lp_cs_exec *current;
struct vertex_header *io;
size_t io_stride;
void *payload;
size_t payload_stride;
};
enum {
@ -81,6 +93,8 @@ enum {
CS_ARG_GRID_SIZE_Y,
CS_ARG_GRID_SIZE_Z,
CS_ARG_WORK_DIM,
CS_ARG_DRAW_ID,
CS_ARG_VERTEX_DATA,
CS_ARG_PER_THREAD_DATA,
CS_ARG_OUTER_COUNT,
CS_ARG_CORO_X_LOOPS = CS_ARG_OUTER_COUNT,
@ -90,9 +104,214 @@ enum {
CS_ARG_CORO_BLOCK_Z_SIZE,
CS_ARG_CORO_IDX,
CS_ARG_CORO_MEM,
CS_ARG_CORO_OUTPUTS,
CS_ARG_MAX,
};
struct lp_mesh_llvm_iface {
struct lp_build_mesh_iface base;
LLVMValueRef vertex_count;
LLVMValueRef prim_count;
LLVMValueRef outputs;
};
static inline const struct lp_mesh_llvm_iface *
lp_mesh_llvm_iface(const struct lp_build_mesh_iface *iface)
{
return (const struct lp_mesh_llvm_iface *)iface;
}
static LLVMTypeRef
create_mesh_jit_output_type_deref(struct gallivm_state *gallivm)
{
LLVMTypeRef float_type = LLVMFloatTypeInContext(gallivm->context);
LLVMTypeRef output_array;
output_array = LLVMArrayType(float_type, TGSI_NUM_CHANNELS); /* num channels */
output_array = LLVMArrayType(output_array, PIPE_MAX_SHADER_OUTPUTS); /* num attrs per vertex */
return output_array;
}
static void
lp_mesh_llvm_emit_store_output(const struct lp_build_mesh_iface *mesh_iface,
struct lp_build_context *bld,
unsigned name,
boolean is_vindex_indirect,
LLVMValueRef vertex_index,
boolean is_aindex_indirect,
LLVMValueRef attrib_index,
boolean is_sindex_indirect,
LLVMValueRef swizzle_index,
LLVMValueRef value,
LLVMValueRef mask_vec)
{
const struct lp_mesh_llvm_iface *mesh = lp_mesh_llvm_iface(mesh_iface);
struct gallivm_state *gallivm = bld->gallivm;
LLVMBuilderRef builder = gallivm->builder;
LLVMValueRef indices[3];
LLVMValueRef res;
struct lp_type type = bld->type;
LLVMTypeRef output_type = create_mesh_jit_output_type_deref(gallivm);
if (is_vindex_indirect || is_aindex_indirect || is_sindex_indirect) {
for (int i = 0; i < type.length; ++i) {
LLVMValueRef idx = lp_build_const_int32(gallivm, i);
LLVMValueRef vert_chan_index = vertex_index ? vertex_index : lp_build_const_int32(gallivm, 0);
LLVMValueRef attr_chan_index = attrib_index;
LLVMValueRef swiz_chan_index = swizzle_index;
LLVMValueRef channel_vec;
if (is_vindex_indirect) {
vert_chan_index = LLVMBuildExtractElement(builder,
vertex_index, idx, "");
}
if (is_aindex_indirect) {
attr_chan_index = LLVMBuildExtractElement(builder,
attrib_index, idx, "");
}
if (is_sindex_indirect) {
swiz_chan_index = LLVMBuildExtractElement(builder,
swizzle_index, idx, "");
}
indices[0] = vert_chan_index;
indices[1] = attr_chan_index;
indices[2] = swiz_chan_index;
channel_vec = LLVMBuildGEP2(builder, output_type, mesh->outputs, indices, 3, "");
res = LLVMBuildExtractElement(builder, value, idx, "");
struct lp_build_if_state ifthen;
LLVMValueRef cond = LLVMBuildICmp(gallivm->builder, LLVMIntNE, mask_vec, lp_build_const_int_vec(gallivm, bld->type, 0), "");
cond = LLVMBuildExtractElement(gallivm->builder, cond, idx, "");
lp_build_if(&ifthen, gallivm, cond);
LLVMBuildStore(builder, res, channel_vec);
lp_build_endif(&ifthen);
}
} else {
indices[0] = vertex_index ? vertex_index : lp_build_const_int32(gallivm, 0);
indices[1] = attrib_index;
indices[2] = swizzle_index;
res = LLVMBuildGEP2(builder, output_type, mesh->outputs, indices, 3, "");
for (unsigned i = 0; i < type.length; ++i) {
LLVMValueRef idx = lp_build_const_int32(gallivm, i);
LLVMValueRef val = LLVMBuildExtractElement(builder, value, idx, "");
struct lp_build_if_state ifthen;
LLVMValueRef cond = LLVMBuildICmp(gallivm->builder, LLVMIntNE, mask_vec, lp_build_const_int_vec(gallivm, bld->type, 0), "");
cond = LLVMBuildExtractElement(gallivm->builder, cond, idx, "");
lp_build_if(&ifthen, gallivm, cond);
LLVMBuildStore(builder, val, res);
lp_build_endif(&ifthen);
}
}
}
static void
lp_mesh_emit_vertex_and_primitive_count(const struct lp_build_mesh_iface *mesh_iface,
struct lp_build_context *bld,
LLVMValueRef vertices_count,
LLVMValueRef primitives_count)
{
const struct lp_mesh_llvm_iface *mesh = lp_mesh_llvm_iface(mesh_iface);
struct gallivm_state *gallivm = bld->gallivm;
LLVMBuildStore(gallivm->builder, vertices_count, mesh->vertex_count);
LLVMBuildStore(gallivm->builder, primitives_count, mesh->prim_count);
}
static void
mesh_convert_to_aos(struct gallivm_state *gallivm,
nir_shader *nir,
bool vert_only,
LLVMTypeRef io_type,
LLVMValueRef io,
LLVMValueRef outputs,
LLVMValueRef clipmask,
LLVMValueRef vertex_index,
struct lp_type soa_type,
int primid_slot,
boolean need_edgeflag)
{
LLVMBuilderRef builder = gallivm->builder;
LLVMValueRef inds[3];
LLVMTypeRef output_type = create_mesh_jit_output_type_deref(gallivm);
#if DEBUG_STORE
lp_build_printf(gallivm, " # storing begin\n");
#endif
int first_per_prim_attrib = -1;
nir_foreach_shader_out_variable(var, nir) {
if (var->data.per_primitive) {
first_per_prim_attrib = var->data.driver_location;
break;
}
}
nir_foreach_shader_out_variable(var, nir) {
if (vert_only && var->data.per_primitive)
continue;
if (!vert_only && !var->data.per_primitive)
continue;
int attrib = var->data.driver_location;
int slots = glsl_count_attribute_slots(glsl_get_array_element(var->type), false);
for (unsigned s = 0; s < slots; s++) {
LLVMValueRef soa[TGSI_NUM_CHANNELS];
LLVMValueRef aos[LP_MAX_VECTOR_WIDTH / 32];
for (unsigned chan = 0; chan < TGSI_NUM_CHANNELS; ++chan) {
inds[0] = vertex_index;
inds[1] = lp_build_const_int32(gallivm, attrib);
inds[2] = lp_build_const_int32(gallivm, chan);
LLVMValueRef res = LLVMBuildGEP2(builder, output_type, outputs, inds, 3, "");
LLVMTypeRef single_type = (attrib == primid_slot) ? lp_build_int_elem_type(gallivm, soa_type) : lp_build_elem_type(gallivm, soa_type);
LLVMValueRef out = LLVMBuildLoad2(builder, single_type, res, "");
lp_build_name(out, "output%u.%c", attrib, "xyzw"[chan]);
#if DEBUG_STORE
lp_build_printf(gallivm, "output %d : %d ",
LLVMConstInt(LLVMInt32TypeInContext(gallivm->context),
attrib, 0),
LLVMConstInt(LLVMInt32TypeInContext(gallivm->context),
chan, 0));
lp_build_print_value(gallivm, "val = ", out);
{
LLVMValueRef iv =
LLVMBuildBitCast(builder, out, lp_build_int_elem_type(gallivm, soa_type), "");
lp_build_print_value(gallivm, " ival = ", iv);
}
#endif
soa[chan] = out;
}
LLVMTypeRef float_type = LLVMFloatTypeInContext(gallivm->context);
aos[0] = LLVMGetUndef(LLVMVectorType(float_type, 4));
for (unsigned i = 0; i < 4; i++)
aos[0] = LLVMBuildInsertElement(builder, aos[0], soa[i], lp_build_const_int32(gallivm, i), "");
int aos_attrib = attrib;
if (var->data.per_primitive)
aos_attrib -= first_per_prim_attrib;
draw_store_aos_array(gallivm,
soa_type,
io_type,
io,
NULL,
aos,
aos_attrib,
clipmask,
need_edgeflag, var->data.per_primitive);
attrib++;
}
}
#if DEBUG_STORE
lp_build_printf(gallivm, " # storing end\n");
#endif
}
static void
generate_compute(struct llvmpipe_context *lp,
struct lp_compute_shader *shader,
@ -108,15 +327,25 @@ generate_compute(struct llvmpipe_context *lp,
LLVMValueRef block_x_size_arg, block_y_size_arg, block_z_size_arg;
LLVMValueRef grid_x_arg, grid_y_arg, grid_z_arg;
LLVMValueRef grid_size_x_arg, grid_size_y_arg, grid_size_z_arg;
LLVMValueRef work_dim_arg, thread_data_ptr;
LLVMValueRef work_dim_arg, draw_id_arg, thread_data_ptr, io_ptr;
LLVMBasicBlockRef block;
LLVMBuilderRef builder;
struct lp_build_sampler_soa *sampler;
struct lp_build_image_soa *image;
LLVMValueRef function, coro;
struct lp_type cs_type;
struct lp_mesh_llvm_iface mesh_iface;
bool is_mesh = false;
unsigned i;
LLVMValueRef output_array = NULL;
if (shader->base.type == PIPE_SHADER_IR_NIR) {
struct nir_shader *nir = shader->base.ir.nir;
if (nir->info.stage == MESA_SHADER_MESH) {
is_mesh = true;
}
}
/*
* This function has two parts
* a) setup the coroutine execution environment loop.
@ -146,6 +375,11 @@ generate_compute(struct llvmpipe_context *lp,
arg_types[CS_ARG_GRID_SIZE_Y] = int32_type; /* grid_size_y */
arg_types[CS_ARG_GRID_SIZE_Z] = int32_type; /* grid_size_z */
arg_types[CS_ARG_WORK_DIM] = int32_type; /* work dim */
arg_types[CS_ARG_DRAW_ID] = int32_type; /* draw id */
if (variant->jit_vertex_header_ptr_type)
arg_types[CS_ARG_VERTEX_DATA] = variant->jit_vertex_header_ptr_type; /* mesh shaders only */
else
arg_types[CS_ARG_VERTEX_DATA] = LLVMPointerType(LLVMInt8TypeInContext(gallivm->context), 0); /* mesh shaders only */
arg_types[CS_ARG_PER_THREAD_DATA] = variant->jit_cs_thread_data_ptr_type; /* per thread data */
arg_types[CS_ARG_CORO_X_LOOPS] = int32_type; /* coro only - num X loops */
arg_types[CS_ARG_CORO_PARTIALS] = int32_type; /* coro only - partials */
@ -154,11 +388,13 @@ generate_compute(struct llvmpipe_context *lp,
arg_types[CS_ARG_CORO_BLOCK_Z_SIZE] = int32_type; /* coro block_z_size */
arg_types[CS_ARG_CORO_IDX] = int32_type; /* coro idx */
arg_types[CS_ARG_CORO_MEM] = LLVMPointerType(LLVMPointerType(LLVMInt8TypeInContext(gallivm->context), 0), 0);
arg_types[CS_ARG_CORO_OUTPUTS] = LLVMPointerType(LLVMInt8TypeInContext(gallivm->context), 0); /* mesh shaders only */
func_type = LLVMFunctionType(LLVMVoidTypeInContext(gallivm->context),
arg_types, CS_ARG_OUTER_COUNT, 0);
coro_func_type = LLVMFunctionType(LLVMPointerType(LLVMInt8TypeInContext(gallivm->context), 0),
arg_types, CS_ARG_MAX, 0);
arg_types, CS_ARG_MAX - (!is_mesh), 0);
function = LLVMAddFunction(gallivm->module, func_name, func_type);
LLVMSetFunctionCallConv(function, LLVMCCallConv);
@ -169,7 +405,7 @@ generate_compute(struct llvmpipe_context *lp,
variant->function = function;
for (i = 0; i < CS_ARG_MAX; ++i) {
for (i = 0; i < CS_ARG_MAX - !is_mesh; ++i) {
if (LLVMGetTypeKind(arg_types[i]) == LLVMPointerTypeKind) {
lp_add_function_attr(coro, i + 1, LP_FUNC_ATTR_NOALIAS);
if (i < CS_ARG_OUTER_COUNT)
@ -192,6 +428,8 @@ generate_compute(struct llvmpipe_context *lp,
grid_size_y_arg = LLVMGetParam(function, CS_ARG_GRID_SIZE_Y);
grid_size_z_arg = LLVMGetParam(function, CS_ARG_GRID_SIZE_Z);
work_dim_arg = LLVMGetParam(function, CS_ARG_WORK_DIM);
draw_id_arg = LLVMGetParam(function, CS_ARG_DRAW_ID);
io_ptr = LLVMGetParam(function, CS_ARG_VERTEX_DATA);
thread_data_ptr = LLVMGetParam(function, CS_ARG_PER_THREAD_DATA);
lp_build_name(context_ptr, "context");
@ -206,7 +444,9 @@ generate_compute(struct llvmpipe_context *lp,
lp_build_name(grid_size_y_arg, "grid_size_y");
lp_build_name(grid_size_z_arg, "grid_size_z");
lp_build_name(work_dim_arg, "work_dim");
lp_build_name(draw_id_arg, "draw_id");
lp_build_name(thread_data_ptr, "thread_data");
lp_build_name(io_ptr, "vertex_io");
block = LLVMAppendBasicBlockInContext(gallivm->context, function, "entry");
builder = gallivm->builder;
@ -217,6 +457,12 @@ generate_compute(struct llvmpipe_context *lp,
key->nr_sampler_views));
image = lp_bld_llvm_image_soa_create(lp_cs_variant_key_images(key), key->nr_images);
if (is_mesh) {
struct nir_shader *nir = shader->base.ir.nir;
LLVMTypeRef output_type = create_mesh_jit_output_type_deref(gallivm);
output_array = lp_build_array_alloca(gallivm, output_type, lp_build_const_int32(gallivm, align(MAX2(nir->info.mesh.max_primitives_out, nir->info.mesh.max_vertices_out), 8)), "outputs");
}
struct lp_build_loop_state loop_state[4];
LLVMValueRef num_x_loop;
LLVMValueRef vec_length = lp_build_const_int32(gallivm, cs_type.length);
@ -265,6 +511,8 @@ generate_compute(struct llvmpipe_context *lp,
args[CS_ARG_GRID_SIZE_Y] = grid_size_y_arg;
args[CS_ARG_GRID_SIZE_Z] = grid_size_z_arg;
args[CS_ARG_WORK_DIM] = work_dim_arg;
args[CS_ARG_DRAW_ID] = draw_id_arg;
args[CS_ARG_VERTEX_DATA] = io_ptr;
args[CS_ARG_PER_THREAD_DATA] = thread_data_ptr;
args[CS_ARG_CORO_X_LOOPS] = num_x_loop;
args[CS_ARG_CORO_PARTIALS] = partials;
@ -284,6 +532,10 @@ generate_compute(struct llvmpipe_context *lp,
args[CS_ARG_CORO_IDX] = coro_hdl_idx;
args[CS_ARG_CORO_MEM] = coro_mem;
if (is_mesh)
args[CS_ARG_CORO_OUTPUTS] = output_array;
LLVMValueRef coro_entry = LLVMBuildGEP2(gallivm->builder, hdl_ptr_type, coro_hdls, &coro_hdl_idx, 1, "");
LLVMValueRef coro_hdl = LLVMBuildLoad2(gallivm->builder, hdl_ptr_type, coro_entry, "coro_hdl");
@ -293,7 +545,7 @@ generate_compute(struct llvmpipe_context *lp,
lp_build_const_int32(gallivm, 0), "");
/* first time here - call the coroutine function entry point */
lp_build_if(&ifstate, gallivm, cmp);
LLVMValueRef coro_ret = LLVMBuildCall2(gallivm->builder, coro_func_type, coro, args, CS_ARG_MAX, "");
LLVMValueRef coro_ret = LLVMBuildCall2(gallivm->builder, coro_func_type, coro, args, CS_ARG_MAX - !is_mesh, "");
LLVMBuildStore(gallivm->builder, coro_ret, coro_entry);
lp_build_else(&ifstate);
/* subsequent calls for this invocation - check if done. */
@ -344,6 +596,8 @@ generate_compute(struct llvmpipe_context *lp,
grid_size_y_arg = LLVMGetParam(coro, CS_ARG_GRID_SIZE_Y);
grid_size_z_arg = LLVMGetParam(coro, CS_ARG_GRID_SIZE_Z);
work_dim_arg = LLVMGetParam(coro, CS_ARG_WORK_DIM);
draw_id_arg = LLVMGetParam(coro, CS_ARG_DRAW_ID);
io_ptr = LLVMGetParam(coro, CS_ARG_VERTEX_DATA);
thread_data_ptr = LLVMGetParam(coro, CS_ARG_PER_THREAD_DATA);
num_x_loop = LLVMGetParam(coro, CS_ARG_CORO_X_LOOPS);
partials = LLVMGetParam(coro, CS_ARG_CORO_PARTIALS);
@ -352,12 +606,15 @@ generate_compute(struct llvmpipe_context *lp,
block_z_size_arg = LLVMGetParam(coro, CS_ARG_CORO_BLOCK_Z_SIZE);
LLVMValueRef coro_idx = LLVMGetParam(coro, CS_ARG_CORO_IDX);
coro_mem = LLVMGetParam(coro, CS_ARG_CORO_MEM);
if (is_mesh)
output_array = LLVMGetParam(coro, CS_ARG_CORO_OUTPUTS);
block = LLVMAppendBasicBlockInContext(gallivm->context, coro, "entry");
LLVMPositionBuilderAtEnd(builder, block);
{
LLVMValueRef consts_ptr;
LLVMValueRef ssbo_ptr;
LLVMValueRef shared_ptr;
LLVMValueRef payload_ptr;
LLVMValueRef kernel_args_ptr;
struct lp_build_mask_context mask;
struct lp_bld_tgsi_system_values system_values;
@ -372,6 +629,9 @@ generate_compute(struct llvmpipe_context *lp,
shared_ptr = lp_jit_cs_thread_data_shared(gallivm,
variant->jit_cs_thread_data_type,
thread_data_ptr);
payload_ptr = lp_jit_cs_thread_data_payload(gallivm,
variant->jit_cs_thread_data_type,
thread_data_ptr);
LLVMValueRef coro_num_hdls = LLVMBuildMul(gallivm->builder, num_x_loop, block_y_size_arg, "");
coro_num_hdls = LLVMBuildMul(gallivm->builder, coro_num_hdls, block_z_size_arg, "");
@ -410,6 +670,7 @@ generate_compute(struct llvmpipe_context *lp,
system_values.grid_size = LLVMBuildInsertElement(builder, system_values.grid_size, gstids[i], lp_build_const_int32(gallivm, i), "");
system_values.work_dim = work_dim_arg;
system_values.draw_id = draw_id_arg;
/* subgroup_id = ((z * block_size_x * block_size_y) + (y * block_size_x) + x) / subgroup_size
*
@ -470,6 +731,16 @@ generate_compute(struct llvmpipe_context *lp,
coro_info.suspend = sus_block;
coro_info.cleanup = clean_block;
if (is_mesh) {
LLVMValueRef vertex_count = lp_build_alloca(gallivm, LLVMInt32TypeInContext(gallivm->context), "vertex_count");
LLVMValueRef primitive_count = lp_build_alloca(gallivm, LLVMInt32TypeInContext(gallivm->context), "prim_count");
mesh_iface.base.emit_store_output = lp_mesh_llvm_emit_store_output;
mesh_iface.base.emit_vertex_and_primitive_count = lp_mesh_emit_vertex_and_primitive_count;
mesh_iface.vertex_count = vertex_count;
mesh_iface.prim_count = primitive_count;
mesh_iface.outputs = output_array;
}
struct lp_build_tgsi_params params;
memset(&params, 0, sizeof(params));
@ -486,11 +757,13 @@ generate_compute(struct llvmpipe_context *lp,
params.ssbo_ptr = ssbo_ptr;
params.image = image;
params.shared_ptr = shared_ptr;
params.payload_ptr = payload_ptr;
params.coro = &coro_info;
params.kernel_args = kernel_args_ptr;
params.aniso_filter_table = lp_jit_resources_aniso_filter_table(gallivm,
variant->jit_resources_type,
resources_ptr);
params.mesh_iface = &mesh_iface.base;
if (shader->base.type == PIPE_SHADER_IR_TGSI)
lp_build_tgsi_soa(gallivm, shader->base.tokens, &params, NULL);
@ -498,6 +771,73 @@ generate_compute(struct llvmpipe_context *lp,
lp_build_nir_soa(gallivm, shader->base.ir.nir, &params,
NULL);
if (is_mesh) {
LLVMTypeRef i32t = LLVMInt32TypeInContext(gallivm->context);
LLVMValueRef clipmask = lp_build_const_int_vec(gallivm,
lp_int_type(cs_type), 0);
struct lp_build_if_state iter0state;
LLVMValueRef is_iter0 = LLVMBuildICmp(gallivm->builder, LLVMIntEQ, coro_idx,
lp_build_const_int32(gallivm, 0), "");
LLVMValueRef vertex_count = LLVMBuildLoad2(gallivm->builder, i32t, mesh_iface.vertex_count, "");
LLVMValueRef prim_count = LLVMBuildLoad2(gallivm->builder, i32t, mesh_iface.prim_count, "");
LLVMValueRef vert_count_ptr, prim_count_ptr;
LLVMValueRef indices = lp_build_const_int32(gallivm, 1);
vert_count_ptr = LLVMBuildGEP2(gallivm->builder, i32t, io_ptr, &indices, 1, "");
indices = lp_build_const_int32(gallivm, 2);
prim_count_ptr = LLVMBuildGEP2(gallivm->builder, i32t, io_ptr, &indices, 1, "");
lp_build_if(&iter0state, gallivm, is_iter0);
LLVMBuildStore(gallivm->builder, vertex_count, vert_count_ptr);
LLVMBuildStore(gallivm->builder, prim_count, prim_count_ptr);
lp_build_endif(&iter0state);
LLVMBasicBlockRef resume = lp_build_insert_new_block(gallivm, "resume");
lp_build_coro_suspend_switch(gallivm, params.coro, resume, false);
LLVMPositionBuilderAtEnd(gallivm->builder, resume);
vertex_count = LLVMBuildLoad2(gallivm->builder, i32t, vert_count_ptr, "");
prim_count = LLVMBuildLoad2(gallivm->builder, i32t, prim_count_ptr, "");
nir_shader *nir = shader->base.ir.nir;
int per_prim_count = util_bitcount64(nir->info.per_primitive_outputs);
int out_count = util_bitcount64(nir->info.outputs_written);
int per_vert_count = out_count - per_prim_count;
int vsize = (sizeof(struct vertex_header) + per_vert_count * 4 * sizeof(float)) * 8;
int psize = (per_prim_count * 4 * sizeof(float)) * 8;
struct lp_build_loop_state vertex_loop_state;
lp_build_loop_begin(&vertex_loop_state, gallivm,
lp_build_const_int32(gallivm, 0));
LLVMValueRef io;
io = LLVMBuildPtrToInt(gallivm->builder, io_ptr, LLVMInt64TypeInContext(gallivm->context), "");
io = LLVMBuildAdd(builder, io, LLVMBuildZExt(builder, LLVMBuildMul(builder, vertex_loop_state.counter, lp_build_const_int32(gallivm, vsize), ""), LLVMInt64TypeInContext(gallivm->context), ""), "");
io = LLVMBuildIntToPtr(gallivm->builder, io, LLVMPointerType(LLVMVoidTypeInContext(gallivm->context), 0), "");
mesh_convert_to_aos(gallivm, shader->base.ir.nir, true, variant->jit_vertex_header_type,
io, output_array, clipmask,
vertex_loop_state.counter, lp_elem_type(cs_type), -1, FALSE);
lp_build_loop_end_cond(&vertex_loop_state,
vertex_count,
NULL, LLVMIntUGE);
struct lp_build_loop_state prim_loop_state;
lp_build_loop_begin(&prim_loop_state, gallivm,
lp_build_const_int32(gallivm, 0));
io = LLVMBuildPtrToInt(gallivm->builder, io_ptr, LLVMInt64TypeInContext(gallivm->context), "");
LLVMValueRef prim_offset = LLVMBuildMul(builder, prim_loop_state.counter, lp_build_const_int32(gallivm, psize), "");
prim_offset = LLVMBuildAdd(builder, prim_offset, lp_build_const_int32(gallivm, vsize * (nir->info.mesh.max_vertices_out + 8)), "");
io = LLVMBuildAdd(builder, io, LLVMBuildZExt(builder, prim_offset, LLVMInt64TypeInContext(gallivm->context), ""), "");
io = LLVMBuildIntToPtr(gallivm->builder, io, LLVMPointerType(LLVMVoidTypeInContext(gallivm->context), 0), "");
mesh_convert_to_aos(gallivm, shader->base.ir.nir, false, variant->jit_prim_type,
io, output_array, clipmask,
prim_loop_state.counter, lp_elem_type(cs_type), -1, FALSE);
lp_build_loop_end_cond(&prim_loop_state,
prim_count,
NULL, LLVMIntUGE);
}
mask_val = lp_build_mask_end(&mask);
lp_build_coro_suspend_switch(gallivm, &coro_info, NULL, true);
@ -837,7 +1177,8 @@ generate_variant(struct llvmpipe_context *lp,
memset(variant, 0, sizeof(*variant));
char module_name[64];
const char *shname = "cs";
const char *shname = sh_type == PIPE_SHADER_MESH ? "ms" :
(sh_type == PIPE_SHADER_TASK ? "ts" : "cs");
snprintf(module_name, sizeof(module_name), "%s%u_variant%u",
shname, shader->no, shader->variants_created);
@ -871,6 +1212,16 @@ generate_variant(struct llvmpipe_context *lp,
lp_jit_init_cs_types(variant);
if (sh_type == PIPE_SHADER_MESH) {
struct nir_shader *nir = shader->base.ir.nir;
int per_prim_count = util_bitcount64(nir->info.per_primitive_outputs);
int out_count = util_bitcount64(nir->info.outputs_written);
int per_vert_count = out_count - per_prim_count;
variant->jit_vertex_header_type = lp_build_create_jit_vertex_header_type(variant->gallivm, per_vert_count);
variant->jit_vertex_header_ptr_type = LLVMPointerType(variant->jit_vertex_header_type, 0);
variant->jit_prim_type = LLVMArrayType(LLVMArrayType(LLVMFloatTypeInContext(variant->gallivm->context), 4), per_prim_count);
}
generate_compute(lp, shader, variant);
gallivm_compile_module(variant->gallivm);
@ -1443,19 +1794,41 @@ cs_exec_fn(void *init_data, int iter_idx, struct lp_cs_local_mem *lmem)
memset(lmem->local_mem_ptr, 0, job_info->req_local_mem);
thread_data.shared = lmem->local_mem_ptr;
unsigned grid_z = iter_idx / (job_info->grid_size[0] * job_info->grid_size[1]);
unsigned grid_y = (iter_idx - (grid_z * (job_info->grid_size[0] * job_info->grid_size[1]))) / job_info->grid_size[0];
unsigned grid_x = (iter_idx - (grid_z * (job_info->grid_size[0] * job_info->grid_size[1])) - (grid_y * job_info->grid_size[0]));
thread_data.payload = job_info->payload;
unsigned grid_z, grid_y, grid_x;
if (job_info->use_iters) {
grid_z = iter_idx / (job_info->iter_size[0] * job_info->iter_size[1]);
grid_y = (iter_idx - (grid_z * (job_info->iter_size[0] * job_info->iter_size[1]))) / job_info->iter_size[0];
grid_x = (iter_idx - (grid_z * (job_info->iter_size[0] * job_info->iter_size[1])) - (grid_y * job_info->iter_size[0]));
} else {
grid_z = iter_idx / (job_info->grid_size[0] * job_info->grid_size[1]);
grid_y = (iter_idx - (grid_z * (job_info->grid_size[0] * job_info->grid_size[1]))) / job_info->grid_size[0];
grid_x = (iter_idx - (grid_z * (job_info->grid_size[0] * job_info->grid_size[1])) - (grid_y * job_info->grid_size[0]));
}
grid_z += job_info->grid_base[2];
grid_y += job_info->grid_base[1];
grid_x += job_info->grid_base[0];
struct lp_compute_shader_variant *variant = job_info->current->variant;
void *io_ptr = NULL;
if (job_info->io) {
size_t io_offset = job_info->io_stride * iter_idx;
io_ptr = (char *)job_info->io + io_offset;
}
if (thread_data.payload) {
size_t payload_offset = job_info->payload_stride * iter_idx;
thread_data.payload = (char *)thread_data.payload + payload_offset;
}
variant->jit_function(&job_info->current->jit_context,
&job_info->current->jit_resources,
job_info->block_size[0], job_info->block_size[1], job_info->block_size[2],
grid_x, grid_y, grid_z,
job_info->grid_size[0], job_info->grid_size[1], job_info->grid_size[2], job_info->work_dim,
job_info->grid_size[0], job_info->grid_size[1], job_info->grid_size[2],
job_info->work_dim, job_info->draw_id,
io_ptr,
&thread_data);
}

View file

@ -86,6 +86,10 @@ struct lp_compute_shader_variant
LLVMTypeRef jit_resources_ptr_type;
LLVMTypeRef jit_cs_thread_data_ptr_type;
/* for mesh shaders */
LLVMTypeRef jit_vertex_header_type;
LLVMTypeRef jit_vertex_header_ptr_type;
LLVMTypeRef jit_prim_type;
LLVMValueRef function;
lp_jit_cs_func jit_function;