intel/compiler/mesh: use U888X packed index format

Reviewed-by: Caio Oliveira <caio.oliveira@intel.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/20910>
This commit is contained in:
Marcin Ślusarz 2023-01-25 15:06:23 +01:00 committed by Marge Bot
parent 2d20564a6a
commit 465c241266
3 changed files with 212 additions and 11 deletions

View file

@ -1600,6 +1600,7 @@ struct brw_tue_map {
struct brw_mue_map {
int32_t start_dw[VARYING_SLOT_MAX];
uint32_t per_primitive_indices_dw;
uint32_t size_dw;
@ -1624,6 +1625,7 @@ struct brw_task_prog_data {
enum brw_mesh_index_format {
BRW_INDEX_FORMAT_U32,
BRW_INDEX_FORMAT_U888X,
};
struct brw_mesh_prog_data {

View file

@ -434,7 +434,8 @@ brw_nir_lower_tue_inputs(nir_shader *nir, const brw_tue_map *map)
* the pitch.
*/
static void
brw_compute_mue_map(struct nir_shader *nir, struct brw_mue_map *map)
brw_compute_mue_map(struct nir_shader *nir, struct brw_mue_map *map,
enum brw_mesh_index_format index_format)
{
memset(map, 0, sizeof(*map));
@ -459,10 +460,20 @@ brw_compute_mue_map(struct nir_shader *nir, struct brw_mue_map *map)
outputs_written &= ~BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_INDICES);
}
/* One dword for primitives count then K extra dwords for each
* primitive. Note this should change when we implement other index types.
*/
const unsigned primitive_list_size_dw = 1 + vertices_per_primitive * map->max_primitives;
/* One dword for primitives count then K extra dwords for each primitive. */
switch (index_format) {
case BRW_INDEX_FORMAT_U32:
map->per_primitive_indices_dw = vertices_per_primitive;
break;
case BRW_INDEX_FORMAT_U888X:
map->per_primitive_indices_dw = 1;
break;
default:
unreachable("invalid index format");
}
map->per_primitive_start_dw = ALIGN(map->per_primitive_indices_dw *
map->max_primitives + 1, 8);
/* TODO(mesh): Multiview. */
map->per_primitive_header_size_dw =
@ -471,8 +482,6 @@ brw_compute_mue_map(struct nir_shader *nir, struct brw_mue_map *map)
BITFIELD64_BIT(VARYING_SLOT_PRIMITIVE_SHADING_RATE) |
BITFIELD64_BIT(VARYING_SLOT_LAYER))) ? 8 : 0;
map->per_primitive_start_dw = ALIGN(primitive_list_size_dw, 8);
map->per_primitive_data_size_dw = 0;
u_foreach_bit64(location, outputs_written & nir->info.per_primitive_outputs) {
assert(map->start_dw[location] == -1);
@ -747,7 +756,7 @@ brw_nir_adjust_offset_for_arrayed_indices_instr(nir_builder *b, nir_instr *instr
struct nir_io_semantics sem = nir_intrinsic_io_semantics(intrin);
uint32_t pitch;
if (sem.location == VARYING_SLOT_PRIMITIVE_INDICES)
pitch = num_mesh_vertices_per_primitive(b->shader->info.mesh.primitive_type);
pitch = map->per_primitive_indices_dw;
else
pitch = map->per_primitive_pitch_dw;
@ -771,6 +780,187 @@ brw_nir_adjust_offset_for_arrayed_indices(nir_shader *nir, const struct brw_mue_
(void *)map);
}
struct index_packing_state {
unsigned vertices_per_primitive;
nir_variable *original_prim_indices;
nir_variable *packed_prim_indices;
};
static bool
brw_can_pack_primitive_indices(nir_shader *nir, struct index_packing_state *state)
{
/* NV_mesh_shader primitive indices are stored as a flat array instead
* of an array of primitives. Don't bother with this for now.
*/
if (nir->info.mesh.nv)
return false;
/* can single index fit into one byte of U888X format? */
if (nir->info.mesh.max_vertices_out > 255)
return false;
state->vertices_per_primitive =
num_mesh_vertices_per_primitive(nir->info.mesh.primitive_type);
/* packing point indices doesn't help */
if (state->vertices_per_primitive == 1)
return false;
state->original_prim_indices =
nir_find_variable_with_location(nir,
nir_var_shader_out,
VARYING_SLOT_PRIMITIVE_INDICES);
/* no indices = no changes to the shader, but it's still worth it,
* because less URB space will be used
*/
if (!state->original_prim_indices)
return true;
ASSERTED const struct glsl_type *type = state->original_prim_indices->type;
assert(type->is_array());
assert(type->without_array()->is_vector());
assert(type->without_array()->vector_elements == state->vertices_per_primitive);
nir_foreach_function(function, nir) {
if (!function->impl)
continue;
nir_foreach_block(block, function->impl) {
nir_foreach_instr(instr, block) {
if (instr->type != nir_instr_type_intrinsic)
continue;
nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
if (intrin->intrinsic != nir_intrinsic_store_deref) {
/* any unknown deref operation on primitive indices -> don't pack */
unsigned num_srcs = nir_intrinsic_infos[intrin->intrinsic].num_srcs;
for (unsigned i = 0; i < num_srcs; i++) {
nir_deref_instr *deref = nir_src_as_deref(intrin->src[i]);
if (!deref)
continue;
nir_variable *var = nir_deref_instr_get_variable(deref);
if (var == state->original_prim_indices)
return false;
}
continue;
}
nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]);
if (!deref)
continue;
nir_variable *var = nir_deref_instr_get_variable(deref);
if (var != state->original_prim_indices)
continue;
if (deref->deref_type != nir_deref_type_array)
return false; /* unknown chain of derefs */
nir_deref_instr *var_deref = nir_src_as_deref(deref->parent);
if (!var_deref || var_deref->deref_type != nir_deref_type_var)
return false; /* unknown chain of derefs */
assert (var_deref->var == state->original_prim_indices);
unsigned write_mask = nir_intrinsic_write_mask(intrin);
/* If only some components are written, then we can't easily pack.
* In theory we could, by loading current dword value, bitmasking
* one byte and storing back the whole dword, but it would be slow
* and could actually decrease performance. TODO: reevaluate this
* once there will be something hitting this.
*/
if (write_mask != BITFIELD_MASK(state->vertices_per_primitive))
return false;
}
}
}
return true;
}
static bool
brw_pack_primitive_indices_instr(nir_builder *b, nir_instr *instr, void *data)
{
if (instr->type != nir_instr_type_intrinsic)
return false;
nir_intrinsic_instr *intrin = nir_instr_as_intrinsic(instr);
if (intrin->intrinsic != nir_intrinsic_store_deref)
return false;
nir_deref_instr *array_deref = nir_src_as_deref(intrin->src[0]);
if (!array_deref || array_deref->deref_type != nir_deref_type_array)
return false;
nir_deref_instr *var_deref = nir_src_as_deref(array_deref->parent);
if (!var_deref || var_deref->deref_type != nir_deref_type_var)
return false;
struct index_packing_state *state =
(struct index_packing_state *)data;
nir_variable *var = var_deref->var;
if (var != state->original_prim_indices)
return false;
unsigned vertices_per_primitive = state->vertices_per_primitive;
b->cursor = nir_before_instr(&intrin->instr);
nir_deref_instr *new_var_deref =
nir_build_deref_var(b, state->packed_prim_indices);
nir_deref_instr *new_array_deref =
nir_build_deref_array(b, new_var_deref, array_deref->arr.index.ssa);
nir_src *data_src = &intrin->src[1];
nir_ssa_def *data_def =
nir_ssa_for_src(b, *data_src, vertices_per_primitive);
nir_ssa_def *new_data =
nir_ior(b, nir_ishl_imm(b, nir_channel(b, data_def, 0), 0),
nir_ishl_imm(b, nir_channel(b, data_def, 1), 8));
if (vertices_per_primitive >= 3) {
new_data =
nir_ior(b, new_data,
nir_ishl_imm(b, nir_channel(b, data_def, 2), 16));
}
nir_build_store_deref(b, &new_array_deref->dest.ssa, new_data);
nir_instr_remove(instr);
return true;
}
static bool
brw_pack_primitive_indices(nir_shader *nir, void *data)
{
struct index_packing_state *state = (struct index_packing_state *)data;
const struct glsl_type *new_type =
glsl_array_type(glsl_uint_type(),
nir->info.mesh.max_primitives_out,
0);
state->packed_prim_indices =
nir_variable_create(nir, nir_var_shader_out,
new_type, "gl_PrimitiveIndicesPacked");
state->packed_prim_indices->data.location = VARYING_SLOT_PRIMITIVE_INDICES;
state->packed_prim_indices->data.interpolation = INTERP_MODE_NONE;
state->packed_prim_indices->data.per_primitive = 1;
return nir_shader_instructions_pass(nir,
brw_pack_primitive_indices_instr,
nir_metadata_block_index |
nir_metadata_dominance,
data);
}
const unsigned *
brw_compile_mesh(const struct brw_compiler *compiler,
void *mem_ctx,
@ -795,15 +985,21 @@ brw_compile_mesh(const struct brw_compiler *compiler,
nir->info.clip_distance_array_size;
prog_data->primitive_type = nir->info.mesh.primitive_type;
/* TODO(mesh): Use other index formats (that are more compact) for optimization. */
prog_data->index_format = BRW_INDEX_FORMAT_U32;
struct index_packing_state index_packing_state = {};
if (brw_can_pack_primitive_indices(nir, &index_packing_state)) {
if (index_packing_state.original_prim_indices)
NIR_PASS(_, nir, brw_pack_primitive_indices, &index_packing_state);
prog_data->index_format = BRW_INDEX_FORMAT_U888X;
} else {
prog_data->index_format = BRW_INDEX_FORMAT_U32;
}
prog_data->uses_drawid =
BITSET_TEST(nir->info.system_values_read, SYSTEM_VALUE_DRAW_ID);
brw_nir_lower_tue_inputs(nir, params->tue_map);
brw_compute_mue_map(nir, &prog_data->map);
brw_compute_mue_map(nir, &prog_data->map, prog_data->index_format);
brw_nir_lower_mue_outputs(nir, &prog_data->map);
brw_simd_selection_state simd_state{

View file

@ -1797,6 +1797,9 @@ emit_mesh_state(struct anv_graphics_pipeline *pipeline)
case BRW_INDEX_FORMAT_U32:
index_format = INDEX_U32;
break;
case BRW_INDEX_FORMAT_U888X:
index_format = INDEX_U888X;
break;
default:
unreachable("invalid index format");
}