diff --git a/.pick_status.json b/.pick_status.json index 8675bcec70b..92d56bcf064 100644 --- a/.pick_status.json +++ b/.pick_status.json @@ -224,7 +224,7 @@ "description": "spirv: Handle constant cooperative matrices in OpCompositeExtract", "nominated": true, "nomination_type": 1, - "resolution": 0, + "resolution": 1, "main_sha": null, "because_sha": "b98f87612bc14fe88184dc099d9d4f8e6b3b23cb", "notes": null diff --git a/src/compiler/spirv/spirv_to_nir.c b/src/compiler/spirv/spirv_to_nir.c index f2c9ac300d9..7c76c915c76 100644 --- a/src/compiler/spirv/spirv_to_nir.c +++ b/src/compiler/spirv/spirv_to_nir.c @@ -2278,31 +2278,38 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode, int elem = -1; const struct vtn_type *type = comp->type; for (unsigned i = deref_start; i < count; i++) { - vtn_fail_if(w[i] > type->length, - "%uth index of %s is %u but the type has only " - "%u elements", i - deref_start, - spirv_op_to_string(opcode), w[i], type->length); + if (type->base_type == vtn_base_type_cooperative_matrix) { + /* Cooperative matrices are always scalar constants. We don't + * care about the index w[i] because it's always replicated. + */ + type = type->component_type; + } else { + vtn_fail_if(w[i] > type->length, + "%uth index of %s is %u but the type has only " + "%u elements", i - deref_start, + spirv_op_to_string(opcode), w[i], type->length); - switch (type->base_type) { - case vtn_base_type_vector: - elem = w[i]; - type = type->array_element; - break; + switch (type->base_type) { + case vtn_base_type_vector: + elem = w[i]; + type = type->array_element; + break; - case vtn_base_type_matrix: - case vtn_base_type_array: - c = &(*c)->elements[w[i]]; - type = type->array_element; - break; + case vtn_base_type_matrix: + case vtn_base_type_array: + c = &(*c)->elements[w[i]]; + type = type->array_element; + break; - case vtn_base_type_struct: - c = &(*c)->elements[w[i]]; - type = type->members[w[i]]; - break; + case vtn_base_type_struct: + c = &(*c)->elements[w[i]]; + type = type->members[w[i]]; + break; - default: - vtn_fail("%s must only index into composite types", - spirv_op_to_string(opcode)); + default: + vtn_fail("%s must only index into composite types", + spirv_op_to_string(opcode)); + } } }