diff --git a/src/intel/compiler/brw_eu_defines.h b/src/intel/compiler/brw_eu_defines.h index fc72f511f1d..29c781562f4 100644 --- a/src/intel/compiler/brw_eu_defines.h +++ b/src/intel/compiler/brw_eu_defines.h @@ -421,6 +421,22 @@ enum opcode { SHADER_OPCODE_INCLUSIVE_SCAN, SHADER_OPCODE_EXCLUSIVE_SCAN, + /* Check if any or all values in each subset (cluster) of channels are set, + * and broadcast the result to all channels in the subset. + * + * Source 0: Boolean value. + * Source 1: Immediate with cluster size. + */ + SHADER_OPCODE_VOTE_ANY, + SHADER_OPCODE_VOTE_ALL, + + /* Check if the values of all channels are equal, and broadcast the result + * to all channels. + * + * Source 0: Value. + */ + SHADER_OPCODE_VOTE_EQUAL, + /* Select between src0 and src1 based on channel enables. * * This instruction copies src0 into the enabled channels of the diff --git a/src/intel/compiler/brw_fs.cpp b/src/intel/compiler/brw_fs.cpp index 7cf0331927f..6d6daade4d2 100644 --- a/src/intel/compiler/brw_fs.cpp +++ b/src/intel/compiler/brw_fs.cpp @@ -321,6 +321,9 @@ fs_inst::can_do_source_mods(const struct intel_device_info *devinfo) const case SHADER_OPCODE_REDUCE: case SHADER_OPCODE_INCLUSIVE_SCAN: case SHADER_OPCODE_EXCLUSIVE_SCAN: + case SHADER_OPCODE_VOTE_ANY: + case SHADER_OPCODE_VOTE_ALL: + case SHADER_OPCODE_VOTE_EQUAL: return false; default: return true; diff --git a/src/intel/compiler/brw_fs_nir.cpp b/src/intel/compiler/brw_fs_nir.cpp index 264d6a71967..af0b4f6ca65 100644 --- a/src/intel/compiler/brw_fs_nir.cpp +++ b/src/intel/compiler/brw_fs_nir.cpp @@ -6590,145 +6590,24 @@ fs_nir_emit_intrinsic(nir_to_brw_state &ntb, retype(get_nir_src(ntb, instr->src[0]), BRW_TYPE_F)); break; + case nir_intrinsic_vote_any: + case nir_intrinsic_vote_all: case nir_intrinsic_quad_vote_any: case nir_intrinsic_quad_vote_all: { - struct brw_reg flag = brw_flag_reg(0, 0); - if (s.dispatch_width == 32) - flag.type = BRW_TYPE_UD; + const bool any = instr->intrinsic == nir_intrinsic_vote_any || + instr->intrinsic == nir_intrinsic_quad_vote_any; + const bool quad = instr->intrinsic == nir_intrinsic_quad_vote_any || + instr->intrinsic == nir_intrinsic_quad_vote_all; brw_reg cond = get_nir_src(ntb, instr->src[0]); + const unsigned cluster_size = quad ? 4 : s.dispatch_width; - /* Before Xe2, we can use specialized predicates. */ - if (devinfo->ver < 20) { - const bool any = instr->intrinsic == nir_intrinsic_quad_vote_any; - - /* The any/all predicates do not consider channel enables. To prevent - * dead channels from affecting the result, we initialize the flag with - * with the identity value for the logical operation. - */ - const unsigned identity = any ? 0 : 0xFFFFFFFF; - bld.exec_all().group(1, 0).MOV(flag, retype(brw_imm_ud(identity), flag.type)); - - bld.CMP(bld.null_reg_ud(), cond, brw_imm_ud(0u), BRW_CONDITIONAL_NZ); - bld.exec_all().MOV(retype(dest, BRW_TYPE_UD), brw_imm_ud(0)); - - const enum brw_predicate pred = any ? BRW_PREDICATE_ALIGN1_ANY4H - : BRW_PREDICATE_ALIGN1_ALL4H; - - fs_inst *mov = bld.MOV(retype(dest, BRW_TYPE_D), brw_imm_d(-1)); - set_predicate(pred, mov); - break; - } - - /* This code is going to manipulate the results of flag mask, so clear it to - * avoid any residual value from disabled channels. - */ - bld.exec_all().group(1, 0).MOV(flag, retype(brw_imm_ud(0), flag.type)); - - /* Mask of invocations where condition is true, note that mask is - * replicated to each invocation. - */ - bld.CMP(bld.null_reg_ud(), cond, brw_imm_ud(0u), BRW_CONDITIONAL_NZ); - brw_reg cond_mask = bld.vgrf(BRW_TYPE_UD); - bld.MOV(cond_mask, flag); - - /* Mask of invocations in the quad, each invocation will get - * all the bits set for their quad, i.e. invocations 0-3 will have - * 0b...1111, invocations 4-7 will have 0b...11110000 and so on. - */ - brw_reg invoc_ud = bld.vgrf(BRW_TYPE_UD); - bld.MOV(invoc_ud, bld.LOAD_SUBGROUP_INVOCATION()); - brw_reg quad_mask = - bld.SHL(brw_imm_ud(0xF), bld.AND(invoc_ud, brw_imm_ud(0xFFFFFFFC))); - - /* An invocation will have bits set for each quad that passes the - * condition. This is uniform among each quad. - */ - brw_reg tmp = bld.AND(cond_mask, quad_mask); - - if (instr->intrinsic == nir_intrinsic_quad_vote_any) { - bld.CMP(retype(dest, BRW_TYPE_UD), tmp, brw_imm_ud(0), BRW_CONDITIONAL_NZ); - } else { - assert(instr->intrinsic == nir_intrinsic_quad_vote_all); - - /* Filter out quad_mask to include only active channels. */ - brw_reg active = bld.vgrf(BRW_TYPE_UD); - bld.exec_all().emit(SHADER_OPCODE_LOAD_LIVE_CHANNELS, active); - bld.MOV(active, brw_reg(component(active, 0))); - bld.AND(quad_mask, quad_mask, active); - - bld.CMP(retype(dest, BRW_TYPE_UD), tmp, quad_mask, BRW_CONDITIONAL_Z); - } + bld.emit(any ? SHADER_OPCODE_VOTE_ANY : SHADER_OPCODE_VOTE_ALL, + retype(dest, BRW_TYPE_UD), cond, brw_imm_ud(cluster_size)); break; } - case nir_intrinsic_vote_any: { - const fs_builder ubld1 = bld.exec_all().group(1, 0); - - /* The any/all predicates do not consider channel enables. To prevent - * dead channels from affecting the result, we initialize the flag with - * with the identity value for the logical operation. - */ - if (s.dispatch_width == 32) { - /* For SIMD32, we use a UD type so we fill both f0.0 and f0.1. */ - ubld1.MOV(retype(brw_flag_reg(0, 0), BRW_TYPE_UD), - brw_imm_ud(0)); - } else { - ubld1.MOV(brw_flag_reg(0, 0), brw_imm_uw(0)); - } - bld.CMP(bld.null_reg_d(), get_nir_src(ntb, instr->src[0]), brw_imm_d(0), BRW_CONDITIONAL_NZ); - - /* For some reason, the any/all predicates don't work properly with - * SIMD32. In particular, it appears that a SEL with a QtrCtrl of 2H - * doesn't read the correct subset of the flag register and you end up - * getting garbage in the second half. Work around this by using a pair - * of 1-wide MOVs and scattering the result. - */ - const fs_builder ubld = devinfo->ver >= 20 ? bld.exec_all() : ubld1; - brw_reg res1 = ubld.MOV(brw_imm_d(0)); - set_predicate(devinfo->ver >= 20 ? XE2_PREDICATE_ANY : - s.dispatch_width == 8 ? BRW_PREDICATE_ALIGN1_ANY8H : - s.dispatch_width == 16 ? BRW_PREDICATE_ALIGN1_ANY16H : - BRW_PREDICATE_ALIGN1_ANY32H, - ubld.MOV(res1, brw_imm_d(-1))); - - bld.MOV(retype(dest, BRW_TYPE_D), component(res1, 0)); - break; - } - case nir_intrinsic_vote_all: { - const fs_builder ubld1 = bld.exec_all().group(1, 0); - - /* The any/all predicates do not consider channel enables. To prevent - * dead channels from affecting the result, we initialize the flag with - * with the identity value for the logical operation. - */ - if (s.dispatch_width == 32) { - /* For SIMD32, we use a UD type so we fill both f0.0 and f0.1. */ - ubld1.MOV(retype(brw_flag_reg(0, 0), BRW_TYPE_UD), - brw_imm_ud(0xffffffff)); - } else { - ubld1.MOV(brw_flag_reg(0, 0), brw_imm_uw(0xffff)); - } - bld.CMP(bld.null_reg_d(), get_nir_src(ntb, instr->src[0]), brw_imm_d(0), BRW_CONDITIONAL_NZ); - - /* For some reason, the any/all predicates don't work properly with - * SIMD32. In particular, it appears that a SEL with a QtrCtrl of 2H - * doesn't read the correct subset of the flag register and you end up - * getting garbage in the second half. Work around this by using a pair - * of 1-wide MOVs and scattering the result. - */ - const fs_builder ubld = devinfo->ver >= 20 ? bld.exec_all() : ubld1; - brw_reg res1 = ubld.MOV(brw_imm_d(0)); - set_predicate(devinfo->ver >= 20 ? XE2_PREDICATE_ALL : - s.dispatch_width == 8 ? BRW_PREDICATE_ALIGN1_ALL8H : - s.dispatch_width == 16 ? BRW_PREDICATE_ALIGN1_ALL16H : - BRW_PREDICATE_ALIGN1_ALL32H, - ubld.MOV(res1, brw_imm_d(-1))); - - bld.MOV(retype(dest, BRW_TYPE_D), component(res1, 0)); - break; - } case nir_intrinsic_vote_feq: case nir_intrinsic_vote_ieq: { brw_reg value = get_nir_src(ntb, instr->src[0]); @@ -6737,38 +6616,7 @@ fs_nir_emit_intrinsic(nir_to_brw_state &ntb, value.type = bit_size == 8 ? BRW_TYPE_B : brw_type_with_size(BRW_TYPE_F, bit_size); } - - brw_reg uniformized = bld.emit_uniformize(value); - const fs_builder ubld1 = bld.exec_all().group(1, 0); - - /* The any/all predicates do not consider channel enables. To prevent - * dead channels from affecting the result, we initialize the flag with - * with the identity value for the logical operation. - */ - if (s.dispatch_width == 32) { - /* For SIMD32, we use a UD type so we fill both f0.0 and f0.1. */ - ubld1.MOV(retype(brw_flag_reg(0, 0), BRW_TYPE_UD), - brw_imm_ud(0xffffffff)); - } else { - ubld1.MOV(brw_flag_reg(0, 0), brw_imm_uw(0xffff)); - } - bld.CMP(bld.null_reg_d(), value, uniformized, BRW_CONDITIONAL_Z); - - /* For some reason, the any/all predicates don't work properly with - * SIMD32. In particular, it appears that a SEL with a QtrCtrl of 2H - * doesn't read the correct subset of the flag register and you end up - * getting garbage in the second half. Work around this by using a pair - * of 1-wide MOVs and scattering the result. - */ - const fs_builder ubld = devinfo->ver >= 20 ? bld.exec_all() : ubld1; - brw_reg res1 = ubld.MOV(brw_imm_d(0)); - set_predicate(devinfo->ver >= 20 ? XE2_PREDICATE_ALL : - s.dispatch_width == 8 ? BRW_PREDICATE_ALIGN1_ALL8H : - s.dispatch_width == 16 ? BRW_PREDICATE_ALIGN1_ALL16H : - BRW_PREDICATE_ALIGN1_ALL32H, - ubld.MOV(res1, brw_imm_d(-1))); - - bld.MOV(retype(dest, BRW_TYPE_D), component(res1, 0)); + bld.emit(SHADER_OPCODE_VOTE_EQUAL, retype(dest, BRW_TYPE_D), value); break; } diff --git a/src/intel/compiler/brw_fs_validate.cpp b/src/intel/compiler/brw_fs_validate.cpp index 966377c41b8..2d60c50b504 100644 --- a/src/intel/compiler/brw_fs_validate.cpp +++ b/src/intel/compiler/brw_fs_validate.cpp @@ -234,6 +234,9 @@ brw_validate_instruction_phase(const fs_visitor &s, fs_inst *inst) case SHADER_OPCODE_REDUCE: case SHADER_OPCODE_INCLUSIVE_SCAN: case SHADER_OPCODE_EXCLUSIVE_SCAN: + case SHADER_OPCODE_VOTE_ANY: + case SHADER_OPCODE_VOTE_ALL: + case SHADER_OPCODE_VOTE_EQUAL: invalid_from = BRW_SHADER_PHASE_AFTER_EARLY_LOWERING; break; diff --git a/src/intel/compiler/brw_lower_subgroup_ops.cpp b/src/intel/compiler/brw_lower_subgroup_ops.cpp index d3d99ef312c..680ff51c177 100644 --- a/src/intel/compiler/brw_lower_subgroup_ops.cpp +++ b/src/intel/compiler/brw_lower_subgroup_ops.cpp @@ -345,6 +345,172 @@ brw_lower_scan(fs_visitor &s, bblock_t *block, fs_inst *inst) return true; } +static brw_reg +brw_fill_flag(const fs_builder &bld, unsigned v) +{ + const fs_builder ubld1 = bld.exec_all().group(1, 0); + brw_reg flag = brw_flag_reg(0, 0); + + if (bld.shader->dispatch_width == 32) { + /* For SIMD32, we use a UD type so we fill both f0.0 and f0.1. */ + flag = retype(flag, BRW_TYPE_UD); + ubld1.MOV(flag, brw_imm_ud(v)); + } else { + ubld1.MOV(flag, brw_imm_uw(v & 0xFFFF)); + } + + return flag; +} + +static void +brw_lower_dispatch_width_vote(const fs_builder &bld, enum opcode opcode, brw_reg dst, brw_reg src) +{ + const intel_device_info *devinfo = bld.shader->devinfo; + const unsigned dispatch_width = bld.shader->dispatch_width; + + assert(opcode == SHADER_OPCODE_VOTE_ANY || + opcode == SHADER_OPCODE_VOTE_ALL || + opcode == SHADER_OPCODE_VOTE_EQUAL); + + const bool any = opcode == SHADER_OPCODE_VOTE_ANY; + const bool equal = opcode == SHADER_OPCODE_VOTE_EQUAL; + + const brw_reg ref = equal ? bld.emit_uniformize(src) : brw_imm_d(0); + + /* The any/all predicates do not consider channel enables. To prevent + * dead channels from affecting the result, we initialize the flag with + * with the identity value for the logical operation. + */ + brw_fill_flag(bld, any ? 0 : 0xFFFFFFFF); + bld.CMP(bld.null_reg_d(), src, ref, equal ? BRW_CONDITIONAL_Z + : BRW_CONDITIONAL_NZ); + + /* For some reason, the any/all predicates don't work properly with + * SIMD32. In particular, it appears that a SEL with a QtrCtrl of 2H + * doesn't read the correct subset of the flag register and you end up + * getting garbage in the second half. Work around this by using a pair + * of 1-wide MOVs and scattering the result. + * + * TODO: Check if we still need this for newer platforms. + */ + const fs_builder ubld = devinfo->ver >= 20 ? bld.exec_all() + : bld.exec_all().group(1, 0); + brw_reg res1 = ubld.MOV(brw_imm_d(0)); + + enum brw_predicate pred; + if (any) { + pred = devinfo->ver >= 20 ? XE2_PREDICATE_ANY : + dispatch_width == 8 ? BRW_PREDICATE_ALIGN1_ANY8H : + dispatch_width == 16 ? BRW_PREDICATE_ALIGN1_ANY16H : + BRW_PREDICATE_ALIGN1_ANY32H; + } else { + pred = devinfo->ver >= 20 ? XE2_PREDICATE_ALL : + dispatch_width == 8 ? BRW_PREDICATE_ALIGN1_ALL8H : + dispatch_width == 16 ? BRW_PREDICATE_ALIGN1_ALL16H : + BRW_PREDICATE_ALIGN1_ALL32H; + } + set_predicate(pred, ubld.MOV(res1, brw_imm_d(-1))); + + bld.MOV(retype(dst, BRW_TYPE_D), component(res1, 0)); +} + +static void +brw_lower_quad_vote_gfx9(const fs_builder &bld, enum opcode opcode, brw_reg dst, brw_reg src) +{ + assert(opcode == SHADER_OPCODE_VOTE_ANY || opcode == SHADER_OPCODE_VOTE_ALL); + const bool any = opcode == SHADER_OPCODE_VOTE_ANY; + + /* The any/all predicates do not consider channel enables. To prevent + * dead channels from affecting the result, we initialize the flag with + * with the identity value for the logical operation. + */ + brw_fill_flag(bld, any ? 0 : 0xFFFFFFFF); + bld.CMP(bld.null_reg_ud(), src, brw_imm_ud(0u), BRW_CONDITIONAL_NZ); + bld.exec_all().MOV(retype(dst, BRW_TYPE_UD), brw_imm_ud(0)); + + /* Before Xe2, we can use specialized predicates. */ + const enum brw_predicate pred = any ? BRW_PREDICATE_ALIGN1_ANY4H + : BRW_PREDICATE_ALIGN1_ALL4H; + + fs_inst *mov = bld.MOV(retype(dst, BRW_TYPE_D), brw_imm_d(-1)); + set_predicate(pred, mov); +} + +static void +brw_lower_quad_vote_gfx20(const fs_builder &bld, enum opcode opcode, brw_reg dst, brw_reg src) +{ + assert(opcode == SHADER_OPCODE_VOTE_ANY || opcode == SHADER_OPCODE_VOTE_ALL); + const bool any = opcode == SHADER_OPCODE_VOTE_ANY; + + /* This code is going to manipulate the results of flag mask, so clear it to + * avoid any residual value from disabled channels. + */ + brw_reg flag = brw_fill_flag(bld, 0); + + /* Mask of invocations where condition is true, note that mask is + * replicated to each invocation. + */ + bld.CMP(bld.null_reg_ud(), src, brw_imm_ud(0u), BRW_CONDITIONAL_NZ); + brw_reg cond_mask = bld.vgrf(BRW_TYPE_UD); + bld.MOV(cond_mask, flag); + + /* Mask of invocations in the quad, each invocation will get + * all the bits set for their quad, i.e. invocations 0-3 will have + * 0b...1111, invocations 4-7 will have 0b...11110000 and so on. + */ + brw_reg invoc_ud = bld.vgrf(BRW_TYPE_UD); + bld.MOV(invoc_ud, bld.LOAD_SUBGROUP_INVOCATION()); + brw_reg quad_mask = + bld.SHL(brw_imm_ud(0xF), bld.AND(invoc_ud, brw_imm_ud(0xFFFFFFFC))); + + /* An invocation will have bits set for each quad that passes the + * condition. This is uniform among each quad. + */ + brw_reg tmp = bld.AND(cond_mask, quad_mask); + + if (any) { + bld.CMP(retype(dst, BRW_TYPE_UD), tmp, brw_imm_ud(0), BRW_CONDITIONAL_NZ); + } else { + /* Filter out quad_mask to include only active channels. */ + brw_reg active = bld.vgrf(BRW_TYPE_UD); + bld.exec_all().emit(SHADER_OPCODE_LOAD_LIVE_CHANNELS, active); + bld.MOV(active, brw_reg(component(active, 0))); + bld.AND(quad_mask, quad_mask, active); + + bld.CMP(retype(dst, BRW_TYPE_UD), tmp, quad_mask, BRW_CONDITIONAL_Z); + } +} + +static bool +brw_lower_vote(fs_visitor &s, bblock_t *block, fs_inst *inst) +{ + const fs_builder bld(&s, block, inst); + + brw_reg dst = inst->dst; + brw_reg src = inst->src[0]; + + unsigned cluster_size; + if (inst->sources > 1) { + assert(inst->src[1].file == IMM); + cluster_size = inst->src[1].ud; + } else { + cluster_size = s.dispatch_width; + } + + if (cluster_size == s.dispatch_width) { + brw_lower_dispatch_width_vote(bld, inst->opcode, dst, src); + } else { + assert(cluster_size == 4); + if (s.devinfo->ver < 20) + brw_lower_quad_vote_gfx9(bld, inst->opcode, dst, src); + else + brw_lower_quad_vote_gfx20(bld, inst->opcode, dst, src); + } + + inst->remove(block); + return true; +} + bool brw_fs_lower_subgroup_ops(fs_visitor &s) { @@ -361,6 +527,12 @@ brw_fs_lower_subgroup_ops(fs_visitor &s) progress |= brw_lower_scan(s, block, inst); break; + case SHADER_OPCODE_VOTE_ANY: + case SHADER_OPCODE_VOTE_ALL: + case SHADER_OPCODE_VOTE_EQUAL: + progress |= brw_lower_vote(s, block, inst); + break; + default: /* Nothing to do. */ break; diff --git a/src/intel/compiler/brw_print.cpp b/src/intel/compiler/brw_print.cpp index 1cfdee24a5f..84291b85879 100644 --- a/src/intel/compiler/brw_print.cpp +++ b/src/intel/compiler/brw_print.cpp @@ -297,6 +297,12 @@ brw_instruction_name(const struct brw_isa_info *isa, enum opcode op) return "inclusive_scan"; case SHADER_OPCODE_EXCLUSIVE_SCAN: return "exclusive_scan"; + case SHADER_OPCODE_VOTE_ANY: + return "vote_any"; + case SHADER_OPCODE_VOTE_ALL: + return "vote_all"; + case SHADER_OPCODE_VOTE_EQUAL: + return "vote_equal"; } unreachable("not reached");