diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index bf2c9ff138b..ad1a5034db3 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -4738,6 +4738,7 @@ typedef struct nir_lower_subgroups_options { bool lower_quad_broadcast_dynamic:1; bool lower_quad_broadcast_dynamic_to_const:1; bool lower_elect:1; + bool lower_read_invocation_to_cond:1; } nir_lower_subgroups_options; bool nir_lower_subgroups(nir_shader *shader, diff --git a/src/compiler/nir/nir_control_flow b/src/compiler/nir/nir_control_flow new file mode 100644 index 00000000000..e69de29bb2d diff --git a/src/compiler/nir/nir_intrinsics.py b/src/compiler/nir/nir_intrinsics.py index 57e6d868da0..c6cfe4cd237 100644 --- a/src/compiler/nir/nir_intrinsics.py +++ b/src/compiler/nir/nir_intrinsics.py @@ -362,6 +362,10 @@ intrinsic("ballot", src_comp=[1], dest_comp=0, flags=[CAN_ELIMINATE]) intrinsic("read_invocation", src_comp=[0, 1], dest_comp=0, bit_sizes=src0, flags=[CAN_ELIMINATE]) intrinsic("read_first_invocation", src_comp=[0], dest_comp=0, bit_sizes=src0, flags=[CAN_ELIMINATE]) +# Returns the value of the first source for the lane where the second source is +# true. The second source must be true for exactly one lane. +intrinsic("read_invocation_cond_ir3", src_comp=[0, 1], dest_comp=0, flags=[CAN_ELIMINATE]) + # Additional SPIR-V ballot intrinsics # # These correspond to the SPIR-V opcodes diff --git a/src/compiler/nir/nir_lower_subgroups.c b/src/compiler/nir/nir_lower_subgroups.c index cfe74758ff1..33ba7821752 100644 --- a/src/compiler/nir/nir_lower_subgroups.c +++ b/src/compiler/nir/nir_lower_subgroups.c @@ -493,6 +493,15 @@ lower_dynamic_quad_broadcast(nir_builder *b, nir_intrinsic_instr *intrin, return dst; } +static nir_ssa_def * +lower_read_invocation_to_cond(nir_builder *b, nir_intrinsic_instr *intrin) +{ + return nir_read_invocation_cond_ir3(b, intrin->dest.ssa.bit_size, + intrin->src[0].ssa, + nir_ieq(b, intrin->src[1].ssa, + nir_load_subgroup_invocation(b))); +} + static nir_ssa_def * lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options) { @@ -524,6 +533,14 @@ lower_subgroups_instr(nir_builder *b, nir_instr *instr, void *_options) break; case nir_intrinsic_read_invocation: + if (options->lower_to_scalar && intrin->num_components > 1) + return lower_subgroup_op_to_scalar(b, intrin, false); + + if (options->lower_read_invocation_to_cond) + return lower_read_invocation_to_cond(b, intrin); + + break; + case nir_intrinsic_read_first_invocation: if (options->lower_to_scalar && intrin->num_components > 1) return lower_subgroup_op_to_scalar(b, intrin, false);