diff --git a/src/intel/compiler/brw_eu_defines.h b/src/intel/compiler/brw_eu_defines.h index b9976891683..1256b77eaf3 100644 --- a/src/intel/compiler/brw_eu_defines.h +++ b/src/intel/compiler/brw_eu_defines.h @@ -403,6 +403,15 @@ enum opcode { */ SHADER_OPCODE_SHUFFLE, + /* Combine all values in each subset (cluster) of channels using an operation, + * and broadcast the result to all channels in the subset. + * + * Source 0: Value. + * Source 1: Immediate with brw_reduce_op. + * Source 2: Immediate with cluster size. + */ + SHADER_OPCODE_REDUCE, + /* Select between src0 and src1 based on channel enables. * * This instruction copies src0 into the enabled channels of the @@ -672,6 +681,16 @@ enum interpolator_logical_srcs { INTERP_NUM_SRCS }; +enum brw_reduce_op { + BRW_REDUCE_OP_ADD, + BRW_REDUCE_OP_MUL, + BRW_REDUCE_OP_MIN, + BRW_REDUCE_OP_MAX, + BRW_REDUCE_OP_AND, + BRW_REDUCE_OP_OR, + BRW_REDUCE_OP_XOR, +}; + enum ENUM_PACKED brw_predicate { BRW_PREDICATE_NONE = 0, BRW_PREDICATE_NORMAL = 1, diff --git a/src/intel/compiler/brw_fs.cpp b/src/intel/compiler/brw_fs.cpp index 5c0cc0bbb2b..f4128b8cde3 100644 --- a/src/intel/compiler/brw_fs.cpp +++ b/src/intel/compiler/brw_fs.cpp @@ -318,6 +318,7 @@ fs_inst::can_do_source_mods(const struct intel_device_info *devinfo) const case SHADER_OPCODE_SHUFFLE: case SHADER_OPCODE_INT_QUOTIENT: case SHADER_OPCODE_INT_REMAINDER: + case SHADER_OPCODE_REDUCE: return false; default: return true; diff --git a/src/intel/compiler/brw_fs.h b/src/intel/compiler/brw_fs.h index c026ca73592..98ab4136262 100644 --- a/src/intel/compiler/brw_fs.h +++ b/src/intel/compiler/brw_fs.h @@ -649,6 +649,7 @@ bool brw_fs_lower_sends_overlapping_payload(fs_visitor &s); bool brw_fs_lower_simd_width(fs_visitor &s); bool brw_fs_lower_csel(fs_visitor &s); bool brw_fs_lower_sub_sat(fs_visitor &s); +bool brw_fs_lower_subgroup_ops(fs_visitor &s); bool brw_fs_lower_uniform_pull_constant_loads(fs_visitor &s); void brw_fs_lower_vgrfs_to_fixed_grfs(fs_visitor &s); diff --git a/src/intel/compiler/brw_fs_nir.cpp b/src/intel/compiler/brw_fs_nir.cpp index c880997a4dc..e901cbd819f 100644 --- a/src/intel/compiler/brw_fs_nir.cpp +++ b/src/intel/compiler/brw_fs_nir.cpp @@ -4709,6 +4709,28 @@ brw_nir_reduction_op_identity(const fs_builder &bld, } } +static brw_reduce_op +brw_reduce_op_for_nir_reduction_op(nir_op op) +{ + switch (op) { + case nir_op_iadd: return BRW_REDUCE_OP_ADD; + case nir_op_fadd: return BRW_REDUCE_OP_ADD; + case nir_op_imul: return BRW_REDUCE_OP_MUL; + case nir_op_fmul: return BRW_REDUCE_OP_MUL; + case nir_op_imin: return BRW_REDUCE_OP_MIN; + case nir_op_umin: return BRW_REDUCE_OP_MIN; + case nir_op_fmin: return BRW_REDUCE_OP_MIN; + case nir_op_imax: return BRW_REDUCE_OP_MAX; + case nir_op_umax: return BRW_REDUCE_OP_MAX; + case nir_op_fmax: return BRW_REDUCE_OP_MAX; + case nir_op_iand: return BRW_REDUCE_OP_AND; + case nir_op_ior: return BRW_REDUCE_OP_OR; + case nir_op_ixor: return BRW_REDUCE_OP_XOR; + default: + unreachable("Invalid reduction operation"); + } +} + static opcode brw_op_for_nir_reduction_op(nir_op op) { @@ -7033,49 +7055,19 @@ fs_nir_emit_intrinsic(nir_to_brw_state &ntb, case nir_intrinsic_reduce: { brw_reg src = get_nir_src(ntb, instr->src[0]); - nir_op redop = (nir_op)nir_intrinsic_reduction_op(instr); + nir_op op = (nir_op)nir_intrinsic_reduction_op(instr); + enum brw_reduce_op brw_op = brw_reduce_op_for_nir_reduction_op(op); unsigned cluster_size = nir_intrinsic_cluster_size(instr); if (cluster_size == 0 || cluster_size > s.dispatch_width) cluster_size = s.dispatch_width; /* Figure out the source type */ src.type = brw_type_for_nir_type(devinfo, - (nir_alu_type)(nir_op_infos[redop].input_types[0] | + (nir_alu_type)(nir_op_infos[op].input_types[0] | nir_src_bit_size(instr->src[0]))); - brw_reg identity = brw_nir_reduction_op_identity(bld, redop, src.type); - opcode brw_op = brw_op_for_nir_reduction_op(redop); - brw_conditional_mod cond_mod = brw_cond_mod_for_nir_reduction_op(redop); - - /* Set up a register for all of our scratching around and initialize it - * to reduction operation's identity value. - */ - brw_reg scan = bld.vgrf(src.type); - bld.exec_all().emit(SHADER_OPCODE_SEL_EXEC, scan, src, identity); - - bld.emit_scan(brw_op, scan, cluster_size, cond_mod); - - dest.type = src.type; - if (cluster_size * brw_type_size_bytes(src.type) >= REG_SIZE * 2) { - /* In this case, CLUSTER_BROADCAST instruction isn't needed because - * the distance between clusters is at least 2 GRFs. In this case, - * we don't need the weird striding of the CLUSTER_BROADCAST - * instruction and can just do regular MOVs. - */ - assert((cluster_size * brw_type_size_bytes(src.type)) % (REG_SIZE * 2) == 0); - const unsigned groups = - (s.dispatch_width * brw_type_size_bytes(src.type)) / (REG_SIZE * 2); - const unsigned group_size = s.dispatch_width / groups; - for (unsigned i = 0; i < groups; i++) { - const unsigned cluster = (i * group_size) / cluster_size; - const unsigned comp = cluster * cluster_size + (cluster_size - 1); - bld.group(group_size, i).MOV(horiz_offset(dest, i * group_size), - component(scan, comp)); - } - } else { - bld.emit(SHADER_OPCODE_CLUSTER_BROADCAST, dest, scan, - brw_imm_ud(cluster_size - 1), brw_imm_ud(cluster_size)); - } + bld.emit(SHADER_OPCODE_REDUCE, retype(dest, src.type), src, + brw_imm_ud(brw_op), brw_imm_ud(cluster_size)); break; } diff --git a/src/intel/compiler/brw_fs_opt.cpp b/src/intel/compiler/brw_fs_opt.cpp index 9e969be1b72..5987ccd03a0 100644 --- a/src/intel/compiler/brw_fs_opt.cpp +++ b/src/intel/compiler/brw_fs_opt.cpp @@ -90,6 +90,7 @@ brw_fs_optimize(fs_visitor &s) OPT(brw_fs_opt_dead_code_eliminate); } + OPT(brw_fs_lower_subgroup_ops); OPT(brw_fs_lower_csel); OPT(brw_fs_lower_simd_width); OPT(brw_fs_lower_barycentrics); diff --git a/src/intel/compiler/brw_fs_validate.cpp b/src/intel/compiler/brw_fs_validate.cpp index cb05ea79d38..0fe0dfdc73d 100644 --- a/src/intel/compiler/brw_fs_validate.cpp +++ b/src/intel/compiler/brw_fs_validate.cpp @@ -231,6 +231,7 @@ brw_validate_instruction_phase(const fs_visitor &s, fs_inst *inst) case RT_OPCODE_TRACE_RAY_LOGICAL: case SHADER_OPCODE_URB_READ_LOGICAL: case SHADER_OPCODE_URB_WRITE_LOGICAL: + case SHADER_OPCODE_REDUCE: invalid_from = BRW_SHADER_PHASE_AFTER_EARLY_LOWERING; break; diff --git a/src/intel/compiler/brw_lower_subgroup_ops.cpp b/src/intel/compiler/brw_lower_subgroup_ops.cpp new file mode 100644 index 00000000000..4d427b5ede0 --- /dev/null +++ b/src/intel/compiler/brw_lower_subgroup_ops.cpp @@ -0,0 +1,197 @@ +/* + * Copyright 2024 Intel Corporation + * SPDX-License-Identifier: MIT + */ + +#include +#include "util/half_float.h" + +#include "brw_fs.h" +#include "brw_fs_builder.h" + +using namespace brw; + +struct brw_reduction_info { + brw_reg identity; + enum opcode op; + brw_conditional_mod cond_mod; +}; + +static brw_reduction_info +brw_get_reduction_info(brw_reduce_op red_op, brw_reg_type type) +{ + struct brw_reduction_info info; + + info.op = BRW_OPCODE_SEL; + info.cond_mod = BRW_CONDITIONAL_NONE; + + switch (red_op) { + case BRW_REDUCE_OP_ADD: info.op = BRW_OPCODE_ADD; break; + case BRW_REDUCE_OP_MUL: info.op = BRW_OPCODE_MUL; break; + case BRW_REDUCE_OP_AND: info.op = BRW_OPCODE_AND; break; + case BRW_REDUCE_OP_OR: info.op = BRW_OPCODE_OR; break; + case BRW_REDUCE_OP_XOR: info.op = BRW_OPCODE_XOR; break; + case BRW_REDUCE_OP_MIN: info.cond_mod = BRW_CONDITIONAL_L; break; + case BRW_REDUCE_OP_MAX: info.cond_mod = BRW_CONDITIONAL_GE; break; + default: + unreachable("invalid reduce op"); + } + + switch (red_op) { + case BRW_REDUCE_OP_ADD: + case BRW_REDUCE_OP_XOR: + case BRW_REDUCE_OP_OR: + info.identity = retype(brw_imm_u64(0), type); + return info; + case BRW_REDUCE_OP_AND: + info.identity = retype(brw_imm_u64(~0ull), type); + return info; + default: + /* Continue below. */ + break; + } + + brw_reg id; + const unsigned size = brw_type_size_bytes(type); + + switch (red_op) { + case BRW_REDUCE_OP_MUL: { + if (brw_type_is_int(type)) { + id = size < 4 ? brw_imm_uw(1) : + size == 4 ? brw_imm_ud(1) : + brw_imm_u64(1); + } else { + assert(brw_type_is_float(type)); + id = size == 2 ? brw_imm_uw(_mesa_float_to_half(1.0)) : + size == 4 ? brw_imm_f(1.0) : + brw_imm_df(1.0); + } + break; + } + + case BRW_REDUCE_OP_MIN: { + if (brw_type_is_uint(type)) { + id = brw_imm_u64(~0ull); + } else if (brw_type_is_sint(type)) { + id = size == 1 ? brw_imm_w(INT8_MAX) : + size == 2 ? brw_imm_w(INT16_MAX) : + size == 4 ? brw_imm_d(INT32_MAX) : + brw_imm_q(INT64_MAX); + } else { + assert(brw_type_is_float(type)); + id = size == 2 ? brw_imm_uw(_mesa_float_to_half(INFINITY)) : + size == 4 ? brw_imm_f(INFINITY) : + brw_imm_df(INFINITY); + } + break; + } + + case BRW_REDUCE_OP_MAX: { + if (brw_type_is_uint(type)) { + id = brw_imm_u64(0); + } else if (brw_type_is_sint(type)) { + id = size == 1 ? brw_imm_w(INT8_MIN) : + size == 2 ? brw_imm_w(INT16_MIN) : + size == 4 ? brw_imm_d(INT32_MIN) : + brw_imm_q(INT64_MIN); + } else { + assert(brw_type_is_float(type)); + id = size == 2 ? brw_imm_uw(_mesa_float_to_half(-INFINITY)) : + size == 4 ? brw_imm_f(-INFINITY) : + brw_imm_df(-INFINITY); + } + break; + } + + default: + unreachable("invalid reduce op"); + } + + /* For some cases above (e.g. all bits zeros, all bits ones, first bit one) + * either the size or the signedness was ignored, so adjust the final type + * now. + * + * B/UB types can't have immediates, so used W/UW above and here. + */ + if (type == BRW_TYPE_UB) type = BRW_TYPE_UW; + else if (type == BRW_TYPE_B) type = BRW_TYPE_W; + + info.identity = retype(id, type); + + return info; +} + +static bool +brw_lower_reduce(fs_visitor &s, bblock_t *block, fs_inst *inst) +{ + const fs_builder bld(&s, block, inst); + + assert(inst->dst.type == inst->src[0].type); + brw_reg dst = inst->dst; + brw_reg src = inst->src[0]; + + assert(inst->src[1].file == IMM); + enum brw_reduce_op op = (enum brw_reduce_op)inst->src[1].ud; + + assert(inst->src[2].file == IMM); + unsigned cluster_size = inst->src[2].ud; + + assert(cluster_size > 0); + assert(cluster_size <= s.dispatch_width); + + struct brw_reduction_info info = brw_get_reduction_info(op, src.type); + + /* Set up a register for all of our scratching around and initialize it + * to reduction operation's identity value. + */ + brw_reg scan = bld.vgrf(src.type); + bld.exec_all().emit(SHADER_OPCODE_SEL_EXEC, scan, src, info.identity); + + bld.emit_scan(info.op, scan, cluster_size, info.cond_mod); + + if (cluster_size * brw_type_size_bytes(src.type) >= REG_SIZE * 2) { + /* In this case, CLUSTER_BROADCAST instruction isn't needed because + * the distance between clusters is at least 2 GRFs. In this case, + * we don't need the weird striding of the CLUSTER_BROADCAST + * instruction and can just do regular MOVs. + */ + assert((cluster_size * brw_type_size_bytes(src.type)) % (REG_SIZE * 2) == 0); + const unsigned groups = + (s.dispatch_width * brw_type_size_bytes(src.type)) / (REG_SIZE * 2); + const unsigned group_size = s.dispatch_width / groups; + for (unsigned i = 0; i < groups; i++) { + const unsigned cluster = (i * group_size) / cluster_size; + const unsigned comp = cluster * cluster_size + (cluster_size - 1); + bld.group(group_size, i).MOV(horiz_offset(dst, i * group_size), + component(scan, comp)); + } + } else { + bld.emit(SHADER_OPCODE_CLUSTER_BROADCAST, dst, scan, + brw_imm_ud(cluster_size - 1), brw_imm_ud(cluster_size)); + } + inst->remove(block); + return true; +} + +bool +brw_fs_lower_subgroup_ops(fs_visitor &s) +{ + bool progress = false; + + foreach_block_and_inst_safe(block, fs_inst, inst, s.cfg) { + switch (inst->opcode) { + case SHADER_OPCODE_REDUCE: + progress |= brw_lower_reduce(s, block, inst); + break; + + default: + /* Nothing to do. */ + break; + } + } + + if (progress) + s.invalidate_analysis(DEPENDENCY_INSTRUCTIONS | DEPENDENCY_VARIABLES); + + return progress; +} diff --git a/src/intel/compiler/brw_print.cpp b/src/intel/compiler/brw_print.cpp index 3c3ff26ebc0..2b445d9991a 100644 --- a/src/intel/compiler/brw_print.cpp +++ b/src/intel/compiler/brw_print.cpp @@ -291,6 +291,8 @@ brw_instruction_name(const struct brw_isa_info *isa, enum opcode op) return "memory_store"; case SHADER_OPCODE_MEMORY_ATOMIC_LOGICAL: return "memory_atomic"; + case SHADER_OPCODE_REDUCE: + return "reduce"; } unreachable("not reached"); diff --git a/src/intel/compiler/meson.build b/src/intel/compiler/meson.build index 7a4e30c8827..718b080a38b 100644 --- a/src/intel/compiler/meson.build +++ b/src/intel/compiler/meson.build @@ -83,6 +83,7 @@ libintel_compiler_brw_files = files( 'brw_ir_performance.cpp', 'brw_isa_info.h', 'brw_lower_logical_sends.cpp', + 'brw_lower_subgroup_ops.cpp', 'brw_nir.h', 'brw_nir.c', 'brw_nir_analyze_ubo_ranges.c',