nir: add a filter cb to lower_io_to_scalar

this is useful for drivers that want to do selective scalarization
of io

Reviewed-by: Timur Kristóf <timur.kristof@gmail.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/24565>
This commit is contained in:
Mike Blumenkrantz 2023-07-21 12:53:49 -04:00 committed by Marge Bot
parent 550f3dc437
commit e9a5da2f4b
13 changed files with 51 additions and 35 deletions

View file

@ -697,7 +697,7 @@ radv_postprocess_nir(struct radv_device *device, const struct radv_pipeline_key
});
if (radv_use_llvm_for_stage(device, stage->stage))
NIR_PASS_V(stage->nir, nir_lower_io_to_scalar, nir_var_mem_global);
NIR_PASS_V(stage->nir, nir_lower_io_to_scalar, nir_var_mem_global, NULL, NULL);
NIR_PASS(_, stage->nir, ac_nir_lower_global_access);
NIR_PASS_V(stage->nir, ac_nir_lower_intrinsics_to_args, gfx_level, radv_select_hw_stage(&stage->info, gfx_level),

View file

@ -2685,7 +2685,7 @@ agx_compile_shader_nir(nir_shader *nir, struct agx_shader_key *key,
* transform feedback programs will use vector output.
*/
if (nir->info.stage == MESA_SHADER_VERTEX)
NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_shader_out);
NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_shader_out, NULL, NULL);
out->push_count = key->reserved_preamble;
agx_optimize_nir(nir, &out->push_count);

View file

@ -1040,7 +1040,7 @@ v3d_nir_lower_gs_late(struct v3d_compile *c)
}
/* Note: GS output scalarizing must happen after nir_lower_clip_gs. */
NIR_PASS_V(c->s, nir_lower_io_to_scalar, nir_var_shader_out);
NIR_PASS_V(c->s, nir_lower_io_to_scalar, nir_var_shader_out, NULL, NULL);
}
static void
@ -1050,11 +1050,11 @@ v3d_nir_lower_vs_late(struct v3d_compile *c)
NIR_PASS(_, c->s, nir_lower_clip_vs, c->key->ucp_enables,
false, false, NULL);
NIR_PASS_V(c->s, nir_lower_io_to_scalar,
nir_var_shader_out);
nir_var_shader_out, NULL, NULL);
}
/* Note: VS output scalarizing must happen after nir_lower_clip_vs. */
NIR_PASS_V(c->s, nir_lower_io_to_scalar, nir_var_shader_out);
NIR_PASS_V(c->s, nir_lower_io_to_scalar, nir_var_shader_out, NULL, NULL);
}
static void
@ -1070,7 +1070,7 @@ v3d_nir_lower_fs_late(struct v3d_compile *c)
if (c->key->ucp_enables)
NIR_PASS(_, c->s, nir_lower_clip_fs, c->key->ucp_enables, true);
NIR_PASS_V(c->s, nir_lower_io_to_scalar, nir_var_shader_in);
NIR_PASS_V(c->s, nir_lower_io_to_scalar, nir_var_shader_in, NULL, NULL);
}
static uint32_t

View file

@ -5242,7 +5242,7 @@ bool nir_lower_phis_to_scalar(nir_shader *shader, bool lower_all);
void nir_lower_io_arrays_to_elements(nir_shader *producer, nir_shader *consumer);
void nir_lower_io_arrays_to_elements_no_indirects(nir_shader *shader,
bool outputs_only);
bool nir_lower_io_to_scalar(nir_shader *shader, nir_variable_mode mask);
bool nir_lower_io_to_scalar(nir_shader *shader, nir_variable_mode mask, nir_instr_filter_cb filter, void *filter_data);
bool nir_lower_io_to_scalar_early(nir_shader *shader, nir_variable_mode mask);
bool nir_lower_io_to_vector(nir_shader *shader, nir_variable_mode mask);
bool nir_vectorize_tess_levels(nir_shader *shader);

View file

@ -231,10 +231,16 @@ lower_store_to_scalar(nir_builder *b, nir_intrinsic_instr *intr)
nir_instr_remove(&intr->instr);
}
struct scalarize_state {
nir_variable_mode mask;
nir_instr_filter_cb filter;
void *filter_data;
};
static bool
nir_lower_io_to_scalar_instr(nir_builder *b, nir_instr *instr, void *data)
{
nir_variable_mode mask = *(nir_variable_mode *)data;
struct scalarize_state *state = data;
if (instr->type != nir_instr_type_intrinsic)
return false;
@ -247,36 +253,41 @@ nir_lower_io_to_scalar_instr(nir_builder *b, nir_instr *instr, void *data)
if ((intr->intrinsic == nir_intrinsic_load_input ||
intr->intrinsic == nir_intrinsic_load_per_vertex_input ||
intr->intrinsic == nir_intrinsic_load_interpolated_input) &&
(mask & nir_var_shader_in)) {
(state->mask & nir_var_shader_in) &&
(!state->filter || state->filter(instr, state->filter_data))) {
lower_load_input_to_scalar(b, intr);
return true;
}
if ((intr->intrinsic == nir_intrinsic_load_output ||
intr->intrinsic == nir_intrinsic_load_per_vertex_output) &&
(mask & nir_var_shader_out)) {
(state->mask & nir_var_shader_out) &&
(!state->filter || state->filter(instr, state->filter_data))) {
lower_load_input_to_scalar(b, intr);
return true;
}
if ((intr->intrinsic == nir_intrinsic_load_ubo && (mask & nir_var_mem_ubo)) ||
(intr->intrinsic == nir_intrinsic_load_ssbo && (mask & nir_var_mem_ssbo)) ||
(intr->intrinsic == nir_intrinsic_load_global && (mask & nir_var_mem_global)) ||
(intr->intrinsic == nir_intrinsic_load_shared && (mask & nir_var_mem_shared))) {
if (((intr->intrinsic == nir_intrinsic_load_ubo && (state->mask & nir_var_mem_ubo)) ||
(intr->intrinsic == nir_intrinsic_load_ssbo && (state->mask & nir_var_mem_ssbo)) ||
(intr->intrinsic == nir_intrinsic_load_global && (state->mask & nir_var_mem_global)) ||
(intr->intrinsic == nir_intrinsic_load_shared && (state->mask & nir_var_mem_shared))) &&
(!state->filter || state->filter(instr, state->filter_data))) {
lower_load_to_scalar(b, intr);
return true;
}
if ((intr->intrinsic == nir_intrinsic_store_output ||
intr->intrinsic == nir_intrinsic_store_per_vertex_output) &&
mask & nir_var_shader_out) {
state->mask & nir_var_shader_out &&
(!state->filter || state->filter(instr, state->filter_data))) {
lower_store_output_to_scalar(b, intr);
return true;
}
if ((intr->intrinsic == nir_intrinsic_store_ssbo && (mask & nir_var_mem_ssbo)) ||
(intr->intrinsic == nir_intrinsic_store_global && (mask & nir_var_mem_global)) ||
(intr->intrinsic == nir_intrinsic_store_shared && (mask & nir_var_mem_shared))) {
if (((intr->intrinsic == nir_intrinsic_store_ssbo && (state->mask & nir_var_mem_ssbo)) ||
(intr->intrinsic == nir_intrinsic_store_global && (state->mask & nir_var_mem_global)) ||
(intr->intrinsic == nir_intrinsic_store_shared && (state->mask & nir_var_mem_shared))) &&
(!state->filter || state->filter(instr, state->filter_data))) {
lower_store_to_scalar(b, intr);
return true;
}
@ -285,13 +296,18 @@ nir_lower_io_to_scalar_instr(nir_builder *b, nir_instr *instr, void *data)
}
bool
nir_lower_io_to_scalar(nir_shader *shader, nir_variable_mode mask)
nir_lower_io_to_scalar(nir_shader *shader, nir_variable_mode mask, nir_instr_filter_cb filter, void *filter_data)
{
struct scalarize_state state = {
mask,
filter,
filter_data
};
return nir_shader_instructions_pass(shader,
nir_lower_io_to_scalar_instr,
nir_metadata_block_index |
nir_metadata_dominance,
&mask);
&state);
}
static nir_variable **

View file

@ -647,7 +647,7 @@ ir3_nir_lower_variant(struct ir3_shader_variant *so, nir_shader *s)
bool progress = false;
NIR_PASS_V(s, nir_lower_io_to_scalar, nir_var_mem_ssbo);
NIR_PASS_V(s, nir_lower_io_to_scalar, nir_var_mem_ssbo, NULL, NULL);
if (so->key.has_gs || so->key.tessellation) {
switch (so->type) {
@ -658,7 +658,7 @@ ir3_nir_lower_variant(struct ir3_shader_variant *so, nir_shader *s)
break;
case MESA_SHADER_TESS_CTRL:
NIR_PASS_V(s, nir_lower_io_to_scalar,
nir_var_shader_in | nir_var_shader_out);
nir_var_shader_in | nir_var_shader_out, NULL, NULL);
NIR_PASS_V(s, ir3_nir_lower_tess_ctrl, so, so->key.tessellation);
NIR_PASS_V(s, ir3_nir_lower_to_explicit_input, so);
progress = true;

View file

@ -122,7 +122,7 @@ lima_program_optimize_vs_nir(struct nir_shader *s)
NIR_PASS_V(s, nir_lower_load_const_to_scalar);
NIR_PASS_V(s, lima_nir_lower_uniform_to_scalar);
NIR_PASS_V(s, nir_lower_io_to_scalar,
nir_var_shader_in|nir_var_shader_out);
nir_var_shader_in|nir_var_shader_out, NULL, NULL);
do {
progress = false;

View file

@ -1774,7 +1774,7 @@ static void si_lower_ngg(struct si_shader *shader, nir_shader *nir)
NIR_PASS_V(nir, nir_lower_subgroups, &si_nir_subgroups_options);
/* may generate some vector output store */
NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_shader_out);
NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_shader_out, NULL, NULL);
}
struct nir_shader *si_deserialize_shader(struct si_shader_selector *sel)

View file

@ -305,7 +305,7 @@ static void si_lower_nir(struct si_screen *sscreen, struct nir_shader *nir)
if (nir->info.stage == MESA_SHADER_VERTEX ||
nir->info.stage == MESA_SHADER_TESS_EVAL ||
nir->info.stage == MESA_SHADER_GEOMETRY)
NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_shader_out);
NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_shader_out, NULL, NULL);
if (nir->info.stage == MESA_SHADER_GEOMETRY) {
unsigned flags = nir_lower_gs_intrinsics_per_stream;

View file

@ -2293,7 +2293,7 @@ vc4_shader_ntq(struct vc4_context *vc4, enum qstage stage,
NIR_PASS_V(c->s, nir_lower_clip_vs,
c->key->ucp_enables, false, false, NULL);
NIR_PASS_V(c->s, nir_lower_io_to_scalar,
nir_var_shader_out);
nir_var_shader_out, NULL, NULL);
}
}
@ -2302,9 +2302,9 @@ vc4_shader_ntq(struct vc4_context *vc4, enum qstage stage,
* scalarizing must happen after nir_lower_clip_vs.
*/
if (c->stage == QSTAGE_FRAG)
NIR_PASS_V(c->s, nir_lower_io_to_scalar, nir_var_shader_in);
NIR_PASS_V(c->s, nir_lower_io_to_scalar, nir_var_shader_in, NULL, NULL);
else
NIR_PASS_V(c->s, nir_lower_io_to_scalar, nir_var_shader_out);
NIR_PASS_V(c->s, nir_lower_io_to_scalar, nir_var_shader_out, NULL, NULL);
NIR_PASS_V(c->s, vc4_nir_lower_io, c);
NIR_PASS_V(c->s, vc4_nir_lower_txf_ms, c);

View file

@ -3706,7 +3706,7 @@ zink_shader_compile(struct zink_screen *screen, bool can_shobj, struct zink_shad
}
}
if (screen->driconf.inline_uniforms) {
NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_mem_global | nir_var_mem_ubo | nir_var_mem_ssbo | nir_var_mem_shared);
NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_mem_global | nir_var_mem_ubo | nir_var_mem_ssbo | nir_var_mem_shared, NULL, NULL);
NIR_PASS_V(nir, rewrite_bo_access, screen);
NIR_PASS_V(nir, remove_bo_access, zs);
need_optimize = true;
@ -3761,7 +3761,7 @@ zink_shader_compile_separate(struct zink_screen *screen, struct zink_shader *zs)
}
}
if (screen->driconf.inline_uniforms) {
NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_mem_global | nir_var_mem_ubo | nir_var_mem_ssbo | nir_var_mem_shared);
NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_mem_global | nir_var_mem_ubo | nir_var_mem_ssbo | nir_var_mem_shared, NULL, NULL);
NIR_PASS_V(nir, rewrite_bo_access, screen);
NIR_PASS_V(nir, remove_bo_access, zs);
}
@ -4913,7 +4913,7 @@ zink_shader_create(struct zink_screen *screen, struct nir_shader *nir,
NIR_PASS_V(nir, unbreak_bos, ret, needs_size);
/* run in compile if there could be inlined uniforms */
if (!screen->driconf.inline_uniforms && !nir->info.num_inlinable_uniforms) {
NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_mem_global | nir_var_mem_ubo | nir_var_mem_ssbo | nir_var_mem_shared);
NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_mem_global | nir_var_mem_ubo | nir_var_mem_ssbo | nir_var_mem_shared, NULL, NULL);
NIR_PASS_V(nir, rewrite_bo_access, screen);
NIR_PASS_V(nir, remove_bo_access, ret);
}

View file

@ -89,7 +89,7 @@ static void rogue_nir_passes(struct rogue_build_ctx *ctx,
/* Load inputs to scalars (single registers later). */
/* TODO: Fitrp can process multiple frag inputs at once, scalarise I/O. */
NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_shader_in);
NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_shader_in, NULL, NULL);
/* Optimize GL access qualifiers. */
const nir_opt_access_options opt_access_options = {
@ -102,7 +102,7 @@ static void rogue_nir_passes(struct rogue_build_ctx *ctx,
NIR_PASS_V(nir, rogue_nir_pfo);
/* Load outputs to scalars (single registers later). */
NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_shader_out);
NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_shader_out, NULL, NULL);
/* Lower ALU operations to scalars. */
NIR_PASS_V(nir, nir_lower_alu_to_scalar, NULL, NULL);
@ -115,7 +115,7 @@ static void rogue_nir_passes(struct rogue_build_ctx *ctx,
nir_lower_explicit_io,
nir_var_mem_ubo,
spirv_options.ubo_addr_format);
NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_mem_ubo);
NIR_PASS_V(nir, nir_lower_io_to_scalar, nir_var_mem_ubo, NULL, NULL);
NIR_PASS_V(nir, rogue_nir_lower_io);
/* Algebraic opts. */

View file

@ -6518,7 +6518,7 @@ nir_to_dxil(struct nir_shader *s, const struct nir_to_dxil_options *opts,
NIR_PASS_V(s, nir_lower_io, nir_var_shader_in | nir_var_shader_out, type_size_vec4, nir_lower_io_lower_64bit_to_32);
NIR_PASS_V(s, dxil_nir_ensure_position_writes);
NIR_PASS_V(s, dxil_nir_lower_system_values);
NIR_PASS_V(s, nir_lower_io_to_scalar, nir_var_shader_in | nir_var_system_value | nir_var_shader_out);
NIR_PASS_V(s, nir_lower_io_to_scalar, nir_var_shader_in | nir_var_system_value | nir_var_shader_out, NULL, NULL);
/* Do a round of optimization to try to vectorize loads/stores. Otherwise the addresses used for loads
* might be too opaque for the pass to see that they're next to each other. */