intel/brw: Add SHADER_OPCODE_VOTE_*

Add opcodes for VOTE_ALL, VOTE_ANY and VOTE_EQUAL.  The first two
are also used for the quad variants.  Move their lowering from
NIR conversion to brw_lower_subgroup_ops.

Reviewed-by: Kenneth Graunke <kenneth@whitecape.org>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/31029>
This commit is contained in:
Caio Oliveira 2024-09-04 10:07:52 -07:00 committed by Marge Bot
parent f20df2984d
commit 019770f026
6 changed files with 210 additions and 162 deletions

View file

@ -421,6 +421,22 @@ enum opcode {
SHADER_OPCODE_INCLUSIVE_SCAN, SHADER_OPCODE_INCLUSIVE_SCAN,
SHADER_OPCODE_EXCLUSIVE_SCAN, SHADER_OPCODE_EXCLUSIVE_SCAN,
/* Check if any or all values in each subset (cluster) of channels are set,
* and broadcast the result to all channels in the subset.
*
* Source 0: Boolean value.
* Source 1: Immediate with cluster size.
*/
SHADER_OPCODE_VOTE_ANY,
SHADER_OPCODE_VOTE_ALL,
/* Check if the values of all channels are equal, and broadcast the result
* to all channels.
*
* Source 0: Value.
*/
SHADER_OPCODE_VOTE_EQUAL,
/* Select between src0 and src1 based on channel enables. /* Select between src0 and src1 based on channel enables.
* *
* This instruction copies src0 into the enabled channels of the * This instruction copies src0 into the enabled channels of the

View file

@ -321,6 +321,9 @@ fs_inst::can_do_source_mods(const struct intel_device_info *devinfo) const
case SHADER_OPCODE_REDUCE: case SHADER_OPCODE_REDUCE:
case SHADER_OPCODE_INCLUSIVE_SCAN: case SHADER_OPCODE_INCLUSIVE_SCAN:
case SHADER_OPCODE_EXCLUSIVE_SCAN: case SHADER_OPCODE_EXCLUSIVE_SCAN:
case SHADER_OPCODE_VOTE_ANY:
case SHADER_OPCODE_VOTE_ALL:
case SHADER_OPCODE_VOTE_EQUAL:
return false; return false;
default: default:
return true; return true;

View file

@ -6590,145 +6590,24 @@ fs_nir_emit_intrinsic(nir_to_brw_state &ntb,
retype(get_nir_src(ntb, instr->src[0]), BRW_TYPE_F)); retype(get_nir_src(ntb, instr->src[0]), BRW_TYPE_F));
break; break;
case nir_intrinsic_vote_any:
case nir_intrinsic_vote_all:
case nir_intrinsic_quad_vote_any: case nir_intrinsic_quad_vote_any:
case nir_intrinsic_quad_vote_all: { case nir_intrinsic_quad_vote_all: {
struct brw_reg flag = brw_flag_reg(0, 0); const bool any = instr->intrinsic == nir_intrinsic_vote_any ||
if (s.dispatch_width == 32) instr->intrinsic == nir_intrinsic_quad_vote_any;
flag.type = BRW_TYPE_UD; const bool quad = instr->intrinsic == nir_intrinsic_quad_vote_any ||
instr->intrinsic == nir_intrinsic_quad_vote_all;
brw_reg cond = get_nir_src(ntb, instr->src[0]); brw_reg cond = get_nir_src(ntb, instr->src[0]);
const unsigned cluster_size = quad ? 4 : s.dispatch_width;
/* Before Xe2, we can use specialized predicates. */ bld.emit(any ? SHADER_OPCODE_VOTE_ANY : SHADER_OPCODE_VOTE_ALL,
if (devinfo->ver < 20) { retype(dest, BRW_TYPE_UD), cond, brw_imm_ud(cluster_size));
const bool any = instr->intrinsic == nir_intrinsic_quad_vote_any;
/* The any/all predicates do not consider channel enables. To prevent
* dead channels from affecting the result, we initialize the flag with
* with the identity value for the logical operation.
*/
const unsigned identity = any ? 0 : 0xFFFFFFFF;
bld.exec_all().group(1, 0).MOV(flag, retype(brw_imm_ud(identity), flag.type));
bld.CMP(bld.null_reg_ud(), cond, brw_imm_ud(0u), BRW_CONDITIONAL_NZ);
bld.exec_all().MOV(retype(dest, BRW_TYPE_UD), brw_imm_ud(0));
const enum brw_predicate pred = any ? BRW_PREDICATE_ALIGN1_ANY4H
: BRW_PREDICATE_ALIGN1_ALL4H;
fs_inst *mov = bld.MOV(retype(dest, BRW_TYPE_D), brw_imm_d(-1));
set_predicate(pred, mov);
break;
}
/* This code is going to manipulate the results of flag mask, so clear it to
* avoid any residual value from disabled channels.
*/
bld.exec_all().group(1, 0).MOV(flag, retype(brw_imm_ud(0), flag.type));
/* Mask of invocations where condition is true, note that mask is
* replicated to each invocation.
*/
bld.CMP(bld.null_reg_ud(), cond, brw_imm_ud(0u), BRW_CONDITIONAL_NZ);
brw_reg cond_mask = bld.vgrf(BRW_TYPE_UD);
bld.MOV(cond_mask, flag);
/* Mask of invocations in the quad, each invocation will get
* all the bits set for their quad, i.e. invocations 0-3 will have
* 0b...1111, invocations 4-7 will have 0b...11110000 and so on.
*/
brw_reg invoc_ud = bld.vgrf(BRW_TYPE_UD);
bld.MOV(invoc_ud, bld.LOAD_SUBGROUP_INVOCATION());
brw_reg quad_mask =
bld.SHL(brw_imm_ud(0xF), bld.AND(invoc_ud, brw_imm_ud(0xFFFFFFFC)));
/* An invocation will have bits set for each quad that passes the
* condition. This is uniform among each quad.
*/
brw_reg tmp = bld.AND(cond_mask, quad_mask);
if (instr->intrinsic == nir_intrinsic_quad_vote_any) {
bld.CMP(retype(dest, BRW_TYPE_UD), tmp, brw_imm_ud(0), BRW_CONDITIONAL_NZ);
} else {
assert(instr->intrinsic == nir_intrinsic_quad_vote_all);
/* Filter out quad_mask to include only active channels. */
brw_reg active = bld.vgrf(BRW_TYPE_UD);
bld.exec_all().emit(SHADER_OPCODE_LOAD_LIVE_CHANNELS, active);
bld.MOV(active, brw_reg(component(active, 0)));
bld.AND(quad_mask, quad_mask, active);
bld.CMP(retype(dest, BRW_TYPE_UD), tmp, quad_mask, BRW_CONDITIONAL_Z);
}
break; break;
} }
case nir_intrinsic_vote_any: {
const fs_builder ubld1 = bld.exec_all().group(1, 0);
/* The any/all predicates do not consider channel enables. To prevent
* dead channels from affecting the result, we initialize the flag with
* with the identity value for the logical operation.
*/
if (s.dispatch_width == 32) {
/* For SIMD32, we use a UD type so we fill both f0.0 and f0.1. */
ubld1.MOV(retype(brw_flag_reg(0, 0), BRW_TYPE_UD),
brw_imm_ud(0));
} else {
ubld1.MOV(brw_flag_reg(0, 0), brw_imm_uw(0));
}
bld.CMP(bld.null_reg_d(), get_nir_src(ntb, instr->src[0]), brw_imm_d(0), BRW_CONDITIONAL_NZ);
/* For some reason, the any/all predicates don't work properly with
* SIMD32. In particular, it appears that a SEL with a QtrCtrl of 2H
* doesn't read the correct subset of the flag register and you end up
* getting garbage in the second half. Work around this by using a pair
* of 1-wide MOVs and scattering the result.
*/
const fs_builder ubld = devinfo->ver >= 20 ? bld.exec_all() : ubld1;
brw_reg res1 = ubld.MOV(brw_imm_d(0));
set_predicate(devinfo->ver >= 20 ? XE2_PREDICATE_ANY :
s.dispatch_width == 8 ? BRW_PREDICATE_ALIGN1_ANY8H :
s.dispatch_width == 16 ? BRW_PREDICATE_ALIGN1_ANY16H :
BRW_PREDICATE_ALIGN1_ANY32H,
ubld.MOV(res1, brw_imm_d(-1)));
bld.MOV(retype(dest, BRW_TYPE_D), component(res1, 0));
break;
}
case nir_intrinsic_vote_all: {
const fs_builder ubld1 = bld.exec_all().group(1, 0);
/* The any/all predicates do not consider channel enables. To prevent
* dead channels from affecting the result, we initialize the flag with
* with the identity value for the logical operation.
*/
if (s.dispatch_width == 32) {
/* For SIMD32, we use a UD type so we fill both f0.0 and f0.1. */
ubld1.MOV(retype(brw_flag_reg(0, 0), BRW_TYPE_UD),
brw_imm_ud(0xffffffff));
} else {
ubld1.MOV(brw_flag_reg(0, 0), brw_imm_uw(0xffff));
}
bld.CMP(bld.null_reg_d(), get_nir_src(ntb, instr->src[0]), brw_imm_d(0), BRW_CONDITIONAL_NZ);
/* For some reason, the any/all predicates don't work properly with
* SIMD32. In particular, it appears that a SEL with a QtrCtrl of 2H
* doesn't read the correct subset of the flag register and you end up
* getting garbage in the second half. Work around this by using a pair
* of 1-wide MOVs and scattering the result.
*/
const fs_builder ubld = devinfo->ver >= 20 ? bld.exec_all() : ubld1;
brw_reg res1 = ubld.MOV(brw_imm_d(0));
set_predicate(devinfo->ver >= 20 ? XE2_PREDICATE_ALL :
s.dispatch_width == 8 ? BRW_PREDICATE_ALIGN1_ALL8H :
s.dispatch_width == 16 ? BRW_PREDICATE_ALIGN1_ALL16H :
BRW_PREDICATE_ALIGN1_ALL32H,
ubld.MOV(res1, brw_imm_d(-1)));
bld.MOV(retype(dest, BRW_TYPE_D), component(res1, 0));
break;
}
case nir_intrinsic_vote_feq: case nir_intrinsic_vote_feq:
case nir_intrinsic_vote_ieq: { case nir_intrinsic_vote_ieq: {
brw_reg value = get_nir_src(ntb, instr->src[0]); brw_reg value = get_nir_src(ntb, instr->src[0]);
@ -6737,38 +6616,7 @@ fs_nir_emit_intrinsic(nir_to_brw_state &ntb,
value.type = bit_size == 8 ? BRW_TYPE_B : value.type = bit_size == 8 ? BRW_TYPE_B :
brw_type_with_size(BRW_TYPE_F, bit_size); brw_type_with_size(BRW_TYPE_F, bit_size);
} }
bld.emit(SHADER_OPCODE_VOTE_EQUAL, retype(dest, BRW_TYPE_D), value);
brw_reg uniformized = bld.emit_uniformize(value);
const fs_builder ubld1 = bld.exec_all().group(1, 0);
/* The any/all predicates do not consider channel enables. To prevent
* dead channels from affecting the result, we initialize the flag with
* with the identity value for the logical operation.
*/
if (s.dispatch_width == 32) {
/* For SIMD32, we use a UD type so we fill both f0.0 and f0.1. */
ubld1.MOV(retype(brw_flag_reg(0, 0), BRW_TYPE_UD),
brw_imm_ud(0xffffffff));
} else {
ubld1.MOV(brw_flag_reg(0, 0), brw_imm_uw(0xffff));
}
bld.CMP(bld.null_reg_d(), value, uniformized, BRW_CONDITIONAL_Z);
/* For some reason, the any/all predicates don't work properly with
* SIMD32. In particular, it appears that a SEL with a QtrCtrl of 2H
* doesn't read the correct subset of the flag register and you end up
* getting garbage in the second half. Work around this by using a pair
* of 1-wide MOVs and scattering the result.
*/
const fs_builder ubld = devinfo->ver >= 20 ? bld.exec_all() : ubld1;
brw_reg res1 = ubld.MOV(brw_imm_d(0));
set_predicate(devinfo->ver >= 20 ? XE2_PREDICATE_ALL :
s.dispatch_width == 8 ? BRW_PREDICATE_ALIGN1_ALL8H :
s.dispatch_width == 16 ? BRW_PREDICATE_ALIGN1_ALL16H :
BRW_PREDICATE_ALIGN1_ALL32H,
ubld.MOV(res1, brw_imm_d(-1)));
bld.MOV(retype(dest, BRW_TYPE_D), component(res1, 0));
break; break;
} }

View file

@ -234,6 +234,9 @@ brw_validate_instruction_phase(const fs_visitor &s, fs_inst *inst)
case SHADER_OPCODE_REDUCE: case SHADER_OPCODE_REDUCE:
case SHADER_OPCODE_INCLUSIVE_SCAN: case SHADER_OPCODE_INCLUSIVE_SCAN:
case SHADER_OPCODE_EXCLUSIVE_SCAN: case SHADER_OPCODE_EXCLUSIVE_SCAN:
case SHADER_OPCODE_VOTE_ANY:
case SHADER_OPCODE_VOTE_ALL:
case SHADER_OPCODE_VOTE_EQUAL:
invalid_from = BRW_SHADER_PHASE_AFTER_EARLY_LOWERING; invalid_from = BRW_SHADER_PHASE_AFTER_EARLY_LOWERING;
break; break;

View file

@ -345,6 +345,172 @@ brw_lower_scan(fs_visitor &s, bblock_t *block, fs_inst *inst)
return true; return true;
} }
static brw_reg
brw_fill_flag(const fs_builder &bld, unsigned v)
{
const fs_builder ubld1 = bld.exec_all().group(1, 0);
brw_reg flag = brw_flag_reg(0, 0);
if (bld.shader->dispatch_width == 32) {
/* For SIMD32, we use a UD type so we fill both f0.0 and f0.1. */
flag = retype(flag, BRW_TYPE_UD);
ubld1.MOV(flag, brw_imm_ud(v));
} else {
ubld1.MOV(flag, brw_imm_uw(v & 0xFFFF));
}
return flag;
}
static void
brw_lower_dispatch_width_vote(const fs_builder &bld, enum opcode opcode, brw_reg dst, brw_reg src)
{
const intel_device_info *devinfo = bld.shader->devinfo;
const unsigned dispatch_width = bld.shader->dispatch_width;
assert(opcode == SHADER_OPCODE_VOTE_ANY ||
opcode == SHADER_OPCODE_VOTE_ALL ||
opcode == SHADER_OPCODE_VOTE_EQUAL);
const bool any = opcode == SHADER_OPCODE_VOTE_ANY;
const bool equal = opcode == SHADER_OPCODE_VOTE_EQUAL;
const brw_reg ref = equal ? bld.emit_uniformize(src) : brw_imm_d(0);
/* The any/all predicates do not consider channel enables. To prevent
* dead channels from affecting the result, we initialize the flag with
* with the identity value for the logical operation.
*/
brw_fill_flag(bld, any ? 0 : 0xFFFFFFFF);
bld.CMP(bld.null_reg_d(), src, ref, equal ? BRW_CONDITIONAL_Z
: BRW_CONDITIONAL_NZ);
/* For some reason, the any/all predicates don't work properly with
* SIMD32. In particular, it appears that a SEL with a QtrCtrl of 2H
* doesn't read the correct subset of the flag register and you end up
* getting garbage in the second half. Work around this by using a pair
* of 1-wide MOVs and scattering the result.
*
* TODO: Check if we still need this for newer platforms.
*/
const fs_builder ubld = devinfo->ver >= 20 ? bld.exec_all()
: bld.exec_all().group(1, 0);
brw_reg res1 = ubld.MOV(brw_imm_d(0));
enum brw_predicate pred;
if (any) {
pred = devinfo->ver >= 20 ? XE2_PREDICATE_ANY :
dispatch_width == 8 ? BRW_PREDICATE_ALIGN1_ANY8H :
dispatch_width == 16 ? BRW_PREDICATE_ALIGN1_ANY16H :
BRW_PREDICATE_ALIGN1_ANY32H;
} else {
pred = devinfo->ver >= 20 ? XE2_PREDICATE_ALL :
dispatch_width == 8 ? BRW_PREDICATE_ALIGN1_ALL8H :
dispatch_width == 16 ? BRW_PREDICATE_ALIGN1_ALL16H :
BRW_PREDICATE_ALIGN1_ALL32H;
}
set_predicate(pred, ubld.MOV(res1, brw_imm_d(-1)));
bld.MOV(retype(dst, BRW_TYPE_D), component(res1, 0));
}
static void
brw_lower_quad_vote_gfx9(const fs_builder &bld, enum opcode opcode, brw_reg dst, brw_reg src)
{
assert(opcode == SHADER_OPCODE_VOTE_ANY || opcode == SHADER_OPCODE_VOTE_ALL);
const bool any = opcode == SHADER_OPCODE_VOTE_ANY;
/* The any/all predicates do not consider channel enables. To prevent
* dead channels from affecting the result, we initialize the flag with
* with the identity value for the logical operation.
*/
brw_fill_flag(bld, any ? 0 : 0xFFFFFFFF);
bld.CMP(bld.null_reg_ud(), src, brw_imm_ud(0u), BRW_CONDITIONAL_NZ);
bld.exec_all().MOV(retype(dst, BRW_TYPE_UD), brw_imm_ud(0));
/* Before Xe2, we can use specialized predicates. */
const enum brw_predicate pred = any ? BRW_PREDICATE_ALIGN1_ANY4H
: BRW_PREDICATE_ALIGN1_ALL4H;
fs_inst *mov = bld.MOV(retype(dst, BRW_TYPE_D), brw_imm_d(-1));
set_predicate(pred, mov);
}
static void
brw_lower_quad_vote_gfx20(const fs_builder &bld, enum opcode opcode, brw_reg dst, brw_reg src)
{
assert(opcode == SHADER_OPCODE_VOTE_ANY || opcode == SHADER_OPCODE_VOTE_ALL);
const bool any = opcode == SHADER_OPCODE_VOTE_ANY;
/* This code is going to manipulate the results of flag mask, so clear it to
* avoid any residual value from disabled channels.
*/
brw_reg flag = brw_fill_flag(bld, 0);
/* Mask of invocations where condition is true, note that mask is
* replicated to each invocation.
*/
bld.CMP(bld.null_reg_ud(), src, brw_imm_ud(0u), BRW_CONDITIONAL_NZ);
brw_reg cond_mask = bld.vgrf(BRW_TYPE_UD);
bld.MOV(cond_mask, flag);
/* Mask of invocations in the quad, each invocation will get
* all the bits set for their quad, i.e. invocations 0-3 will have
* 0b...1111, invocations 4-7 will have 0b...11110000 and so on.
*/
brw_reg invoc_ud = bld.vgrf(BRW_TYPE_UD);
bld.MOV(invoc_ud, bld.LOAD_SUBGROUP_INVOCATION());
brw_reg quad_mask =
bld.SHL(brw_imm_ud(0xF), bld.AND(invoc_ud, brw_imm_ud(0xFFFFFFFC)));
/* An invocation will have bits set for each quad that passes the
* condition. This is uniform among each quad.
*/
brw_reg tmp = bld.AND(cond_mask, quad_mask);
if (any) {
bld.CMP(retype(dst, BRW_TYPE_UD), tmp, brw_imm_ud(0), BRW_CONDITIONAL_NZ);
} else {
/* Filter out quad_mask to include only active channels. */
brw_reg active = bld.vgrf(BRW_TYPE_UD);
bld.exec_all().emit(SHADER_OPCODE_LOAD_LIVE_CHANNELS, active);
bld.MOV(active, brw_reg(component(active, 0)));
bld.AND(quad_mask, quad_mask, active);
bld.CMP(retype(dst, BRW_TYPE_UD), tmp, quad_mask, BRW_CONDITIONAL_Z);
}
}
static bool
brw_lower_vote(fs_visitor &s, bblock_t *block, fs_inst *inst)
{
const fs_builder bld(&s, block, inst);
brw_reg dst = inst->dst;
brw_reg src = inst->src[0];
unsigned cluster_size;
if (inst->sources > 1) {
assert(inst->src[1].file == IMM);
cluster_size = inst->src[1].ud;
} else {
cluster_size = s.dispatch_width;
}
if (cluster_size == s.dispatch_width) {
brw_lower_dispatch_width_vote(bld, inst->opcode, dst, src);
} else {
assert(cluster_size == 4);
if (s.devinfo->ver < 20)
brw_lower_quad_vote_gfx9(bld, inst->opcode, dst, src);
else
brw_lower_quad_vote_gfx20(bld, inst->opcode, dst, src);
}
inst->remove(block);
return true;
}
bool bool
brw_fs_lower_subgroup_ops(fs_visitor &s) brw_fs_lower_subgroup_ops(fs_visitor &s)
{ {
@ -361,6 +527,12 @@ brw_fs_lower_subgroup_ops(fs_visitor &s)
progress |= brw_lower_scan(s, block, inst); progress |= brw_lower_scan(s, block, inst);
break; break;
case SHADER_OPCODE_VOTE_ANY:
case SHADER_OPCODE_VOTE_ALL:
case SHADER_OPCODE_VOTE_EQUAL:
progress |= brw_lower_vote(s, block, inst);
break;
default: default:
/* Nothing to do. */ /* Nothing to do. */
break; break;

View file

@ -297,6 +297,12 @@ brw_instruction_name(const struct brw_isa_info *isa, enum opcode op)
return "inclusive_scan"; return "inclusive_scan";
case SHADER_OPCODE_EXCLUSIVE_SCAN: case SHADER_OPCODE_EXCLUSIVE_SCAN:
return "exclusive_scan"; return "exclusive_scan";
case SHADER_OPCODE_VOTE_ANY:
return "vote_any";
case SHADER_OPCODE_VOTE_ALL:
return "vote_all";
case SHADER_OPCODE_VOTE_EQUAL:
return "vote_equal";
} }
unreachable("not reached"); unreachable("not reached");