diff --git a/src/compiler/nir/nir.h b/src/compiler/nir/nir.h index 97813ec8159..7e4352049c3 100644 --- a/src/compiler/nir/nir.h +++ b/src/compiler/nir/nir.h @@ -3069,6 +3069,8 @@ typedef enum { nir_lower_extract64 = (1 << 13), nir_lower_ufind_msb64 = (1 << 14), nir_lower_bit_count64 = (1 << 15), + nir_lower_subgroup_shuffle64 = (1 << 16), + nir_lower_scan_reduce_bitwise64 = (1 << 17), } nir_lower_int64_options; typedef enum { diff --git a/src/compiler/nir/nir_lower_int64.c b/src/compiler/nir/nir_lower_int64.c index 28a80012559..6c5a2aea0a7 100644 --- a/src/compiler/nir/nir_lower_int64.c +++ b/src/compiler/nir/nir_lower_int64.c @@ -1066,12 +1066,134 @@ should_lower_int64_alu_instr(const nir_alu_instr *alu, return (options->lower_int64_options & mask) != 0; } +static nir_ssa_def * +split_64bit_subgroup_op(nir_builder *b, const nir_intrinsic_instr *intrin) +{ + const nir_intrinsic_info *info = &nir_intrinsic_infos[intrin->intrinsic]; + + /* This works on subgroup ops with a single 64-bit source which can be + * trivially lowered by doing the exact same op on both halves. + */ + assert(intrin->src[0].is_ssa && intrin->src[0].ssa->bit_size == 64); + nir_ssa_def *split_src0[2] = { + nir_unpack_64_2x32_split_x(b, intrin->src[0].ssa), + nir_unpack_64_2x32_split_y(b, intrin->src[0].ssa), + }; + + assert(info->has_dest && intrin->dest.is_ssa && + intrin->dest.ssa.bit_size == 64); + + nir_ssa_def *res[2]; + for (unsigned i = 0; i < 2; i++) { + nir_intrinsic_instr *split = + nir_intrinsic_instr_create(b->shader, intrin->intrinsic); + split->num_components = intrin->num_components; + split->src[0] = nir_src_for_ssa(split_src0[i]); + + /* Other sources must be less than 64 bits and get copied directly */ + for (unsigned j = 1; j < info->num_srcs; j++) { + assert(intrin->src[j].is_ssa && intrin->src[j].ssa->bit_size < 64); + split->src[j] = nir_src_for_ssa(intrin->src[j].ssa); + } + + /* Copy const indices, if any */ + memcpy(split->const_index, intrin->const_index, + sizeof(intrin->const_index)); + + nir_ssa_dest_init(&split->instr, &split->dest, + intrin->dest.ssa.num_components, 32, NULL); + nir_builder_instr_insert(b, &split->instr); + + res[i] = &split->dest.ssa; + } + + return nir_pack_64_2x32_split(b, res[0], res[1]); +} + +static bool +should_lower_int64_intrinsic(const nir_intrinsic_instr *intrin, + const nir_shader_compiler_options *options) +{ + switch (intrin->intrinsic) { + case nir_intrinsic_read_invocation: + case nir_intrinsic_read_first_invocation: + case nir_intrinsic_shuffle: + case nir_intrinsic_shuffle_xor: + case nir_intrinsic_shuffle_up: + case nir_intrinsic_shuffle_down: + case nir_intrinsic_quad_broadcast: + case nir_intrinsic_quad_swap_horizontal: + case nir_intrinsic_quad_swap_vertical: + case nir_intrinsic_quad_swap_diagonal: + assert(intrin->dest.is_ssa); + return intrin->dest.ssa.bit_size == 64 && + (options->lower_int64_options & nir_lower_subgroup_shuffle64); + + case nir_intrinsic_reduce: + case nir_intrinsic_inclusive_scan: + case nir_intrinsic_exclusive_scan: + assert(intrin->dest.is_ssa); + if (intrin->dest.ssa.bit_size != 64) + return false; + + switch (nir_intrinsic_reduction_op(intrin)) { + case nir_op_iand: + case nir_op_ior: + case nir_op_ixor: + return options->lower_int64_options & nir_lower_scan_reduce_bitwise64; + default: + return false; + } + break; + + default: + return false; + } +} + +static nir_ssa_def * +lower_int64_intrinsic(nir_builder *b, nir_intrinsic_instr *intrin) +{ + switch (intrin->intrinsic) { + case nir_intrinsic_read_invocation: + case nir_intrinsic_read_first_invocation: + case nir_intrinsic_shuffle: + case nir_intrinsic_shuffle_xor: + case nir_intrinsic_shuffle_up: + case nir_intrinsic_shuffle_down: + case nir_intrinsic_quad_broadcast: + case nir_intrinsic_quad_swap_horizontal: + case nir_intrinsic_quad_swap_vertical: + case nir_intrinsic_quad_swap_diagonal: + return split_64bit_subgroup_op(b, intrin); + + case nir_intrinsic_reduce: + case nir_intrinsic_inclusive_scan: + case nir_intrinsic_exclusive_scan: + switch (nir_intrinsic_reduction_op(intrin)) { + case nir_op_iand: + case nir_op_ior: + case nir_op_ixor: + return split_64bit_subgroup_op(b, intrin); + default: + unreachable("Unsupported subgroup scan/reduce op"); + } + break; + + default: + unreachable("Unsupported intrinsic"); + } +} + static bool should_lower_int64_instr(const nir_instr *instr, const void *_options) { switch (instr->type) { case nir_instr_type_alu: return should_lower_int64_alu_instr(nir_instr_as_alu(instr), _options); + case nir_instr_type_intrinsic: + return should_lower_int64_intrinsic(nir_instr_as_intrinsic(instr), + _options); default: return false; } @@ -1083,6 +1205,8 @@ lower_int64_instr(nir_builder *b, nir_instr *instr, void *_options) switch (instr->type) { case nir_instr_type_alu: return lower_int64_alu_instr(b, nir_instr_as_alu(instr)); + case nir_instr_type_intrinsic: + return lower_int64_intrinsic(b, nir_instr_as_intrinsic(instr)); default: return NULL; }