diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index 2d2ff03ed72..c238de29945 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -6144,10 +6144,13 @@ bool nir_lower_atomics(nir_shader *shader, nir_instr_filter_cb filter); typedef struct nir_lower_subgroups_options { /* In addition to the boolean lowering options below, this optional callback * will filter instructions for lowering if non-NULL. The data passed will be - * this options struct itself. + * filter_data. */ nir_instr_filter_cb filter; + /* Extra data passed to the filter. */ + const void *filter_data; + /* In case the exact subgroup size is not known, subgroup_size should be * set to 0. In that case, the maximum subgroup size will be calculated by * ballot_components * ballot_bit_size. diff --git a/src/compiler/nir/nir_lower_subgroups.c b/src/compiler/nir/nir_lower_subgroups.c index 896cf11da36..718d018c704 100644 --- a/src/compiler/nir/nir_lower_subgroups.c +++ b/src/compiler/nir/nir_lower_subgroups.c @@ -893,6 +893,12 @@ lower_scan_reduce(nir_builder *b, nir_intrinsic_instr *intrin, static bool lower_subgroups_filter(const nir_instr *instr, const void *_options) { + const nir_lower_subgroups_options *options = _options; + + if (options->filter) { + return options->filter(instr, options->filter_data); + } + return instr->type == nir_instr_type_intrinsic; } @@ -1343,8 +1349,7 @@ bool nir_lower_subgroups(nir_shader *shader, const nir_lower_subgroups_options *options) { - void *filter = options->filter ? options->filter : lower_subgroups_filter; - return nir_shader_lower_instructions(shader, filter, + return nir_shader_lower_instructions(shader, lower_subgroups_filter, lower_subgroups_instr, (void *)options); }