diff --git a/src/compiler/spirv/vtn_subgroup.c b/src/compiler/spirv/vtn_subgroup.c index 69ba9dddbaa..9f676a1e238 100644 --- a/src/compiler/spirv/vtn_subgroup.c +++ b/src/compiler/spirv/vtn_subgroup.c @@ -119,47 +119,43 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode, case SpvOpGroupNonUniformBallotBitCount: case SpvOpGroupNonUniformBallotFindLSB: case SpvOpGroupNonUniformBallotFindMSB: { - nir_def *src0 = NULL; - nir_intrinsic_op op; + nir_def *dest; switch (opcode) { - case SpvOpGroupNonUniformBallotBitCount: + case SpvOpGroupNonUniformBallotBitCount: { + nir_def *src = vtn_get_nir_ssa(b, w[5]); switch ((SpvGroupOperation)w[4]) { case SpvGroupOperationReduce: - op = nir_intrinsic_ballot_bit_count_reduce; + dest = nir_ballot_bit_count_reduce(&b->nb, src); break; case SpvGroupOperationInclusiveScan: - op = nir_intrinsic_ballot_bit_count_inclusive; + dest = nir_ballot_bit_count_inclusive(&b->nb, src); break; case SpvGroupOperationExclusiveScan: - op = nir_intrinsic_ballot_bit_count_exclusive; + dest = nir_ballot_bit_count_exclusive(&b->nb, src); break; default: UNREACHABLE("Invalid group operation"); } - src0 = vtn_get_nir_ssa(b, w[5]); break; - case SpvOpGroupNonUniformBallotFindLSB: - op = nir_intrinsic_ballot_find_lsb; - src0 = vtn_get_nir_ssa(b, w[4]); + } + default: { + nir_def *src = vtn_get_nir_ssa(b, w[4]); + switch (opcode) { + case SpvOpGroupNonUniformBallotFindLSB: + dest = nir_ballot_find_lsb(&b->nb, src); + break; + case SpvOpGroupNonUniformBallotFindMSB: + dest = nir_ballot_find_msb(&b->nb, src); + break; + default: + UNREACHABLE("Unhandled opcode"); + } break; - case SpvOpGroupNonUniformBallotFindMSB: - op = nir_intrinsic_ballot_find_msb; - src0 = vtn_get_nir_ssa(b, w[4]); - break; - default: - UNREACHABLE("Unhandled opcode"); + } } - nir_intrinsic_instr *intrin = - nir_intrinsic_instr_create(b->nb.shader, op); - - intrin->src[0] = nir_src_for_ssa(src0); - - nir_def_init_for_type(&intrin->instr, &intrin->def, - dest_type->type); - nir_builder_instr_insert(&b->nb, &intrin->instr); - - vtn_push_nir_ssa(b, w[2], &intrin->def); + dest = nir_i2iN(&b->nb, dest, glsl_get_bit_size(dest_type->type)); + vtn_push_nir_ssa(b, w[2], dest); break; }