diff --git a/src/compiler/spirv/vtn_subgroup.c b/src/compiler/spirv/vtn_subgroup.c index a1e27721934..69ba9dddbaa 100644 --- a/src/compiler/spirv/vtn_subgroup.c +++ b/src/compiler/spirv/vtn_subgroup.c @@ -108,18 +108,20 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode, break; } - case SpvOpGroupNonUniformBallotBitExtract: + case SpvOpGroupNonUniformBallotBitExtract: { + nir_def *src0 = vtn_get_nir_ssa(b, w[4]); + nir_def *src1 = vtn_get_nir_ssa(b, w[5]); + nir_def *dest = nir_ballot_bitfield_extract(&b->nb, src0, src1); + vtn_push_nir_ssa(b, w[2], dest); + break; + } + case SpvOpGroupNonUniformBallotBitCount: case SpvOpGroupNonUniformBallotFindLSB: case SpvOpGroupNonUniformBallotFindMSB: { - nir_def *src0, *src1 = NULL; + nir_def *src0 = NULL; nir_intrinsic_op op; switch (opcode) { - case SpvOpGroupNonUniformBallotBitExtract: - op = nir_intrinsic_ballot_bitfield_extract; - src0 = vtn_get_nir_ssa(b, w[4]); - src1 = vtn_get_nir_ssa(b, w[5]); - break; case SpvOpGroupNonUniformBallotBitCount: switch ((SpvGroupOperation)w[4]) { case SpvGroupOperationReduce: @@ -152,8 +154,6 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode, nir_intrinsic_instr_create(b->nb.shader, op); intrin->src[0] = nir_src_for_ssa(src0); - if (src1) - intrin->src[1] = nir_src_for_ssa(src1); nir_def_init_for_type(&intrin->instr, &intrin->def, dest_type->type);