diff --git a/src/gallium/drivers/radeonsi/si_shader_nir.c b/src/gallium/drivers/radeonsi/si_shader_nir.c index 393a0c01d60..a04a724e263 100644 --- a/src/gallium/drivers/radeonsi/si_shader_nir.c +++ b/src/gallium/drivers/radeonsi/si_shader_nir.c @@ -614,6 +614,50 @@ void si_nir_late_opts(nir_shader *nir) } } +static void si_late_optimize_16bit_samplers(struct si_screen *sscreen, nir_shader *nir) +{ + /* Optimize and fix types of image_sample sources and destinations. + * + * The image_sample constraints are: + * nir_tex_src_coord: has_a16 ? select 16 or 32 : 32 + * nir_tex_src_comparator: 32 + * nir_tex_src_offset: 32 + * nir_tex_src_bias: 32 + * nir_tex_src_lod: match coord + * nir_tex_src_min_lod: match coord + * nir_tex_src_ms_index: match coord + * nir_tex_src_ddx: has_g16 && coord == 32 ? select 16 or 32 : match coord + * nir_tex_src_ddy: match ddy + * + * coord and ddx are selected optimally. The types of the rest are legalized + * based on those two. + */ + /* TODO: The constraints can't represent the ddx constraint. */ + /*bool has_g16 = sscreen->info.chip_class >= GFX10 && LLVM_VERSION_MAJOR >= 12;*/ + bool has_g16 = false; + nir_tex_src_type_constraints tex_constraints = { + [nir_tex_src_comparator] = {true, 32}, + [nir_tex_src_offset] = {true, 32}, + [nir_tex_src_bias] = {true, 32}, + [nir_tex_src_lod] = {true, 0, nir_tex_src_coord}, + [nir_tex_src_min_lod] = {true, 0, nir_tex_src_coord}, + [nir_tex_src_ms_index] = {true, 0, nir_tex_src_coord}, + [nir_tex_src_ddx] = {!has_g16, 0, nir_tex_src_coord}, + [nir_tex_src_ddy] = {true, 0, has_g16 ? nir_tex_src_ddx : nir_tex_src_coord}, + }; + bool changed = false; + + NIR_PASS(changed, nir, nir_fold_16bit_sampler_conversions, + (1 << nir_tex_src_coord) | + (has_g16 ? 1 << nir_tex_src_ddx : 0)); + NIR_PASS(changed, nir, nir_legalize_16bit_sampler_srcs, tex_constraints); + + if (changed) { + si_nir_opts(sscreen, nir, false); + si_nir_late_opts(nir); + } +} + static int type_size_vec4(const struct glsl_type *type, bool bindless) { return glsl_count_attribute_slots(type, false); @@ -833,9 +877,12 @@ static void si_lower_nir(struct si_screen *sscreen, struct nir_shader *nir) if (changed) si_nir_opts(sscreen, nir, false); - /* Run late optimizations to fuse ffma. */ + /* Run late optimizations to fuse ffma and eliminate 16-bit conversions. */ si_nir_late_opts(nir); + if (sscreen->b.get_shader_param(&sscreen->b, PIPE_SHADER_FRAGMENT, PIPE_SHADER_CAP_FP16)) + si_late_optimize_16bit_samplers(sscreen, nir); + NIR_PASS_V(nir, nir_remove_dead_variables, nir_var_function_temp, NULL); NIR_PASS_V(nir, nir_lower_discard_or_demote,