spirv: ensure ballot find_lsb/find_msb/bit_count have 32bit result

Reviewed-by: Alyssa Rosenzweig <alyssa@rosenzweig.io>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/37178>
This commit is contained in:
Georg Lehmann 2025-09-03 14:16:38 +02:00 committed by Marge Bot
parent f8633511be
commit 516c766c71

View file

@ -119,47 +119,43 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
case SpvOpGroupNonUniformBallotBitCount:
case SpvOpGroupNonUniformBallotFindLSB:
case SpvOpGroupNonUniformBallotFindMSB: {
nir_def *src0 = NULL;
nir_intrinsic_op op;
nir_def *dest;
switch (opcode) {
case SpvOpGroupNonUniformBallotBitCount:
case SpvOpGroupNonUniformBallotBitCount: {
nir_def *src = vtn_get_nir_ssa(b, w[5]);
switch ((SpvGroupOperation)w[4]) {
case SpvGroupOperationReduce:
op = nir_intrinsic_ballot_bit_count_reduce;
dest = nir_ballot_bit_count_reduce(&b->nb, src);
break;
case SpvGroupOperationInclusiveScan:
op = nir_intrinsic_ballot_bit_count_inclusive;
dest = nir_ballot_bit_count_inclusive(&b->nb, src);
break;
case SpvGroupOperationExclusiveScan:
op = nir_intrinsic_ballot_bit_count_exclusive;
dest = nir_ballot_bit_count_exclusive(&b->nb, src);
break;
default:
UNREACHABLE("Invalid group operation");
}
src0 = vtn_get_nir_ssa(b, w[5]);
break;
case SpvOpGroupNonUniformBallotFindLSB:
op = nir_intrinsic_ballot_find_lsb;
src0 = vtn_get_nir_ssa(b, w[4]);
}
default: {
nir_def *src = vtn_get_nir_ssa(b, w[4]);
switch (opcode) {
case SpvOpGroupNonUniformBallotFindLSB:
dest = nir_ballot_find_lsb(&b->nb, src);
break;
case SpvOpGroupNonUniformBallotFindMSB:
dest = nir_ballot_find_msb(&b->nb, src);
break;
default:
UNREACHABLE("Unhandled opcode");
}
break;
case SpvOpGroupNonUniformBallotFindMSB:
op = nir_intrinsic_ballot_find_msb;
src0 = vtn_get_nir_ssa(b, w[4]);
break;
default:
UNREACHABLE("Unhandled opcode");
}
}
nir_intrinsic_instr *intrin =
nir_intrinsic_instr_create(b->nb.shader, op);
intrin->src[0] = nir_src_for_ssa(src0);
nir_def_init_for_type(&intrin->instr, &intrin->def,
dest_type->type);
nir_builder_instr_insert(&b->nb, &intrin->instr);
vtn_push_nir_ssa(b, w[2], &intrin->def);
dest = nir_i2iN(&b->nb, dest, glsl_get_bit_size(dest_type->type));
vtn_push_nir_ssa(b, w[2], dest);
break;
}