diff --git a/src/gallium/drivers/radeonsi/si_get.c b/src/gallium/drivers/radeonsi/si_get.c index 5458048b9b4..100335d7adf 100644 --- a/src/gallium/drivers/radeonsi/si_get.c +++ b/src/gallium/drivers/radeonsi/si_get.c @@ -1286,6 +1286,8 @@ void si_init_screen_get_functions(struct si_screen *sscreen) .lower_fisnormal = true, .lower_rotate = true, .lower_to_scalar = true, + .lower_to_scalar_filter = sscreen->info.has_packed_math_16bit ? + si_alu_to_scalar_packed_math_filter : NULL, .lower_int64_options = nir_lower_imul_2x32_64 | nir_lower_imul_high64, .has_sdot_4x8 = sscreen->info.has_accelerated_dot_product, .has_sudot_4x8 = sscreen->info.has_accelerated_dot_product && sscreen->info.gfx_level >= GFX11, diff --git a/src/gallium/drivers/radeonsi/si_shader.h b/src/gallium/drivers/radeonsi/si_shader.h index cd7c5832df4..1785a04fa39 100644 --- a/src/gallium/drivers/radeonsi/si_shader.h +++ b/src/gallium/drivers/radeonsi/si_shader.h @@ -1062,6 +1062,7 @@ void si_nir_scan_shader(struct si_screen *sscreen, const struct nir_shader *nir /* si_shader_nir.c */ extern const nir_lower_subgroups_options si_nir_subgroups_options; +bool si_alu_to_scalar_packed_math_filter(const nir_instr *instr, const void *data); void si_nir_opts(struct si_screen *sscreen, struct nir_shader *nir, bool first); void si_nir_late_opts(nir_shader *nir); char *si_finalize_nir(struct pipe_screen *screen, void *nirptr); diff --git a/src/gallium/drivers/radeonsi/si_shader_nir.c b/src/gallium/drivers/radeonsi/si_shader_nir.c index 0086a54e41e..198b9046028 100644 --- a/src/gallium/drivers/radeonsi/si_shader_nir.c +++ b/src/gallium/drivers/radeonsi/si_shader_nir.c @@ -10,11 +10,9 @@ #include "ac_nir.h" -static bool si_alu_to_scalar_filter(const nir_instr *instr, const void *data) +bool si_alu_to_scalar_packed_math_filter(const nir_instr *instr, const void *data) { - struct si_screen *sscreen = (struct si_screen *)data; - - if (sscreen->info.has_packed_math_16bit && instr->type == nir_instr_type_alu) { + if (instr->type == nir_instr_type_alu) { nir_alu_instr *alu = nir_instr_as_alu(instr); if (alu->dest.dest.is_ssa && @@ -75,7 +73,8 @@ void si_nir_opts(struct si_screen *sscreen, struct nir_shader *nir, bool first) bool lower_phis_to_scalar = false; NIR_PASS(progress, nir, nir_lower_vars_to_ssa); - NIR_PASS(progress, nir, nir_lower_alu_to_scalar, si_alu_to_scalar_filter, sscreen); + NIR_PASS(progress, nir, nir_lower_alu_to_scalar, + nir->options->lower_to_scalar_filter, NULL); NIR_PASS(progress, nir, nir_lower_phis_to_scalar, false); if (first) { @@ -97,8 +96,10 @@ void si_nir_opts(struct si_screen *sscreen, struct nir_shader *nir, bool first) (LLVM_VERSION_MAJOR == 14 ? 0 : nir_opt_if_optimize_phi_true_false)); NIR_PASS(progress, nir, nir_opt_dead_cf); - if (lower_alu_to_scalar) - NIR_PASS_V(nir, nir_lower_alu_to_scalar, si_alu_to_scalar_filter, sscreen); + if (lower_alu_to_scalar) { + NIR_PASS_V(nir, nir_lower_alu_to_scalar, + nir->options->lower_to_scalar_filter, NULL); + } if (lower_phis_to_scalar) NIR_PASS_V(nir, nir_lower_phis_to_scalar, false); progress |= lower_alu_to_scalar | lower_phis_to_scalar;