diff --git a/src/compiler/nir/nir_lower_subgroups.c b/src/compiler/nir/nir_lower_subgroups.c index 33ba7821752..c809a1783d6 100644 --- a/src/compiler/nir/nir_lower_subgroups.c +++ b/src/compiler/nir/nir_lower_subgroups.c @@ -78,8 +78,18 @@ static nir_ssa_def * uint_to_ballot_type(nir_builder *b, nir_ssa_def *value, unsigned num_components, unsigned bit_size) { - value = nir_bitcast_vector(b, value, bit_size); - return nir_pad_vector_imm_int(b, value, 0, num_components); + assert(util_is_power_of_two_nonzero(num_components)); + assert(util_is_power_of_two_nonzero(value->num_components)); + + /* The ballot type must always have enough bits */ + unsigned total_bits = bit_size * num_components; + assert(total_bits >= value->bit_size * value->num_components); + + /* If the source doesn't have enough bits, zero-pad */ + if (total_bits > value->bit_size * value->num_components) + value = nir_pad_vector_imm_int(b, value, 0, total_bits / value->bit_size); + + return nir_bitcast_vector(b, value, bit_size); } static nir_ssa_def *