diff --git a/src/compiler/nir/nir_lower_subgroups.c b/src/compiler/nir/nir_lower_subgroups.c index e46bf0d00fa..1f2ac17dda0 100644 --- a/src/compiler/nir/nir_lower_subgroups.c +++ b/src/compiler/nir/nir_lower_subgroups.c @@ -701,18 +701,32 @@ build_identity(nir_builder *b, unsigned bit_size, nir_op op) /* Implementation of scan/reduce that assumes a full subgroup */ static nir_def * build_scan_full(nir_builder *b, nir_intrinsic_op op, nir_op red_op, - nir_def *data, unsigned cluster_size) + nir_def *data, unsigned cluster_size, + const nir_lower_subgroups_options *options) { + bool unknown_size = !options->subgroup_size; + nir_def *subgroup_size = unknown_size ? nir_load_subgroup_size(b) : NULL; + switch (op) { case nir_intrinsic_exclusive_scan: case nir_intrinsic_inclusive_scan: { for (unsigned i = 1; i < cluster_size; i *= 2) { + nir_def *old_data = data; + + if (unknown_size) + nir_push_if(b, nir_ugt_imm(b, subgroup_size, i)); + nir_def *idx = nir_load_subgroup_invocation(b); nir_def *has_buddy = nir_ige_imm(b, idx, i); nir_def *buddy_data = nir_shuffle_up(b, data, nir_imm_int(b, i)); nir_def *accum = nir_build_alu2(b, red_op, data, buddy_data); data = nir_bcsel(b, has_buddy, accum, data); + + if (unknown_size) { + nir_pop_if(b, NULL); + data = nir_if_phi(b, data, old_data); + } } if (op == nir_intrinsic_exclusive_scan) { @@ -732,8 +746,18 @@ build_scan_full(nir_builder *b, nir_intrinsic_op op, nir_op red_op, case nir_intrinsic_reduce: { for (unsigned i = 1; i < cluster_size; i *= 2) { + nir_def *old_data = data; + + if (unknown_size) + nir_push_if(b, nir_ugt_imm(b, subgroup_size, i)); + nir_def *buddy_data = nir_shuffle_xor(b, data, nir_imm_int(b, i)); data = nir_build_alu2(b, red_op, data, buddy_data); + + if (unknown_size) { + nir_pop_if(b, NULL); + data = nir_if_phi(b, data, old_data); + } } return data; } @@ -749,6 +773,9 @@ build_scan_reduce(nir_builder *b, nir_intrinsic_op op, nir_op red_op, nir_def *data, nir_def *mask, unsigned max_mask_bits, const nir_lower_subgroups_options *options) { + bool unknown_size = !options->subgroup_size; + nir_def *subgroup_size = unknown_size ? nir_load_subgroup_size(b) : NULL; + nir_def *lt_mask = nir_load_subgroup_lt_mask(b, options->ballot_components, options->ballot_bit_size); @@ -759,6 +786,12 @@ build_scan_reduce(nir_builder *b, nir_intrinsic_op op, nir_op red_op, nir_def *remaining = nir_iand(b, mask, lt_mask); for (unsigned i = 1; i < max_mask_bits; i *= 2) { + nir_def *old_data = data; + nir_def *old_remaining = remaining; + + if (unknown_size) + nir_push_if(b, nir_ugt_imm(b, subgroup_size, i)); + /* At each step, our buddy channel is the first channel we have yet to * take into account in the accumulator. */ @@ -776,6 +809,12 @@ build_scan_reduce(nir_builder *b, nir_intrinsic_op op, nir_op red_op, */ nir_def *buddy_remaining = nir_shuffle(b, remaining, buddy); remaining = nir_bcsel(b, has_buddy, buddy_remaining, nir_imm_int(b, 0)); + + if (unknown_size) { + nir_pop_if(b, NULL); + data = nir_if_phi(b, data, old_data); + remaining = nir_if_phi(b, remaining, old_remaining); + } } switch (op) { @@ -868,7 +907,7 @@ lower_scan_reduce(nir_builder *b, nir_intrinsic_instr *intrin, nir_push_if(b, nir_ball_iequal(b, mask, build_subgroup_mask(b, options))); { full = build_scan_full(b, intrin->intrinsic, red_op, - intrin->src[0].ssa, cluster_size); + intrin->src[0].ssa, cluster_size, options); } nir_push_else(b, NULL); {