diff --git a/src/intel/compiler/brw_mesh.cpp b/src/intel/compiler/brw_mesh.cpp index ce652abd715..6690317f955 100644 --- a/src/intel/compiler/brw_mesh.cpp +++ b/src/intel/compiler/brw_mesh.cpp @@ -458,6 +458,84 @@ brw_nir_lower_mue_outputs(nir_shader *nir, const struct brw_mue_map *map) nir_lower_io_lower_64bit_to_32); } +static void +brw_nir_initialize_mue(nir_shader *nir, + const struct brw_mue_map *map, + unsigned dispatch_width) +{ + assert(map->per_primitive_header_size_dw > 0); + + nir_builder b; + nir_function_impl *entrypoint = nir_shader_get_entrypoint(nir); + nir_builder_init(&b, entrypoint); + b.cursor = nir_before_block(nir_start_block(entrypoint)); + + nir_ssa_def *dw_off = nir_imm_int(&b, 0); + nir_ssa_def *zerovec = nir_imm_vec4(&b, 0, 0, 0, 0); + + /* TODO(mesh): can we write in bigger batches, generating fewer SENDs? */ + + assert(!nir->info.workgroup_size_variable); + const unsigned workgroup_size = nir->info.workgroup_size[0] * + nir->info.workgroup_size[1] * + nir->info.workgroup_size[2]; + + /* Invocations from a single workgroup will cooperate in zeroing MUE. */ + + /* How many prims each invocation needs to cover without checking its index? */ + unsigned prims_per_inv = map->max_primitives / workgroup_size; + + /* Zero first 4 dwords of MUE Primitive Header: + * Reserved, RTAIndex, ViewportIndex, CullPrimitiveMask. + */ + + nir_ssa_def *local_invocation_index = nir_load_local_invocation_index(&b); + + /* Zero primitive headers distanced by workgroup_size, starting from + * invocation index. + */ + for (unsigned prim_in_inv = 0; prim_in_inv < prims_per_inv; ++prim_in_inv) { + nir_ssa_def *prim = nir_iadd_imm(&b, local_invocation_index, + prim_in_inv * workgroup_size); + + nir_store_per_primitive_output(&b, zerovec, prim, dw_off, + .base = (int)map->per_primitive_start_dw, + .write_mask = WRITEMASK_XYZW, + .src_type = nir_type_uint32); + } + + /* How many prims are left? */ + unsigned remaining = map->max_primitives % workgroup_size; + + if (remaining) { + /* Zero "remaining" primitive headers starting from the last one covered + * by the loop above + workgroup_size. + */ + nir_ssa_def *cmp = nir_ilt(&b, local_invocation_index, + nir_imm_int(&b, remaining)); + nir_if *if_stmt = nir_push_if(&b, cmp); + { + nir_ssa_def *prim = nir_iadd_imm(&b, local_invocation_index, + prims_per_inv * workgroup_size); + + nir_store_per_primitive_output(&b, zerovec, prim, dw_off, + .base = (int)map->per_primitive_start_dw, + .write_mask = WRITEMASK_XYZW, + .src_type = nir_type_uint32); + } + nir_pop_if(&b, if_stmt); + } + + /* If there's more than one subgroup, then we need to wait for all of them + * to finish initialization before we can proceed. Otherwise some subgroups + * may start filling MUE before other finished initializing. + */ + if (workgroup_size > dispatch_width) { + nir_scoped_barrier(&b, NIR_SCOPE_WORKGROUP, NIR_SCOPE_WORKGROUP, + NIR_MEMORY_ACQ_REL, nir_var_shader_out); + } +} + static bool brw_nir_adjust_offset_for_arrayed_indices_instr(nir_builder *b, nir_instr *instr, void *data) { @@ -554,7 +632,6 @@ brw_compile_mesh(const struct brw_compiler *compiler, brw_compute_mue_map(nir, &prog_data->map); NIR_PASS_V(nir, brw_nir_lower_mue_outputs, &prog_data->map); - NIR_PASS_V(nir, brw_nir_adjust_offset_for_arrayed_indices, &prog_data->map); const unsigned required_dispatch_width = brw_required_dispatch_width(&nir->info, key->base.subgroup_size_type); @@ -570,8 +647,17 @@ brw_compile_mesh(const struct brw_compiler *compiler, const unsigned dispatch_width = 8 << simd; nir_shader *shader = nir_shader_clone(mem_ctx, nir); + + /* + * When Primitive Header is enabled, we may not generates writes to all + * fields, so let's initialize everything. + */ + if (prog_data->map.per_primitive_header_size_dw > 0) + NIR_PASS_V(shader, brw_nir_initialize_mue, &prog_data->map, dispatch_width); + brw_nir_apply_key(shader, compiler, &key->base, dispatch_width, true /* is_scalar */); + NIR_PASS_V(shader, brw_nir_adjust_offset_for_arrayed_indices, &prog_data->map); /* Load uniforms can do a better job for constants, so fold before it. */ NIR_PASS_V(shader, nir_opt_constant_folding); NIR_PASS_V(shader, brw_nir_lower_load_uniforms);