diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index 349c8a379ba..1c6e3e3c517 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -5496,7 +5496,7 @@ bool nir_lower_ssa_defs_to_regs_block(nir_block *block); bool nir_rematerialize_derefs_in_use_blocks_impl(nir_function_impl *impl); bool nir_lower_samplers(nir_shader *shader); -bool nir_lower_cl_images(nir_shader *shader); +bool nir_lower_cl_images(nir_shader *shader, bool lower_image_derefs, bool lower_sampler_derefs); bool nir_dedup_inline_samplers(nir_shader *shader); bool nir_lower_ssbo(nir_shader *shader); diff --git a/src/compiler/nir/nir_lower_cl_images.c b/src/compiler/nir/nir_lower_cl_images.c index feaa61d7008..738a9c9eb83 100644 --- a/src/compiler/nir/nir_lower_cl_images.c +++ b/src/compiler/nir/nir_lower_cl_images.c @@ -108,7 +108,7 @@ nir_dedup_inline_samplers(nir_shader *nir) } bool -nir_lower_cl_images(nir_shader *shader) +nir_lower_cl_images(nir_shader *shader, bool lower_image_derefs, bool lower_sampler_derefs) { nir_function_impl *impl = nir_shader_get_entrypoint(shader); @@ -159,6 +159,12 @@ nir_lower_cl_images(nir_shader *shader) nir_builder b; nir_builder_init(&b, impl); + /* don't need any lowering if we can keep the derefs */ + if (!lower_image_derefs && !lower_sampler_derefs) { + nir_metadata_preserve(impl, nir_metadata_all); + return false; + } + bool progress = false; nir_foreach_block_reverse(block, impl) { nir_foreach_instr_reverse_safe(instr, block) { @@ -173,6 +179,13 @@ nir_lower_cl_images(nir_shader *shader) !glsl_type_is_sampler(deref->type)) break; + if (!lower_image_derefs && glsl_type_is_image(deref->type)) + break; + + if (!lower_sampler_derefs && + (glsl_type_is_sampler(deref->type) || glsl_type_is_texture(deref->type))) + break; + b.cursor = nir_instr_remove(&deref->instr); nir_ssa_def *loc = nir_imm_intN_t(&b, deref->var->data.driver_location, @@ -183,6 +196,9 @@ nir_lower_cl_images(nir_shader *shader) } case nir_instr_type_tex: { + if (!lower_sampler_derefs) + break; + nir_tex_instr *tex = nir_instr_as_tex(instr); unsigned count = 0; for (unsigned i = 0; i < tex->num_srcs; i++) { @@ -247,6 +263,9 @@ nir_lower_cl_images(nir_shader *shader) case nir_intrinsic_image_deref_atomic_dec_wrap: case nir_intrinsic_image_deref_size: case nir_intrinsic_image_deref_samples: { + if (!lower_image_derefs) + break; + assert(intrin->src[0].is_ssa); b.cursor = nir_before_instr(&intrin->instr); /* Back-ends expect a 32-bit thing, not 64-bit */ diff --git a/src/gallium/frontends/clover/nir/invocation.cpp b/src/gallium/frontends/clover/nir/invocation.cpp index 907f79bce08..d75d79a8c72 100644 --- a/src/gallium/frontends/clover/nir/invocation.cpp +++ b/src/gallium/frontends/clover/nir/invocation.cpp @@ -376,7 +376,7 @@ binary clover::nir::spirv_to_nir(const binary &mod, const device &dev, NIR_PASS_V(nir, nir_opt_deref); NIR_PASS_V(nir, nir_lower_readonly_images_to_tex, false); - NIR_PASS_V(nir, nir_lower_cl_images); + NIR_PASS_V(nir, nir_lower_cl_images, true, true); NIR_PASS_V(nir, nir_lower_memcpy); /* use offsets for kernel inputs (uniform) */ diff --git a/src/gallium/frontends/rusticl/core/kernel.rs b/src/gallium/frontends/rusticl/core/kernel.rs index 20e32e60b3f..5023336c7ff 100644 --- a/src/gallium/frontends/rusticl/core/kernel.rs +++ b/src/gallium/frontends/rusticl/core/kernel.rs @@ -534,7 +534,7 @@ fn lower_and_optimize_nir_late( } nir.pass1(nir_lower_readonly_images_to_tex, true); - nir.pass0(nir_lower_cl_images); + nir.pass2(nir_lower_cl_images, true, true); nir.reset_scratch_size(); nir.pass2(